mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 12:05:54 +00:00
Compare commits
1 Commits
feat/card_
...
feat/lark_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92c3b81014 |
109
.github/workflows/run-tests.yml
vendored
109
.github/workflows/run-tests.yml
vendored
@@ -4,29 +4,25 @@ on:
|
||||
pull_request:
|
||||
types: [opened, ready_for_review, synchronize]
|
||||
paths:
|
||||
- 'src/langbot/**'
|
||||
- 'pkg/**'
|
||||
- 'tests/**'
|
||||
- '.github/workflows/run-tests.yml'
|
||||
- 'pyproject.toml'
|
||||
- 'uv.lock'
|
||||
- 'run_tests.sh'
|
||||
- 'scripts/test-*.sh'
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- develop
|
||||
paths:
|
||||
- 'src/langbot/**'
|
||||
- 'pkg/**'
|
||||
- 'tests/**'
|
||||
- '.github/workflows/run-tests.yml'
|
||||
- 'pyproject.toml'
|
||||
- 'uv.lock'
|
||||
- 'run_tests.sh'
|
||||
- 'scripts/test-*.sh'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Unit Tests
|
||||
name: Run Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -43,13 +39,28 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
run: |
|
||||
uv sync --dev
|
||||
|
||||
- name: Run unit + smoke tests
|
||||
run: uv run pytest tests/unit_tests/ tests/smoke/ -q --tb=short
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
bash run_tests.sh
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
if: matrix.python-version == '3.12'
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
flags: unit-tests
|
||||
name: unit-tests-coverage
|
||||
fail_ci_if_error: false
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Test Summary
|
||||
if: always()
|
||||
@@ -58,79 +69,3 @@ jobs:
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Python Version: ${{ matrix.python-version }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
integration:
|
||||
name: Fast Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run fast integration tests
|
||||
run: uv run pytest tests/integration/ -m "not slow" -q --tb=short
|
||||
|
||||
- name: Integration Test Summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "## Integration Tests Results" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
coverage:
|
||||
name: Coverage Gate
|
||||
runs-on: ubuntu-latest
|
||||
needs: [test, integration]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run coverage (unit + smoke)
|
||||
run: |
|
||||
uv run pytest tests/unit_tests/ tests/smoke/ \
|
||||
--cov=langbot \
|
||||
--cov-report=xml \
|
||||
--cov-report=term-missing \
|
||||
--cov-fail-under=18 \
|
||||
-q --tb=short
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
flags: unit-tests
|
||||
name: coverage-report
|
||||
fail_ci_if_error: false
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "## Coverage Results" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Threshold: 18%" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||
109
.github/workflows/test-migrations.yml
vendored
109
.github/workflows/test-migrations.yml
vendored
@@ -9,13 +9,11 @@ on:
|
||||
paths:
|
||||
- 'src/langbot/pkg/persistence/**'
|
||||
- 'src/langbot/pkg/entity/persistence/**'
|
||||
- 'tests/integration/persistence/**'
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- 'src/langbot/pkg/persistence/**'
|
||||
- 'src/langbot/pkg/entity/persistence/**'
|
||||
- 'tests/integration/persistence/**'
|
||||
|
||||
jobs:
|
||||
test-migrations-sqlite:
|
||||
@@ -36,8 +34,52 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run SQLite migration tests
|
||||
run: uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
|
||||
- name: Test Alembic upgrade (SQLite)
|
||||
run: |
|
||||
uv run python -c "
|
||||
import asyncio
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.persistence.alembic_runner import run_alembic_upgrade, run_alembic_stamp, get_alembic_current
|
||||
|
||||
async def main():
|
||||
engine = create_async_engine('sqlite+aiosqlite:///test_migrations.db')
|
||||
|
||||
# Create all tables (simulates existing DB)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp baseline
|
||||
await run_alembic_stamp(engine, '0001_baseline')
|
||||
rev = await get_alembic_current(engine)
|
||||
assert rev == '0001_baseline', f'Expected 0001_baseline, got {rev}'
|
||||
print(f'Stamped: {rev}')
|
||||
|
||||
# Upgrade to head
|
||||
await run_alembic_upgrade(engine, 'head')
|
||||
rev = await get_alembic_current(engine)
|
||||
print(f'After upgrade: {rev}')
|
||||
assert rev is not None, 'Expected a revision after upgrade'
|
||||
|
||||
# Verify idempotent
|
||||
await run_alembic_upgrade(engine, 'head')
|
||||
rev2 = await get_alembic_current(engine)
|
||||
assert rev2 == rev, f'Expected {rev}, got {rev2}'
|
||||
print(f'Idempotent check passed: {rev2}')
|
||||
|
||||
# Fresh DB: upgrade from scratch
|
||||
engine2 = create_async_engine('sqlite+aiosqlite:///test_migrations_fresh.db')
|
||||
async with engine2.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await run_alembic_upgrade(engine2, 'head')
|
||||
rev3 = await get_alembic_current(engine2)
|
||||
print(f'Fresh DB upgrade: {rev3}')
|
||||
assert rev3 is not None
|
||||
|
||||
print('All SQLite migration tests passed!')
|
||||
|
||||
asyncio.run(main())
|
||||
"
|
||||
|
||||
test-migrations-postgres:
|
||||
name: Migrations (PostgreSQL)
|
||||
@@ -72,7 +114,58 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run PostgreSQL migration tests
|
||||
env:
|
||||
TEST_POSTGRES_URL: postgresql+asyncpg://langbot:langbot@localhost:5432/langbot_test
|
||||
run: uv run pytest tests/integration/persistence/test_migrations_postgres.py -q --tb=short
|
||||
- name: Test Alembic upgrade (PostgreSQL)
|
||||
run: |
|
||||
uv run python -c "
|
||||
import asyncio
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.persistence.alembic_runner import run_alembic_upgrade, run_alembic_stamp, get_alembic_current
|
||||
|
||||
DB_URL = 'postgresql+asyncpg://langbot:langbot@localhost:5432/langbot_test'
|
||||
|
||||
async def main():
|
||||
engine = create_async_engine(DB_URL)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp baseline
|
||||
await run_alembic_stamp(engine, '0001_baseline')
|
||||
rev = await get_alembic_current(engine)
|
||||
assert rev == '0001_baseline', f'Expected 0001_baseline, got {rev}'
|
||||
print(f'Stamped: {rev}')
|
||||
|
||||
# Upgrade to head
|
||||
await run_alembic_upgrade(engine, 'head')
|
||||
rev = await get_alembic_current(engine)
|
||||
print(f'After upgrade: {rev}')
|
||||
assert rev is not None
|
||||
|
||||
# Verify idempotent
|
||||
await run_alembic_upgrade(engine, 'head')
|
||||
rev2 = await get_alembic_current(engine)
|
||||
assert rev2 == rev, f'Expected {rev}, got {rev2}'
|
||||
print(f'Idempotent check passed: {rev2}')
|
||||
|
||||
# Fresh DB: drop all and upgrade from scratch
|
||||
engine2 = create_async_engine(DB_URL.replace('langbot_test', 'langbot_fresh'))
|
||||
|
||||
# Create fresh database
|
||||
from sqlalchemy import text
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text('COMMIT'))
|
||||
await conn.execute(text('CREATE DATABASE langbot_fresh'))
|
||||
|
||||
async with engine2.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await run_alembic_upgrade(engine2, 'head')
|
||||
rev3 = await get_alembic_current(engine2)
|
||||
print(f'Fresh DB upgrade: {rev3}')
|
||||
assert rev3 is not None
|
||||
|
||||
print('All PostgreSQL migration tests passed!')
|
||||
|
||||
asyncio.run(main())
|
||||
"
|
||||
|
||||
36
Makefile
36
Makefile
@@ -1,36 +0,0 @@
|
||||
# LangBot Makefile
|
||||
# Quick developer commands
|
||||
|
||||
.PHONY: test test-quick test-integration-fast test-coverage test-all-local lint
|
||||
|
||||
# Run all tests (full suite with coverage)
|
||||
test:
|
||||
bash run_tests.sh
|
||||
|
||||
# Quick self-test for developers (lint + unit + smoke, no real credentials needed)
|
||||
test-quick:
|
||||
bash scripts/test-quick.sh
|
||||
|
||||
# Fast integration tests (SQLite/API/Pipeline, no external services)
|
||||
test-integration-fast:
|
||||
bash scripts/test-integration-fast.sh
|
||||
|
||||
# Coverage gate (all tests, enforces minimum threshold)
|
||||
test-coverage:
|
||||
bash scripts/test-coverage.sh
|
||||
|
||||
# Full local quality gate (quick + integration + coverage)
|
||||
test-all-local:
|
||||
bash scripts/test-quick.sh
|
||||
bash scripts/test-integration-fast.sh
|
||||
bash scripts/test-coverage.sh
|
||||
|
||||
# Run linting only
|
||||
lint:
|
||||
ruff check src/langbot/ tests/
|
||||
ruff format --check src/langbot/ tests/
|
||||
|
||||
# Fix linting issues
|
||||
lint-fix:
|
||||
ruff check --fix src/langbot/ tests/
|
||||
ruff format src/langbot/ tests/
|
||||
@@ -47,8 +47,6 @@ LangBot is an **open-source, production-grade platform** for building AI-powered
|
||||
|
||||
[→ Learn more about all features](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Practical guides: [deploy a multi-platform AI bot in 5 minutes](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [connect DeepSeek to WeChat, Discord, and Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [run a Dify Agent in Discord, Telegram, and Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/), and [build an n8n-powered chatbot](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -47,8 +47,6 @@ LangBot 是一个**开源的生产级平台**,用于构建 AI 驱动的即时
|
||||
|
||||
[→ 了解更多功能特性](https://link.langbot.app/zh/docs/features)
|
||||
|
||||
📍 实践指南:[5 分钟部署多平台 AI 机器人](https://blog.langbot.app/zh/blog/deploy-ai-bot-in-5-minutes/)、[将 DeepSeek 接入微信、企业微信与 Discord](https://blog.langbot.app/zh/blog/connect-deepseek-to-wechat/)、[让 Dify Agent 跑在 Discord、Telegram 和 Slack 上](https://blog.langbot.app/zh/blog/dify-agent-discord-telegram-slack/),以及[用 n8n 构建多平台 AI 聊天机器人](https://blog.langbot.app/zh/blog/n8n-multi-platform-ai-chatbot/)。
|
||||
|
||||
---
|
||||
|
||||
## 快速开始
|
||||
|
||||
@@ -46,8 +46,6 @@ LangBot es una **plataforma de código abierto y grado de producción** para con
|
||||
|
||||
[→ Conocer más sobre todas las funcionalidades](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Guías prácticas: [desplegar un bot de IA multiplataforma en 5 minutos](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [conectar DeepSeek a WeChat, Discord y Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [ejecutar un Dify Agent en Discord, Telegram y Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) y [crear un chatbot con n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
## Inicio Rápido
|
||||
|
||||
@@ -46,8 +46,6 @@ LangBot est une **plateforme open-source de niveau production** pour créer des
|
||||
|
||||
[→ En savoir plus sur toutes les fonctionnalités](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Guides pratiques : [déployer un bot IA multiplateforme en 5 minutes](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [connecter DeepSeek à WeChat, Discord et Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [exécuter un Dify Agent dans Discord, Telegram et Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) et [créer un chatbot avec n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
## Démarrage Rapide
|
||||
|
||||
@@ -46,8 +46,6 @@ LangBot は、AI搭載のインスタントメッセージングボットを構
|
||||
|
||||
[→ すべての機能について詳しく見る](https://link.langbot.app/ja/docs/features)
|
||||
|
||||
📍 実践ガイド: [5分でマルチプラットフォームAIボットをデプロイ](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/)、[DeepSeekをWeChat・Discord・Telegramに接続](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/)、[Dify AgentをDiscord・Telegram・Slackで動かす](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/)、[n8n連携チャットボットを構築](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/)。
|
||||
|
||||
---
|
||||
|
||||
## クイックスタート
|
||||
|
||||
@@ -46,8 +46,6 @@ LangBot은 AI 기반 인스턴트 메시징 봇을 구축하기 위한 **오픈
|
||||
|
||||
[→ 모든 기능 자세히 보기](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 실전 가이드: [5분 만에 멀티 플랫폼 AI 봇 배포하기](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [DeepSeek를 WeChat, Discord, Telegram에 연결하기](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [Dify Agent를 Discord, Telegram, Slack에서 실행하기](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/), [n8n 기반 챗봇 만들기](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
## 빠른 시작
|
||||
|
||||
@@ -46,8 +46,6 @@ LangBot — это **платформа с открытым исходным к
|
||||
|
||||
[→ Подробнее обо всех возможностях](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Практические руководства: [развернуть мультиплатформенного ИИ-бота за 5 минут](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [подключить DeepSeek к WeChat, Discord и Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [запустить Dify Agent в Discord, Telegram и Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) и [создать чат-бота на n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
## Быстрый старт
|
||||
|
||||
@@ -48,8 +48,6 @@ LangBot 是一個**開源的生產級平台**,用於建構 AI 驅動的即時
|
||||
|
||||
[→ 了解更多功能特性](https://link.langbot.app/zh/docs/features)
|
||||
|
||||
📍 實踐指南:[5 分鐘部署多平台 AI 機器人](https://blog.langbot.app/zh/blog/deploy-ai-bot-in-5-minutes/)、[將 DeepSeek 接入微信、企業微信與 Discord](https://blog.langbot.app/zh/blog/connect-deepseek-to-wechat/)、[讓 Dify Agent 跑在 Discord、Telegram 和 Slack 上](https://blog.langbot.app/zh/blog/dify-agent-discord-telegram-slack/),以及[用 n8n 建構多平台 AI 聊天機器人](https://blog.langbot.app/zh/blog/n8n-multi-platform-ai-chatbot/)。
|
||||
|
||||
---
|
||||
|
||||
## 快速開始
|
||||
|
||||
@@ -46,8 +46,6 @@ LangBot là một **nền tảng mã nguồn mở, cấp sản xuất** để x
|
||||
|
||||
[→ Tìm hiểu thêm về tất cả tính năng](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Hướng dẫn thực hành: [triển khai bot AI đa nền tảng trong 5 phút](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [kết nối DeepSeek với WeChat, Discord và Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [chạy Dify Agent trên Discord, Telegram và Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) và [xây dựng chatbot với n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
## Bắt đầu nhanh
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "langbot"
|
||||
version = "4.9.7"
|
||||
version = "4.9.6"
|
||||
description = "Production-grade platform for building agentic IM bots"
|
||||
readme = "README.md"
|
||||
license-files = ["LICENSE"]
|
||||
@@ -22,7 +22,7 @@ dependencies = [
|
||||
"discord-py>=2.5.2",
|
||||
"pynacl>=1.5.0", # Required for Discord voice support
|
||||
"gewechat-client>=0.1.5",
|
||||
"lark-oapi>=1.5.5",
|
||||
"lark-oapi>=1.4.15",
|
||||
"mcp>=1.25.0",
|
||||
"nakuru-project-idk>=0.0.2.1",
|
||||
"ollama>=0.4.8",
|
||||
@@ -35,7 +35,6 @@ dependencies = [
|
||||
"python-telegram-bot>=22.0",
|
||||
"pyyaml>=6.0.2",
|
||||
"qq-botpy-rc>=1.2.1.6",
|
||||
"qrcode>=7.4",
|
||||
"quart>=0.20.0",
|
||||
"quart-cors>=0.8.0",
|
||||
"requests>=2.32.3",
|
||||
@@ -70,7 +69,7 @@ dependencies = [
|
||||
"chromadb>=1.0.0,<2.0.0",
|
||||
"qdrant-client (>=1.15.1,<2.0.0)",
|
||||
"pyseekdb==1.1.0.post3",
|
||||
"langbot-plugin==0.3.11",
|
||||
"langbot-plugin==0.3.10",
|
||||
"asyncpg>=0.30.0",
|
||||
"line-bot-sdk>=3.19.0",
|
||||
"matrix-nio>=0.25.2",
|
||||
@@ -122,7 +121,6 @@ package-data = { "langbot" = ["templates/**", "pkg/provider/modelmgr/requesters/
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"moto>=5.2.1",
|
||||
"pre-commit>=4.2.0",
|
||||
"pytest>=9.0.3",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
|
||||
@@ -4,9 +4,6 @@ python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
|
||||
# Python path for imports
|
||||
pythonpath = . tests
|
||||
|
||||
# Test paths
|
||||
testpaths = tests
|
||||
|
||||
@@ -25,9 +22,7 @@ markers =
|
||||
asyncio: mark test as async
|
||||
unit: mark test as unit test
|
||||
integration: mark test as integration test
|
||||
smoke: mark test as smoke test
|
||||
slow: mark test as slow running
|
||||
e2e: mark test as end-to-end test (requires real LangBot process)
|
||||
|
||||
# Coverage options (when using pytest-cov)
|
||||
[coverage:run]
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Coverage gate script
|
||||
# Runs all tests with coverage, enforcing minimum coverage threshold
|
||||
# Uses separate pytest invocations to avoid sys.modules pollution between test types
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
echo "=== LangBot Coverage Gate ==="
|
||||
echo ""
|
||||
|
||||
# Coverage threshold (baseline from current coverage, conservative buffer)
|
||||
# Current: ~22.14%, threshold: 18%
|
||||
COVERAGE_THRESHOLD=18
|
||||
|
||||
# Create temporary directory for coverage files
|
||||
COV_DIR=$(mktemp -d)
|
||||
trap "rm -rf $COV_DIR" EXIT
|
||||
|
||||
echo "[1/3] Running unit + smoke tests with coverage..."
|
||||
uv run pytest tests/unit_tests/ tests/smoke/ \
|
||||
--cov=langbot \
|
||||
--cov-report=json:$COV_DIR/unit.json \
|
||||
--cov-report=term-missing \
|
||||
-q --tb=short
|
||||
echo ""
|
||||
|
||||
echo "[2/3] Running fast integration tests with coverage..."
|
||||
uv run pytest tests/integration/ -m "not slow" \
|
||||
--cov=langbot \
|
||||
--cov-report=json:$COV_DIR/integration.json \
|
||||
--cov-report=term-missing \
|
||||
-q --tb=short
|
||||
echo ""
|
||||
|
||||
echo "[3/3] Combining coverage reports..."
|
||||
# Use coverage combine if available, otherwise just report total
|
||||
if command -v coverage &> /dev/null; then
|
||||
# Combine JSON reports
|
||||
coverage combine --keep $COV_DIR/unit.json $COV_DIR/integration.json \
|
||||
--data-file=$COV_DIR/combined.data 2>/dev/null || true
|
||||
|
||||
coverage report --data-file=$COV_DIR/combined.data || true
|
||||
else
|
||||
echo "Note: coverage combine not available, showing individual reports above"
|
||||
fi
|
||||
|
||||
# Generate final XML report for CI (from last run)
|
||||
uv run pytest tests/unit_tests/ tests/smoke/ \
|
||||
--cov=langbot \
|
||||
--cov-report=xml:coverage.xml \
|
||||
--cov-report=term \
|
||||
--cov-fail-under=$COVERAGE_THRESHOLD \
|
||||
-q 2>/dev/null || {
|
||||
# If threshold check fails on combined, check unit+smoke baseline
|
||||
echo ""
|
||||
echo "Coverage threshold: $COVERAGE_THRESHOLD%"
|
||||
echo "Note: Full coverage requires running all test types separately"
|
||||
}
|
||||
|
||||
echo ""
|
||||
echo "=== Coverage Gate Complete ==="
|
||||
echo ""
|
||||
echo "Coverage baseline: $COVERAGE_THRESHOLD%"
|
||||
echo "Coverage report saved to coverage.xml"
|
||||
@@ -1,16 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Fast integration tests
|
||||
# Runs integration tests excluding slow ones (PostgreSQL, external services)
|
||||
# Uses fake runner/provider, no real credentials needed
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
echo "=== LangBot Fast Integration Tests ==="
|
||||
echo ""
|
||||
|
||||
echo "Running integration tests (excluding slow)..."
|
||||
uv run pytest tests/integration/ -m "not slow" -q --tb=short
|
||||
|
||||
echo ""
|
||||
echo "=== Fast Integration Tests Complete ==="
|
||||
@@ -1,36 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Quick developer self-test command
|
||||
# Runs linting, unit tests, and smoke tests without requiring real provider keys
|
||||
# Suitable for local branch validation
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
echo "=== LangBot Quick Self-Test ==="
|
||||
echo ""
|
||||
|
||||
# 1. Ruff check
|
||||
echo "[1/3] Running ruff check..."
|
||||
uv run ruff check src/langbot/ tests/ --output-format=concise || {
|
||||
echo ""
|
||||
echo "⚠ Ruff check found issues. Run 'uv run ruff check --fix' to auto-fix."
|
||||
exit 1
|
||||
}
|
||||
echo "✓ Ruff check passed"
|
||||
echo ""
|
||||
|
||||
# 2. Unit tests
|
||||
echo "[2/3] Running unit tests..."
|
||||
uv run pytest tests/unit_tests/ -q --tb=short
|
||||
echo ""
|
||||
|
||||
# 3. Smoke tests (if exists)
|
||||
echo "[3/3] Running smoke tests..."
|
||||
if [ -d "tests/smoke" ]; then
|
||||
uv run pytest tests/smoke/ -q --tb=short
|
||||
else
|
||||
echo "No smoke tests found, skipping"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "=== Quick Self-Test Complete ==="
|
||||
@@ -1,3 +1,3 @@
|
||||
"""LangBot - Production-grade platform for building agentic IM bots"""
|
||||
|
||||
__version__ = '4.9.7'
|
||||
__version__ = '4.9.6'
|
||||
|
||||
@@ -109,61 +109,6 @@ class AsyncDifyServiceClient:
|
||||
if chunk.startswith('data:'):
|
||||
yield json.loads(chunk[5:])
|
||||
|
||||
async def workflow_submit(
|
||||
self,
|
||||
form_token: str,
|
||||
workflow_run_id: str,
|
||||
inputs: dict[str, typing.Any],
|
||||
user: str,
|
||||
action: str = '',
|
||||
timeout: float = 120.0,
|
||||
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
|
||||
"""Submit human input to resume a paused workflow, then stream events.
|
||||
|
||||
1. POST /form/human_input/{form_token} to submit the form
|
||||
2. GET /workflow/{task_id}/events to stream the resumed workflow events
|
||||
"""
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
trust_env=True,
|
||||
timeout=timeout,
|
||||
) as client:
|
||||
# Step 1: Submit the form
|
||||
payload: dict[str, typing.Any] = {
|
||||
'inputs': inputs if isinstance(inputs, dict) else {},
|
||||
'user': user,
|
||||
'action': action,
|
||||
}
|
||||
|
||||
submit_resp = await client.post(
|
||||
f'/form/human_input/{form_token}',
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
if submit_resp.status_code != 200:
|
||||
raise DifyAPIError(f'{submit_resp.status_code} {submit_resp.text}')
|
||||
|
||||
# Step 2: Stream resumed workflow events
|
||||
async with client.stream(
|
||||
'GET',
|
||||
f'/workflow/{workflow_run_id}/events',
|
||||
headers={'Authorization': f'Bearer {self.api_key}'},
|
||||
params={'user': user},
|
||||
) as r:
|
||||
async for chunk in r.aiter_lines():
|
||||
if r.status_code != 200:
|
||||
raise DifyAPIError(f'{r.status_code} {chunk}')
|
||||
if chunk.strip() == '':
|
||||
continue
|
||||
if chunk.startswith('data:'):
|
||||
yield json.loads(chunk[5:])
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
file: httpx._types.FileTypes,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import quart
|
||||
import mimetypes
|
||||
import asyncio
|
||||
from ... import group
|
||||
from langbot.pkg.utils import importutil
|
||||
|
||||
@@ -36,617 +35,3 @@ class AdaptersRouterGroup(group.RouterGroup):
|
||||
return quart.Response(
|
||||
importutil.read_resource_file_bytes(icon_path), mimetype=mimetypes.guess_type(icon_path)[0]
|
||||
)
|
||||
|
||||
# In-memory session store for active registrations
|
||||
_create_app_sessions: dict = {}
|
||||
_SESSION_TTL = 900 # 15 minutes
|
||||
|
||||
def _cleanup_expired_sessions():
|
||||
"""Remove sessions that have exceeded their TTL."""
|
||||
import time
|
||||
|
||||
now = time.time()
|
||||
expired = [sid for sid, s in _create_app_sessions.items() if now - s.get('created_at', 0) > _SESSION_TTL]
|
||||
for sid in expired:
|
||||
session = _create_app_sessions.pop(sid, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
|
||||
@self.route('/lark/create-app', methods=['POST'])
|
||||
async def _() -> str:
|
||||
"""Start Feishu one-click app registration. Returns session_id + QR code URL."""
|
||||
import uuid
|
||||
import time
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.scene.registration.errors import AppAccessDeniedError, AppExpiredError
|
||||
|
||||
_cleanup_expired_sessions()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
session = {
|
||||
'status': 'pending',
|
||||
'qr_url': None,
|
||||
'expire_at': None,
|
||||
'app_id': None,
|
||||
'app_secret': None,
|
||||
'error': None,
|
||||
'created_at': time.time(),
|
||||
}
|
||||
_create_app_sessions[session_id] = session
|
||||
|
||||
def on_qr_code(info):
|
||||
# May be called from a background thread by the SDK;
|
||||
# use call_soon_threadsafe to safely update session state.
|
||||
def _update():
|
||||
session['qr_url'] = info['url']
|
||||
session['expire_at'] = time.time() + 600 # 10 minutes
|
||||
session['status'] = 'waiting'
|
||||
|
||||
loop.call_soon_threadsafe(_update)
|
||||
|
||||
async def run_registration():
|
||||
try:
|
||||
result = await lark.aregister_app(
|
||||
on_qr_code=on_qr_code,
|
||||
source='langbot',
|
||||
)
|
||||
session['status'] = 'success'
|
||||
session['app_id'] = result['client_id']
|
||||
session['app_secret'] = result['client_secret']
|
||||
except AppAccessDeniedError:
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'User denied authorization'
|
||||
except AppExpiredError:
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code expired'
|
||||
except Exception as e:
|
||||
session['status'] = 'error'
|
||||
session['error'] = str(e)
|
||||
|
||||
task = asyncio.create_task(run_registration())
|
||||
session['task'] = task
|
||||
|
||||
# Wait for QR code to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
if session['qr_url']:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not session['qr_url']:
|
||||
task.cancel()
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Timeout waiting for QR code'
|
||||
return self.http_status(504, -1, 'Timeout waiting for QR code')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'session_id': session_id,
|
||||
'qr_url': session['qr_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/lark/create-app/status/<session_id>', methods=['GET'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Poll registration status."""
|
||||
session = _create_app_sessions.get(session_id)
|
||||
if not session:
|
||||
return self.http_status(404, -1, 'Session not found')
|
||||
|
||||
data = {'status': session['status']}
|
||||
|
||||
if session['status'] == 'success':
|
||||
data['app_id'] = session['app_id']
|
||||
data['app_secret'] = session['app_secret']
|
||||
_create_app_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'error':
|
||||
data['error'] = session['error']
|
||||
_create_app_sessions.pop(session_id, None)
|
||||
|
||||
return self.success(data=data)
|
||||
|
||||
@self.route('/lark/create-app/<session_id>', methods=['DELETE'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Cancel and clean up a registration session."""
|
||||
session = _create_app_sessions.pop(session_id, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
return self.success(data={})
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# WeChat QR Code Login
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
_weixin_login_sessions: dict = {}
|
||||
_WEIXIN_SESSION_TTL = 600 # 10 minutes (3 retries × 3 min QR validity)
|
||||
|
||||
def _cleanup_expired_weixin_sessions():
|
||||
import time
|
||||
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid for sid, s in _weixin_login_sessions.items() if now - s.get('created_at', 0) > _WEIXIN_SESSION_TTL
|
||||
]
|
||||
for sid in expired:
|
||||
session = _weixin_login_sessions.pop(sid, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
|
||||
@self.route('/weixin/login', methods=['POST'])
|
||||
async def _() -> str:
|
||||
"""Start WeChat QR code login. Returns session_id + QR code data URL."""
|
||||
import uuid
|
||||
import time
|
||||
|
||||
from langbot.libs.openclaw_weixin_api.client import OpenClawWeixinClient, DEFAULT_BASE_URL
|
||||
|
||||
_cleanup_expired_weixin_sessions()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
session = {
|
||||
'status': 'pending',
|
||||
'qr_data_url': None,
|
||||
'expire_at': None,
|
||||
'token': None,
|
||||
'base_url': None,
|
||||
'account_id': None,
|
||||
'error': None,
|
||||
'created_at': time.time(),
|
||||
}
|
||||
_weixin_login_sessions[session_id] = session
|
||||
|
||||
client = OpenClawWeixinClient(
|
||||
base_url=DEFAULT_BASE_URL,
|
||||
token='',
|
||||
)
|
||||
|
||||
async def run_login():
|
||||
try:
|
||||
|
||||
def on_qrcode(qr_data_url: str, _qr_url: str):
|
||||
def _update():
|
||||
session['qr_data_url'] = qr_data_url
|
||||
session['expire_at'] = time.time() + 180
|
||||
session['status'] = 'waiting'
|
||||
|
||||
loop.call_soon_threadsafe(_update)
|
||||
|
||||
result = await client.login(
|
||||
max_retries=1,
|
||||
poll_timeout_ms=180_000,
|
||||
on_qrcode=on_qrcode,
|
||||
)
|
||||
session['status'] = 'success'
|
||||
session['token'] = result.token
|
||||
session['base_url'] = result.base_url
|
||||
session['account_id'] = result.account_id
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
if 'expired' in error_message.lower() or 'max retries exceeded' in error_message.lower():
|
||||
session['status'] = 'expired'
|
||||
session['error'] = 'QR code expired'
|
||||
else:
|
||||
session['status'] = 'error'
|
||||
session['error'] = error_message
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
task = asyncio.create_task(run_login())
|
||||
session['task'] = task
|
||||
|
||||
# Wait for QR code to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
if session['qr_data_url']:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not session['qr_data_url']:
|
||||
task.cancel()
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Timeout waiting for QR code'
|
||||
return self.http_status(504, -1, 'Timeout waiting for QR code')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'session_id': session_id,
|
||||
'qr_data_url': session['qr_data_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/weixin/login/status/<session_id>', methods=['GET'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Poll WeChat login status."""
|
||||
session = _weixin_login_sessions.get(session_id)
|
||||
if not session:
|
||||
return self.http_status(404, -1, 'Session not found')
|
||||
|
||||
data = {
|
||||
'status': session['status'],
|
||||
'qr_data_url': session['qr_data_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
|
||||
if session['status'] == 'success':
|
||||
data['token'] = session['token']
|
||||
data['base_url'] = session['base_url']
|
||||
data['account_id'] = session['account_id']
|
||||
_weixin_login_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'error':
|
||||
data['error'] = session['error']
|
||||
_weixin_login_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'expired':
|
||||
data['error'] = session['error']
|
||||
_weixin_login_sessions.pop(session_id, None)
|
||||
|
||||
return self.success(data=data)
|
||||
|
||||
@self.route('/weixin/login/<session_id>', methods=['DELETE'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Cancel and clean up a WeChat login session."""
|
||||
session = _weixin_login_sessions.pop(session_id, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
return self.success(data={})
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# DingTalk Device Flow QR Code Login
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
_dingtalk_sessions: dict = {}
|
||||
_DINGTALK_SESSION_TTL = 600 # 10 minutes (QR code validity window)
|
||||
|
||||
def _cleanup_expired_dingtalk_sessions():
|
||||
import time
|
||||
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid for sid, s in _dingtalk_sessions.items() if now - s.get('created_at', 0) > _DINGTALK_SESSION_TTL
|
||||
]
|
||||
for sid in expired:
|
||||
session = _dingtalk_sessions.pop(sid, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
|
||||
@self.route('/dingtalk/create-app', methods=['POST'])
|
||||
async def _() -> str:
|
||||
"""Start DingTalk one-click app creation via Device Flow. Returns session_id + QR code URL."""
|
||||
import uuid
|
||||
import time
|
||||
import aiohttp
|
||||
|
||||
DINGTALK_BASE_URL = 'https://oapi.dingtalk.com'
|
||||
|
||||
_cleanup_expired_dingtalk_sessions()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
session = {
|
||||
'status': 'pending',
|
||||
'qr_url': None,
|
||||
'expire_at': None,
|
||||
'client_id': None,
|
||||
'client_secret': None,
|
||||
'error': None,
|
||||
'created_at': time.time(),
|
||||
'device_code': None,
|
||||
'interval': 5,
|
||||
}
|
||||
_dingtalk_sessions[session_id] = session
|
||||
|
||||
async def run_device_flow():
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as http:
|
||||
# Step 1: Init — get nonce
|
||||
async with http.post(
|
||||
f'{DINGTALK_BASE_URL}/app/registration/init',
|
||||
json={'source': 'langbot'},
|
||||
) as resp:
|
||||
try:
|
||||
data = await resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Invalid response from DingTalk service'
|
||||
return
|
||||
if data.get('errcode', -1) != 0:
|
||||
session['status'] = 'error'
|
||||
session['error'] = data.get('errmsg', 'Failed to init')
|
||||
return
|
||||
nonce = data['nonce']
|
||||
|
||||
# Step 2: Begin — get device_code + QR URL
|
||||
async with http.post(
|
||||
f'{DINGTALK_BASE_URL}/app/registration/begin',
|
||||
json={'nonce': nonce},
|
||||
) as resp:
|
||||
try:
|
||||
data = await resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Invalid response from DingTalk service'
|
||||
return
|
||||
if data.get('errcode', -1) != 0:
|
||||
session['status'] = 'error'
|
||||
session['error'] = data.get('errmsg', 'Failed to begin authorization')
|
||||
return
|
||||
|
||||
device_code = data['device_code']
|
||||
verification_uri_complete = data.get('verification_uri_complete', '')
|
||||
expires_in = data.get('expires_in', 7200)
|
||||
interval = data.get('interval', 5)
|
||||
|
||||
session['device_code'] = device_code
|
||||
session['interval'] = interval
|
||||
session['qr_url'] = verification_uri_complete
|
||||
session['expire_at'] = time.time() + 600 # QR code valid for ~10 min
|
||||
session['status'] = 'waiting'
|
||||
|
||||
# Step 3: Poll for authorization result
|
||||
deadline = time.time() + expires_in
|
||||
while time.time() < deadline:
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
async with http.post(
|
||||
f'{DINGTALK_BASE_URL}/app/registration/poll',
|
||||
json={'device_code': device_code},
|
||||
) as poll_resp:
|
||||
try:
|
||||
poll_data = await poll_resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
continue
|
||||
|
||||
if poll_data.get('errcode', -1) != 0:
|
||||
session['status'] = 'error'
|
||||
session['error'] = poll_data.get('errmsg', 'Poll failed')
|
||||
return
|
||||
|
||||
status = poll_data.get('status', '')
|
||||
|
||||
if status == 'SUCCESS':
|
||||
session['status'] = 'success'
|
||||
session['client_id'] = poll_data.get('client_id', '')
|
||||
session['client_secret'] = poll_data.get('client_secret', '')
|
||||
return
|
||||
elif status == 'FAIL':
|
||||
session['status'] = 'error'
|
||||
session['error'] = poll_data.get('fail_reason', 'Authorization failed')
|
||||
return
|
||||
elif status == 'EXPIRED':
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code expired'
|
||||
return
|
||||
# status == 'WAITING': continue polling
|
||||
|
||||
# Timeout
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code expired'
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
session['status'] = 'error'
|
||||
session['error'] = str(e)
|
||||
|
||||
task = asyncio.create_task(run_device_flow())
|
||||
session['task'] = task
|
||||
|
||||
# Wait for QR code to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
if session['qr_url'] or session['error']:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if session['error']:
|
||||
task.cancel()
|
||||
return self.http_status(502, -1, session['error'])
|
||||
|
||||
if not session['qr_url']:
|
||||
task.cancel()
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Timeout waiting for QR code'
|
||||
return self.http_status(504, -1, 'Timeout waiting for QR code')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'session_id': session_id,
|
||||
'qr_url': session['qr_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/dingtalk/create-app/status/<session_id>', methods=['GET'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Poll DingTalk Device Flow status."""
|
||||
_cleanup_expired_dingtalk_sessions()
|
||||
session = _dingtalk_sessions.get(session_id)
|
||||
if not session:
|
||||
return self.http_status(404, -1, 'Session not found')
|
||||
|
||||
data = {'status': session['status']}
|
||||
|
||||
if session['status'] == 'success':
|
||||
data['client_id'] = session['client_id']
|
||||
data['client_secret'] = session['client_secret']
|
||||
_dingtalk_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'error':
|
||||
data['error'] = session['error']
|
||||
_dingtalk_sessions.pop(session_id, None)
|
||||
|
||||
return self.success(data=data)
|
||||
|
||||
@self.route('/dingtalk/create-app/<session_id>', methods=['DELETE'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Cancel and clean up a DingTalk Device Flow session."""
|
||||
session = _dingtalk_sessions.pop(session_id, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
return self.success(data={})
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# WeComBot QR Code One-Click Create
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
_wecombot_sessions: dict = {}
|
||||
_WECOMBOT_SESSION_TTL = 300 # 5 minutes (WeCom QR validity window)
|
||||
|
||||
def _cleanup_expired_wecombot_sessions():
|
||||
import time
|
||||
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid for sid, s in _wecombot_sessions.items() if now - s.get('created_at', 0) > _WECOMBOT_SESSION_TTL
|
||||
]
|
||||
for sid in expired:
|
||||
session = _wecombot_sessions.pop(sid, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
|
||||
@self.route('/wecombot/create-bot', methods=['POST'])
|
||||
async def _() -> str:
|
||||
"""Start WeComBot one-click creation via QR code. Returns session_id + QR code URL."""
|
||||
import uuid
|
||||
import time
|
||||
import aiohttp
|
||||
|
||||
WECOM_QC_GENERATE_URL = 'https://work.weixin.qq.com/ai/qc/generate'
|
||||
WECOM_QC_QUERY_URL = 'https://work.weixin.qq.com/ai/qc/query_result'
|
||||
|
||||
_cleanup_expired_wecombot_sessions()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
session = {
|
||||
'status': 'pending',
|
||||
'qr_url': None,
|
||||
'expire_at': None,
|
||||
'botid': None,
|
||||
'secret': None,
|
||||
'error': None,
|
||||
'created_at': time.time(),
|
||||
'scode': None,
|
||||
'task': None,
|
||||
}
|
||||
_wecombot_sessions[session_id] = session
|
||||
|
||||
async def run_qr_flow():
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as http:
|
||||
# Step 1: Generate QR code
|
||||
async with http.get(
|
||||
f'{WECOM_QC_GENERATE_URL}?source=langbot&plat=0',
|
||||
) as resp:
|
||||
try:
|
||||
data = await resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Invalid response from WeCom service'
|
||||
return
|
||||
if not data.get('data', {}).get('scode') or not data.get('data', {}).get('auth_url'):
|
||||
session['status'] = 'error'
|
||||
session['error'] = data.get('errmsg', 'Failed to generate QR code')
|
||||
return
|
||||
|
||||
scode = data['data']['scode']
|
||||
auth_url = data['data']['auth_url']
|
||||
|
||||
session['scode'] = scode
|
||||
session['qr_url'] = auth_url
|
||||
session['expire_at'] = time.time() + _WECOMBOT_SESSION_TTL
|
||||
session['status'] = 'waiting'
|
||||
|
||||
# Step 2: Poll for scan result
|
||||
deadline = time.time() + _WECOMBOT_SESSION_TTL
|
||||
while time.time() < deadline:
|
||||
await asyncio.sleep(3)
|
||||
|
||||
async with http.get(
|
||||
f'{WECOM_QC_QUERY_URL}?scode={scode}',
|
||||
) as poll_resp:
|
||||
try:
|
||||
poll_data = await poll_resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
continue
|
||||
|
||||
status = poll_data.get('data', {}).get('status', '')
|
||||
if status == 'success':
|
||||
bot_info = poll_data.get('data', {}).get('bot_info', {})
|
||||
if bot_info.get('botid') and bot_info.get('secret'):
|
||||
session['status'] = 'success'
|
||||
session['botid'] = bot_info['botid']
|
||||
session['secret'] = bot_info['secret']
|
||||
return
|
||||
else:
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Scan succeeded but bot info is incomplete'
|
||||
return
|
||||
|
||||
# Timeout
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code expired'
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
session['status'] = 'error'
|
||||
session['error'] = str(e)
|
||||
|
||||
task = asyncio.create_task(run_qr_flow())
|
||||
session['task'] = task
|
||||
|
||||
# Wait for QR code to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
if session['qr_url'] or session['error']:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if session['error']:
|
||||
task.cancel()
|
||||
return self.http_status(502, -1, session['error'])
|
||||
|
||||
if not session['qr_url']:
|
||||
task.cancel()
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Timeout waiting for QR code'
|
||||
return self.http_status(504, -1, 'Timeout waiting for QR code')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'session_id': session_id,
|
||||
'qr_url': session['qr_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/wecombot/create-bot/status/<session_id>', methods=['GET'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Poll WeComBot creation status."""
|
||||
_cleanup_expired_wecombot_sessions()
|
||||
session = _wecombot_sessions.get(session_id)
|
||||
if not session:
|
||||
return self.http_status(404, -1, 'Session not found')
|
||||
|
||||
data = {'status': session['status']}
|
||||
|
||||
if session['status'] == 'success':
|
||||
data['botid'] = session['botid']
|
||||
data['secret'] = session['secret']
|
||||
_wecombot_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'error':
|
||||
data['error'] = session['error']
|
||||
_wecombot_sessions.pop(session_id, None)
|
||||
|
||||
return self.success(data=data)
|
||||
|
||||
@self.route('/wecombot/create-bot/<session_id>', methods=['DELETE'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Cancel and clean up a WeComBot creation session."""
|
||||
session = _wecombot_sessions.pop(session_id, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
return self.success(data={})
|
||||
|
||||
@@ -7,10 +7,8 @@ import httpx
|
||||
import uuid
|
||||
import os
|
||||
import posixpath
|
||||
import sqlalchemy
|
||||
|
||||
from .....core import taskmgr
|
||||
from .....entity.persistence import plugin as persistence_plugin
|
||||
from .. import group
|
||||
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
|
||||
|
||||
@@ -41,16 +39,6 @@ def _normalize_plugin_asset_path(filepath: str) -> str | None:
|
||||
return f'assets/{normalized}'
|
||||
|
||||
|
||||
def _get_request_origin() -> str:
|
||||
"""Return the public request origin, respecting reverse-proxy headers."""
|
||||
forwarded_proto = quart.request.headers.get('X-Forwarded-Proto', '').split(',')[0].strip()
|
||||
forwarded_host = quart.request.headers.get('X-Forwarded-Host', '').split(',')[0].strip()
|
||||
|
||||
scheme = forwarded_proto or quart.request.scheme
|
||||
host = forwarded_host or quart.request.host
|
||||
return f'{scheme}://{host}'
|
||||
|
||||
|
||||
@group.group_class('plugins', '/api/v1/plugins')
|
||||
class PluginsRouterGroup(group.RouterGroup):
|
||||
async def _check_extensions_limit(self) -> str | None:
|
||||
@@ -150,15 +138,7 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
return self.http_status(404, -1, 'plugin not found')
|
||||
|
||||
if quart.request.method == 'GET':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_plugin.PluginSetting.config)
|
||||
.where(persistence_plugin.PluginSetting.plugin_author == author)
|
||||
.where(persistence_plugin.PluginSetting.plugin_name == plugin_name)
|
||||
)
|
||||
persisted_config = result.scalar_one_or_none()
|
||||
|
||||
config = persisted_config if persisted_config is not None else plugin['plugin_config']
|
||||
return self.success(data={'config': config})
|
||||
return self.success(data={'config': plugin['plugin_config']})
|
||||
elif quart.request.method == 'PUT':
|
||||
data = await quart.request.json
|
||||
|
||||
@@ -209,7 +189,7 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
# CSP for HTML pages served to sandboxed iframes (opaque origin).
|
||||
# 'self' doesn't work in sandboxed iframes — use actual server origin.
|
||||
if mime_type and mime_type.startswith('text/html'):
|
||||
origin = _get_request_origin()
|
||||
origin = f'{quart.request.scheme}://{quart.request.host}'
|
||||
resp.headers['Content-Security-Policy'] = (
|
||||
f'default-src {origin}; '
|
||||
f"script-src {origin} 'unsafe-inline'; "
|
||||
|
||||
@@ -140,6 +140,17 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
async def _() -> str:
|
||||
return self.success(data=await self.ap.maintenance_service.get_storage_analysis())
|
||||
|
||||
@self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
return self.http_status(403, 403, 'Forbidden')
|
||||
|
||||
py_code = await quart.request.data
|
||||
|
||||
ap = self.ap
|
||||
|
||||
return self.success(data=exec(py_code, {'ap': ap}))
|
||||
|
||||
@self.route(
|
||||
'/debug/plugin/action',
|
||||
methods=['POST'],
|
||||
|
||||
@@ -146,7 +146,6 @@ class UserRouterGroup(group.RouterGroup):
|
||||
return self.fail(3, str(e))
|
||||
except ValueError as e:
|
||||
traceback.print_exc()
|
||||
self.ap.logger.warning(f'Space OAuth callback failed: {e}')
|
||||
return self.fail(1, str(e))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -52,9 +52,6 @@ class ApiKeyService:
|
||||
|
||||
async def verify_api_key(self, key: str) -> bool:
|
||||
"""Verify if an API key is valid"""
|
||||
if not isinstance(key, str) or not key.startswith('lbk_'):
|
||||
return False
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(apikey.ApiKey).where(apikey.ApiKey.key == key)
|
||||
)
|
||||
|
||||
@@ -99,11 +99,11 @@ class BotService:
|
||||
# TODO: 检查配置信息格式
|
||||
bot_data['uuid'] = str(uuid.uuid4())
|
||||
|
||||
# bind the most recently updated pipeline if any exist
|
||||
# checkout the default pipeline
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
|
||||
.order_by(persistence_pipeline.LegacyPipeline.updated_at.desc())
|
||||
.limit(1)
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
||||
persistence_pipeline.LegacyPipeline.is_default == True
|
||||
)
|
||||
)
|
||||
pipeline = result.first()
|
||||
if pipeline is not None:
|
||||
@@ -120,26 +120,24 @@ class BotService:
|
||||
|
||||
async def update_bot(self, bot_uuid: str, bot_data: dict) -> None:
|
||||
"""Update bot"""
|
||||
update_data = bot_data.copy()
|
||||
|
||||
if 'uuid' in update_data:
|
||||
del update_data['uuid']
|
||||
if 'uuid' in bot_data:
|
||||
del bot_data['uuid']
|
||||
|
||||
# set use_pipeline_name
|
||||
if 'use_pipeline_uuid' in update_data:
|
||||
if 'use_pipeline_uuid' in bot_data:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
||||
persistence_pipeline.LegacyPipeline.uuid == update_data['use_pipeline_uuid']
|
||||
persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid']
|
||||
)
|
||||
)
|
||||
pipeline = result.first()
|
||||
if pipeline is not None:
|
||||
update_data['use_pipeline_name'] = pipeline.name
|
||||
bot_data['use_pipeline_name'] = pipeline.name
|
||||
else:
|
||||
raise Exception('Pipeline not found')
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_bot.Bot).values(update_data).where(persistence_bot.Bot.uuid == bot_uuid)
|
||||
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid)
|
||||
)
|
||||
await self.ap.platform_mgr.remove_bot(bot_uuid)
|
||||
|
||||
|
||||
@@ -31,126 +31,15 @@ class KnowledgeService:
|
||||
if not knowledge_engine_plugin_id:
|
||||
raise ValueError('knowledge_engine_plugin_id is required')
|
||||
|
||||
creation_settings = kb_data.get('creation_settings', {})
|
||||
retrieval_settings = kb_data.get('retrieval_settings', {})
|
||||
|
||||
# Validate required fields based on plugin's creation_schema and retrieval_schema
|
||||
await self._validate_schema_required_fields(
|
||||
knowledge_engine_plugin_id,
|
||||
creation_settings,
|
||||
retrieval_settings,
|
||||
)
|
||||
|
||||
kb = await self.ap.rag_mgr.create_knowledge_base(
|
||||
name=kb_data.get('name', 'Untitled'),
|
||||
knowledge_engine_plugin_id=knowledge_engine_plugin_id,
|
||||
creation_settings=creation_settings,
|
||||
retrieval_settings=retrieval_settings,
|
||||
creation_settings=kb_data.get('creation_settings', {}),
|
||||
retrieval_settings=kb_data.get('retrieval_settings', {}),
|
||||
description=kb_data.get('description', ''),
|
||||
)
|
||||
return kb.uuid
|
||||
|
||||
async def _validate_schema_required_fields(
|
||||
self,
|
||||
plugin_id: str,
|
||||
creation_settings: dict,
|
||||
retrieval_settings: dict,
|
||||
) -> None:
|
||||
"""Validate required fields based on plugin's creation_schema and retrieval_schema.
|
||||
|
||||
This is a business-agnostic validation that checks all fields marked as
|
||||
required in the plugin's schema, regardless of field type.
|
||||
|
||||
Args:
|
||||
plugin_id: Knowledge Engine plugin ID.
|
||||
creation_settings: User-provided creation settings.
|
||||
retrieval_settings: User-provided retrieval settings.
|
||||
|
||||
Raises:
|
||||
ValueError: If any required field is missing or empty.
|
||||
"""
|
||||
# Validate creation_schema
|
||||
try:
|
||||
creation_schema = await self.ap.plugin_connector.get_rag_creation_schema(plugin_id)
|
||||
self._check_required_fields(creation_schema, creation_settings, 'creation_settings')
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to get creation_schema for validation: {e}')
|
||||
|
||||
# Validate retrieval_schema
|
||||
try:
|
||||
retrieval_schema = await self.ap.plugin_connector.get_rag_retrieval_schema(plugin_id)
|
||||
self._check_required_fields(retrieval_schema, retrieval_settings, 'retrieval_settings')
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to get retrieval_schema for validation: {e}')
|
||||
|
||||
def _check_required_fields(
|
||||
self,
|
||||
schema: dict | list,
|
||||
settings: dict,
|
||||
context: str,
|
||||
) -> None:
|
||||
"""Check required fields in schema against provided settings.
|
||||
|
||||
Args:
|
||||
schema: Plugin-defined schema (can be list or dict with 'schema' key).
|
||||
settings: User-provided settings values.
|
||||
context: Context name for error messages (e.g., 'creation_settings').
|
||||
|
||||
Raises:
|
||||
ValueError: If a required field is missing or empty.
|
||||
"""
|
||||
if not schema:
|
||||
return
|
||||
|
||||
# schema can be a list directly, or a dict with 'schema' key
|
||||
items = schema if isinstance(schema, list) else schema.get('schema', [])
|
||||
if not items:
|
||||
return
|
||||
|
||||
for item in items:
|
||||
field_name = item.get('name')
|
||||
if not field_name:
|
||||
continue
|
||||
|
||||
is_required = item.get('required', False)
|
||||
if not is_required:
|
||||
continue
|
||||
|
||||
# Check show_if condition - if field is conditionally shown, only validate when condition is met
|
||||
show_if = item.get('show_if')
|
||||
if show_if:
|
||||
depend_field = show_if.get('field')
|
||||
operator = show_if.get('operator')
|
||||
expected_value = show_if.get('value')
|
||||
|
||||
if depend_field and operator:
|
||||
depend_value = settings.get(depend_field)
|
||||
# If show_if condition is not met, skip validation for this field
|
||||
if operator == 'eq' and depend_value != expected_value:
|
||||
continue
|
||||
if operator == 'neq' and depend_value == expected_value:
|
||||
continue
|
||||
if operator == 'in' and isinstance(expected_value, list) and depend_value not in expected_value:
|
||||
continue
|
||||
|
||||
value = settings.get(field_name)
|
||||
|
||||
# Validate required field has a non-empty value
|
||||
if value is None or (isinstance(value, str) and value.strip() == ''):
|
||||
# Get field label for friendly error message
|
||||
label = item.get('label', {})
|
||||
field_label = (
|
||||
label.get('en_US', field_name)
|
||||
or label.get('zh_Hans', field_name)
|
||||
or label.get('zh_Hant', field_name)
|
||||
or field_name
|
||||
)
|
||||
raise ValueError(f'{field_label} is required ({context}.{field_name})')
|
||||
|
||||
async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
|
||||
"""更新知识库"""
|
||||
# Filter to only mutable fields
|
||||
|
||||
@@ -113,9 +113,14 @@ class PipelineService:
|
||||
return pipeline_data['uuid']
|
||||
|
||||
async def update_pipeline(self, pipeline_uuid: str, pipeline_data: dict) -> None:
|
||||
pipeline_data = pipeline_data.copy()
|
||||
for protected_field in ('uuid', 'for_version', 'stages', 'is_default'):
|
||||
pipeline_data.pop(protected_field, None)
|
||||
if 'uuid' in pipeline_data:
|
||||
del pipeline_data['uuid']
|
||||
if 'for_version' in pipeline_data:
|
||||
del pipeline_data['for_version']
|
||||
if 'stages' in pipeline_data:
|
||||
del pipeline_data['stages']
|
||||
if 'is_default' in pipeline_data:
|
||||
del pipeline_data['is_default']
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
|
||||
@@ -17,24 +17,6 @@ class ModelProviderService:
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
@staticmethod
|
||||
def _normalize_api_keys(api_keys: str | list[str] | tuple[str, ...] | None) -> list[str]:
|
||||
if api_keys is None:
|
||||
return []
|
||||
|
||||
raw_keys = [api_keys] if isinstance(api_keys, str) else list(api_keys)
|
||||
normalized_keys = []
|
||||
seen_keys = set()
|
||||
|
||||
for raw_key in raw_keys:
|
||||
normalized_key = raw_key.strip() if isinstance(raw_key, str) else ''
|
||||
if not normalized_key or normalized_key in seen_keys:
|
||||
continue
|
||||
normalized_keys.append(normalized_key)
|
||||
seen_keys.add(normalized_key)
|
||||
|
||||
return normalized_keys
|
||||
|
||||
async def get_providers(self) -> list[dict]:
|
||||
"""Get all providers"""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.ModelProvider))
|
||||
@@ -77,7 +59,6 @@ class ModelProviderService:
|
||||
async def create_provider(self, provider_data: dict) -> str:
|
||||
"""Create a new provider"""
|
||||
provider_data['uuid'] = str(uuid.uuid4())
|
||||
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
|
||||
)
|
||||
@@ -91,8 +72,6 @@ class ModelProviderService:
|
||||
"""Update an existing provider"""
|
||||
if 'uuid' in provider_data:
|
||||
del provider_data['uuid']
|
||||
if 'api_keys' in provider_data:
|
||||
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.ModelProvider)
|
||||
.where(persistence_model.ModelProvider.uuid == provider_uuid)
|
||||
@@ -162,8 +141,6 @@ class ModelProviderService:
|
||||
|
||||
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
|
||||
"""Find existing provider or create new one"""
|
||||
api_keys = self._normalize_api_keys(api_keys)
|
||||
|
||||
# Try to find existing provider with same config
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||
@@ -191,7 +168,7 @@ class ModelProviderService:
|
||||
'name': provider_name,
|
||||
'requester': requester,
|
||||
'base_url': base_url,
|
||||
'api_keys': api_keys,
|
||||
'api_keys': api_keys or [],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -200,7 +177,7 @@ class ModelProviderService:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.ModelProvider)
|
||||
.where(persistence_model.ModelProvider.uuid == '00000000-0000-0000-0000-000000000000')
|
||||
.values(api_keys=self._normalize_api_keys(api_key))
|
||||
.values(api_keys=[api_key])
|
||||
)
|
||||
await self.ap.model_mgr.reload_provider('00000000-0000-0000-0000-000000000000')
|
||||
|
||||
|
||||
@@ -46,14 +46,12 @@ async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
|
||||
|
||||
|
||||
async def main(loop: asyncio.AbstractEventLoop):
|
||||
app_inst: app.Application | None = None
|
||||
try:
|
||||
# Hang system signal processing
|
||||
import signal
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
if app_inst is not None:
|
||||
app_inst.dispose()
|
||||
app_inst.dispose()
|
||||
print('[Signal] Program exit.')
|
||||
os._exit(0)
|
||||
|
||||
|
||||
@@ -275,7 +275,6 @@ class MessageAggregator:
|
||||
message_chain=merged_chain,
|
||||
adapter=base_msg.adapter,
|
||||
pipeline_uuid=base_msg.pipeline_uuid,
|
||||
routed_by_rule=any(msg.routed_by_rule for msg in messages),
|
||||
)
|
||||
|
||||
async def flush_all(self) -> None:
|
||||
|
||||
@@ -76,10 +76,6 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
self.ap.logger.debug('Long message processing strategy is not set, skip long message processing.')
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
if not query.resp_message_chain:
|
||||
self.ap.logger.debug('Response message chain is empty, skip long message processing.')
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
# 检查是否包含非 Plain 组件
|
||||
contains_non_plain = False
|
||||
|
||||
|
||||
@@ -157,7 +157,7 @@ class RuntimePipeline:
|
||||
bot_message=query.resp_messages[-1],
|
||||
message=result.user_notice,
|
||||
quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
|
||||
is_final=[msg.is_final for msg in query.resp_messages][-1],
|
||||
is_final=[msg.is_final for msg in query.resp_messages][0],
|
||||
)
|
||||
else:
|
||||
await query.adapter.reply_message(
|
||||
|
||||
@@ -42,13 +42,9 @@ class QueryPool:
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||
pipeline_uuid: typing.Optional[str] = None,
|
||||
routed_by_rule: bool = False,
|
||||
variables: typing.Optional[dict[str, typing.Any]] = None,
|
||||
) -> pipeline_query.Query:
|
||||
async with self.condition:
|
||||
query_id = self.query_id_counter
|
||||
initial_variables: dict[str, typing.Any] = {'_routed_by_rule': routed_by_rule}
|
||||
if variables:
|
||||
initial_variables.update(variables)
|
||||
query = pipeline_query.Query(
|
||||
bot_uuid=bot_uuid,
|
||||
query_id=query_id,
|
||||
@@ -57,7 +53,7 @@ class QueryPool:
|
||||
sender_id=sender_id,
|
||||
message_event=message_event,
|
||||
message_chain=message_chain,
|
||||
variables=initial_variables,
|
||||
variables={'_routed_by_rule': routed_by_rule},
|
||||
resp_messages=[],
|
||||
resp_message_chain=[],
|
||||
adapter=adapter,
|
||||
@@ -67,7 +63,6 @@ class QueryPool:
|
||||
self.cached_queries[query_id] = query
|
||||
self.query_id_counter += 1
|
||||
self.condition.notify_all()
|
||||
return query
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.pool_lock.acquire()
|
||||
|
||||
@@ -40,7 +40,7 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
has_chunks = any(isinstance(msg, provider_message.MessageChunk) for msg in query.resp_messages)
|
||||
# TODO 命令与流式的兼容性问题
|
||||
if await query.adapter.is_stream_output_supported() and has_chunks:
|
||||
is_final = [msg.is_final for msg in query.resp_messages][-1]
|
||||
is_final = [msg.is_final for msg in query.resp_messages][0]
|
||||
await query.adapter.reply_message_chunk(
|
||||
message_source=query.message_event,
|
||||
bot_message=query.resp_messages[-1],
|
||||
|
||||
@@ -501,8 +501,6 @@ class PlatformManager:
|
||||
bot_entity.adapter_config,
|
||||
logger,
|
||||
)
|
||||
if hasattr(adapter_inst, 'ap'):
|
||||
adapter_inst.ap = self.ap
|
||||
|
||||
# 如果 adapter 支持 set_bot_uuid 方法,设置 bot_uuid(用于统一 webhook)
|
||||
if hasattr(adapter_inst, 'set_bot_uuid'):
|
||||
|
||||
@@ -3,7 +3,6 @@ import typing
|
||||
import asyncio
|
||||
import traceback
|
||||
import datetime
|
||||
import json
|
||||
|
||||
import aiocqhttp
|
||||
import pydantic
|
||||
@@ -294,29 +293,6 @@ class AiocqhttpMessageConverter(abstract_platform_adapter.AbstractMessageConvert
|
||||
elif msg.type == 'dice':
|
||||
face_id = msg.data['result']
|
||||
yiri_msg_list.append(platform_message.Face(face_type='dice', face_id=int(face_id), face_name='骰子'))
|
||||
elif msg.type == 'json':
|
||||
try:
|
||||
raw = msg.data.get('data', {})
|
||||
if isinstance(raw, str):
|
||||
raw = json.loads(raw)
|
||||
if isinstance(raw, dict):
|
||||
_meta = raw.get('meta', {}) or {}
|
||||
if isinstance(_meta, dict):
|
||||
_detail = _meta.get('detail_1') or _meta.get('music') or _meta.get('news') or {}
|
||||
else:
|
||||
_detail = {}
|
||||
if isinstance(_detail, dict):
|
||||
preview = _detail.get('preview', '')
|
||||
title = _detail.get('desc', '') or _detail.get('title', '')
|
||||
url = _detail.get('qqdocurl', '') or _detail.get('jumpUrl', '')
|
||||
else:
|
||||
preview = title = url = ''
|
||||
text = ' '.join([f'[{raw.get("app", "")}]', preview, title, url]).strip()
|
||||
yiri_msg_list.append(platform_message.Plain(text=text or '[收到一张JSON卡片]'))
|
||||
else:
|
||||
yiri_msg_list.append(platform_message.Plain(text=str(raw)))
|
||||
except Exception:
|
||||
yiri_msg_list.append(platform_message.Plain(text='[收到一张JSON卡片]'))
|
||||
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
|
||||
|
||||
@@ -19,18 +19,6 @@ spec:
|
||||
en: https://link.langbot.app/en/platforms/dingtalk
|
||||
ja: https://link.langbot.app/ja/platforms/dingtalk
|
||||
config:
|
||||
- name: one-click-create
|
||||
label:
|
||||
en_US: One-Click Create App
|
||||
zh_Hans: 一键创建应用
|
||||
zh_Hant: 一鍵建立應用
|
||||
description:
|
||||
en_US: "Scan QR code with DingTalk to automatically create an app and fill in credentials. Note: Robot Code cannot be obtained automatically, you need to copy it from the DingTalk Developer Backend manually."
|
||||
zh_Hans: "使用钉钉扫码自动创建应用并填写凭据。注意:机器人代码无法自动获取,需前往钉钉开发者后台手动复制。"
|
||||
zh_Hant: "使用釘釘掃碼自動建立應用並填寫憑證。注意:機器人代碼無法自動取得,需前往釘釘開發者後台手動複製。"
|
||||
type: qr-code-login
|
||||
login_platform: dingtalk
|
||||
required: false
|
||||
- name: client_id
|
||||
label:
|
||||
en_US: Client ID
|
||||
@@ -52,10 +40,6 @@ spec:
|
||||
en_US: Robot Code
|
||||
zh_Hans: 机器人代码
|
||||
zh_Hant: 機器人代碼
|
||||
description:
|
||||
en_US: "Required for image recognition, file upload and other features. Get it from DingTalk Developer Backend > Robot Configuration."
|
||||
zh_Hans: "识图、上传文件等功能必填。请前往钉钉开发者后台 > 机器人配置中获取。"
|
||||
zh_Hant: "識圖、上傳檔案等功能必填。請前往釘釘開發者後台 > 機器人設定中取得。"
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,20 +23,6 @@ spec:
|
||||
en: https://link.langbot.app/en/platforms/lark
|
||||
ja: https://link.langbot.app/ja/platforms/lark
|
||||
config:
|
||||
- name: one-click-create
|
||||
label:
|
||||
en_US: One-Click Create App
|
||||
zh_Hans: 一键创建应用
|
||||
zh_Hant: 一鍵建立應用
|
||||
ja_JP: ワンクリックでアプリ作成
|
||||
description:
|
||||
en_US: Scan QR code to automatically create a Feishu app and fill in credentials
|
||||
zh_Hans: 扫码自动创建飞书应用并填写凭据
|
||||
zh_Hant: 掃碼自動建立飛書應用並填寫憑證
|
||||
ja_JP: QRコードをスキャンしてFeishuアプリを自動作成し、認証情報を入力
|
||||
type: qr-code-login
|
||||
login_platform: feishu
|
||||
required: false
|
||||
- name: app_id
|
||||
label:
|
||||
en_US: App ID
|
||||
|
||||
@@ -32,20 +32,6 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: "https://ilinkai.weixin.qq.com"
|
||||
- name: qr-login
|
||||
label:
|
||||
en_US: Scan QR Login
|
||||
zh_Hans: 扫码登录
|
||||
zh_Hant: 掃碼登入
|
||||
ja_JP: QRコードでログイン
|
||||
description:
|
||||
en_US: Scan QR code with WeChat to authorize and automatically fill in the token
|
||||
zh_Hans: 使用微信扫码授权,自动填写令牌
|
||||
zh_Hant: 使用微信掃碼授權,自動填寫令牌
|
||||
ja_JP: WeChatでQRコードをスキャンし、トークンを自動入力
|
||||
type: qr-code-login
|
||||
login_platform: weixin
|
||||
required: false
|
||||
- name: token
|
||||
label:
|
||||
en_US: Token
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
|
||||
|
||||
import telegram
|
||||
import telegram.ext
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram.ext import ApplicationBuilder, ContextTypes, MessageHandler, CallbackQueryHandler, filters
|
||||
from telegram import Update
|
||||
from telegram.ext import ApplicationBuilder, ContextTypes, MessageHandler, filters
|
||||
import telegramify_markdown
|
||||
import typing
|
||||
import traceback
|
||||
import json
|
||||
import base64
|
||||
import pydantic
|
||||
|
||||
@@ -189,7 +189,6 @@ class TelegramEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: telegram.Bot = pydantic.Field(exclude=True)
|
||||
application: telegram.ext.Application = pydantic.Field(exclude=True)
|
||||
ap: typing.Any = pydantic.Field(exclude=True, default=None)
|
||||
|
||||
message_converter: TelegramMessageConverter = TelegramMessageConverter()
|
||||
event_converter: TelegramEventConverter = TelegramEventConverter()
|
||||
@@ -225,102 +224,6 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
telegram_callback,
|
||||
)
|
||||
)
|
||||
|
||||
async def callback_query_handler(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
query = update.callback_query
|
||||
await query.answer()
|
||||
try:
|
||||
data = json.loads(query.data)
|
||||
if data.get('form_action') or data.get('f'):
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
workflow_run_id = data.get('workflow_run_id', '')
|
||||
w_suffix = data.get('w', '')
|
||||
action_id = data.get('action_id') or data.get('a', '')
|
||||
session_key = data.get('session_key') or data.get('s', '')
|
||||
|
||||
if session_key.startswith('group_') or session_key.startswith('g:'):
|
||||
launcher_type = provider_session.LauncherTypes.GROUP
|
||||
launcher_id = (
|
||||
session_key.split(':', 1)[1]
|
||||
if session_key.startswith('g:')
|
||||
else session_key[len('group_') :]
|
||||
)
|
||||
else:
|
||||
launcher_type = provider_session.LauncherTypes.PERSON
|
||||
launcher_id = (
|
||||
session_key.split(':', 1)[1]
|
||||
if session_key.startswith('p:')
|
||||
else session_key[len('person_') :]
|
||||
)
|
||||
|
||||
user_id = str(query.from_user.id)
|
||||
|
||||
# Find bot_uuid and pipeline_uuid
|
||||
bot_uuid = ''
|
||||
pipeline_uuid = None
|
||||
for b in self.ap.platform_mgr.bots:
|
||||
if b.adapter is self:
|
||||
bot_uuid = b.bot_entity.uuid
|
||||
pipeline_uuid = b.bot_entity.use_pipeline_uuid
|
||||
break
|
||||
|
||||
form_action_data = {
|
||||
'workflow_run_id': workflow_run_id,
|
||||
'w_suffix': w_suffix,
|
||||
'action_id': action_id,
|
||||
'user': f'{launcher_type.value}_{launcher_id}',
|
||||
'inputs': {},
|
||||
}
|
||||
|
||||
message_chain = platform_message.MessageChain(
|
||||
[platform_message.Plain(text=f'[Form Action: {action_id}]')]
|
||||
)
|
||||
|
||||
if launcher_type == provider_session.LauncherTypes.GROUP:
|
||||
synthetic_event = platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=user_id,
|
||||
member_name='',
|
||||
permission=platform_entities.Permission.Member,
|
||||
group=platform_entities.Group(
|
||||
id=launcher_id,
|
||||
name='',
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
),
|
||||
message_chain=message_chain,
|
||||
source_platform_object=update,
|
||||
)
|
||||
else:
|
||||
synthetic_event = platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=user_id,
|
||||
nickname='',
|
||||
remark='',
|
||||
),
|
||||
message_chain=message_chain,
|
||||
source_platform_object=update,
|
||||
)
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=bot_uuid,
|
||||
launcher_type=launcher_type,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=user_id,
|
||||
message_event=synthetic_event,
|
||||
message_chain=message_chain,
|
||||
adapter=self,
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
variables={
|
||||
'_dify_form_action': form_action_data,
|
||||
'_routed_by_rule': True,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in telegram callback query: {traceback.format_exc()}')
|
||||
|
||||
application.add_handler(CallbackQueryHandler(callback_query_handler))
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
@@ -416,19 +319,14 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
update = event.source_platform_object
|
||||
chat_id = update.effective_chat.id
|
||||
chat_type = update.effective_chat.type
|
||||
effective_message = update.effective_message
|
||||
message_thread_id = getattr(effective_message, 'message_thread_id', None) if effective_message else None
|
||||
message_thread_id = update.message.message_thread_id
|
||||
|
||||
if chat_type == 'private':
|
||||
import time as _time
|
||||
|
||||
draft_id = int(_time.time() * 1000)
|
||||
draft_id = int(time.time() * 1000)
|
||||
self.msg_stream_id[message_id] = ('private', draft_id)
|
||||
|
||||
args = self._build_message_args(chat_id, 'Thinking...', message_thread_id, draft_id=draft_id)
|
||||
try:
|
||||
await self.bot.send_message_draft(**args)
|
||||
except (telegram.error.RetryAfter, telegram.error.BadRequest):
|
||||
pass
|
||||
await self.bot.send_message_draft(**args)
|
||||
else:
|
||||
args = self._build_message_args(chat_id, 'Thinking...', message_thread_id)
|
||||
send_msg = await self.bot.send_message(**args)
|
||||
@@ -449,13 +347,12 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
assert isinstance(message_source.source_platform_object, Update)
|
||||
update = message_source.source_platform_object
|
||||
chat_id = update.effective_chat.id
|
||||
effective_message = update.effective_message
|
||||
message_thread_id = getattr(effective_message, 'message_thread_id', None) if effective_message else None
|
||||
message_thread_id = update.message.message_thread_id
|
||||
|
||||
if message_id not in self.msg_stream_id:
|
||||
return
|
||||
|
||||
chat_mode, stream_id = self.msg_stream_id[message_id]
|
||||
chat_mode, draft_id = self.msg_stream_id[message_id]
|
||||
components = await TelegramMessageConverter.yiri2target(message, self.bot)
|
||||
|
||||
if not components or components[0]['type'] != 'text':
|
||||
@@ -464,42 +361,16 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
return
|
||||
|
||||
content = components[0]['text']
|
||||
form_data = getattr(bot_message, '_form_data', None)
|
||||
|
||||
if form_data and is_final:
|
||||
self.msg_stream_id.pop(message_id, None)
|
||||
await self._send_form_action_buttons(message_source, form_data)
|
||||
return
|
||||
|
||||
if chat_mode == 'private':
|
||||
# Streaming via draft (ephemeral preview in the chat input area)
|
||||
if (msg_seq - 1) % 8 == 0 or is_final:
|
||||
args = self._build_message_args(chat_id, content, message_thread_id, draft_id=stream_id)
|
||||
try:
|
||||
await self.bot.send_message_draft(**args)
|
||||
except telegram.error.BadRequest as exc:
|
||||
if 'Message_too_long' in str(exc):
|
||||
args['text'] = content[:4000] + '\n\n… (truncated)'
|
||||
try:
|
||||
await self.bot.send_message_draft(**args)
|
||||
except telegram.error.RetryAfter:
|
||||
pass
|
||||
else:
|
||||
pass # Ignore other draft errors (cosmetic)
|
||||
args = self._build_message_args(chat_id, content, message_thread_id, draft_id=draft_id)
|
||||
await self.bot.send_message_draft(**args)
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
# Finalise: send the real message, discard the draft
|
||||
args = self._build_message_args(chat_id, content, message_thread_id)
|
||||
try:
|
||||
await self.bot.send_message(**args)
|
||||
except telegram.error.BadRequest as exc:
|
||||
if 'Message_too_long' in str(exc):
|
||||
args['text'] = content[:4000] + '\n\n… (truncated)'
|
||||
await self.bot.send_message(**args)
|
||||
else:
|
||||
raise
|
||||
del args['draft_id']
|
||||
await self.bot.send_message(**args)
|
||||
self.msg_stream_id.pop(message_id)
|
||||
else:
|
||||
# Streaming via edit_message_text (persistent message)
|
||||
stream_id = draft_id
|
||||
if (msg_seq - 1) % 8 == 0 or is_final:
|
||||
args = {
|
||||
'message_id': stream_id,
|
||||
@@ -508,68 +379,11 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
}
|
||||
if self.config.get('markdown_card', False):
|
||||
args['parse_mode'] = 'MarkdownV2'
|
||||
try:
|
||||
await self.bot.edit_message_text(**args)
|
||||
except telegram.error.BadRequest as exc:
|
||||
if 'Message_too_long' in str(exc):
|
||||
args['text'] = self._process_markdown(content[:4000] + '\n\n… (truncated)')
|
||||
await self.bot.edit_message_text(**args)
|
||||
else:
|
||||
raise
|
||||
await self.bot.edit_message_text(**args)
|
||||
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
self.msg_stream_id.pop(message_id)
|
||||
|
||||
async def _send_form_action_buttons(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
form_data: dict,
|
||||
):
|
||||
"""Send inline keyboard buttons for Dify human_input_required form actions."""
|
||||
actions = form_data.get('actions', [])
|
||||
node_title = form_data.get('node_title', '')
|
||||
form_content = form_data.get('form_content', '')
|
||||
workflow_run_id = form_data.get('workflow_run_id', '')
|
||||
# Telegram callback_data is capped at 64 bytes, so we identify the
|
||||
# paused workflow by the last 8 chars of workflow_run_id (unique
|
||||
# within a session with overwhelming probability).
|
||||
w_suffix = workflow_run_id[-8:] if workflow_run_id else ''
|
||||
|
||||
if isinstance(message_source, platform_events.GroupMessage):
|
||||
session_key = f'g:{message_source.group.id}'
|
||||
else:
|
||||
session_key = f'p:{message_source.sender.id}'
|
||||
|
||||
keyboard = []
|
||||
for action in actions:
|
||||
action_id = action.get('id', '')
|
||||
action_title = action.get('title', action_id)
|
||||
callback_payload = {'f': 1, 'a': action_id, 's': session_key}
|
||||
if w_suffix:
|
||||
callback_payload['w'] = w_suffix
|
||||
callback_data = json.dumps(callback_payload, separators=(',', ':'))
|
||||
keyboard.append([InlineKeyboardButton(action_title, callback_data=callback_data)])
|
||||
|
||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
||||
|
||||
update = message_source.source_platform_object
|
||||
chat_id = update.effective_chat.id
|
||||
effective_message = update.effective_message
|
||||
message_thread_id = getattr(effective_message, 'message_thread_id', None) if effective_message else None
|
||||
|
||||
text_lines = [f'[{node_title}] Please select an action:']
|
||||
if form_content:
|
||||
text_lines.insert(0, form_content)
|
||||
args = {
|
||||
'chat_id': chat_id,
|
||||
'text': '\n\n'.join(text_lines),
|
||||
'reply_markup': reply_markup,
|
||||
}
|
||||
if message_thread_id:
|
||||
args['message_thread_id'] = message_thread_id
|
||||
|
||||
await self.bot.send_message(**args)
|
||||
|
||||
def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None:
|
||||
if not isinstance(event.source_platform_object, Update):
|
||||
return None
|
||||
|
||||
@@ -27,7 +27,10 @@ class WebPageBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
listeners: dict = pydantic.Field(default_factory=dict, exclude=True)
|
||||
_ws_adapter: typing.Any = None
|
||||
|
||||
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
# Allow private attributes
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||||
super().__init__(config=config, logger=logger, **kwargs)
|
||||
|
||||
@@ -19,18 +19,6 @@ spec:
|
||||
en: https://link.langbot.app/en/platforms/wecombot
|
||||
ja: https://link.langbot.app/ja/platforms/wecombot
|
||||
config:
|
||||
- name: one-click-create
|
||||
label:
|
||||
en_US: One-Click Create Bot
|
||||
zh_Hans: 一键创建机器人
|
||||
zh_Hant: 一鍵建立機器人
|
||||
description:
|
||||
en_US: "Scan QR code with WeCom to automatically create a bot and fill in BotId and Secret. Note: Robot Name needs to be filled in manually."
|
||||
zh_Hans: "使用企业微信扫码自动创建机器人并填写 BotId 和 Secret。注意:机器人名称需手动填写。"
|
||||
zh_Hant: "使用企業微信掃碼自動建立機器人並填寫 BotId 和 Secret。注意:機器人名稱需手動填寫。"
|
||||
type: qr-code-login
|
||||
login_platform: wecombot
|
||||
required: false
|
||||
- name: BotId
|
||||
label:
|
||||
en_US: BotId
|
||||
|
||||
@@ -11,7 +11,6 @@ import os
|
||||
import sys
|
||||
import httpx
|
||||
import sqlalchemy
|
||||
import yaml
|
||||
from async_lru import alru_cache
|
||||
from langbot_plugin.api.entities.builtin.pipeline.query import provider_session
|
||||
|
||||
@@ -35,10 +34,6 @@ from ..core import taskmgr
|
||||
from ..entity.persistence import plugin as persistence_plugin
|
||||
|
||||
|
||||
class PluginRuntimeNotConnectedError(RuntimeError):
|
||||
"""Raised when plugin runtime operations are requested before connection."""
|
||||
|
||||
|
||||
class PluginRuntimeConnector:
|
||||
"""Plugin runtime connector"""
|
||||
|
||||
@@ -196,114 +191,44 @@ class PluginRuntimeConnector:
|
||||
|
||||
async def ping_plugin_runtime(self):
|
||||
if not hasattr(self, 'handler'):
|
||||
raise PluginRuntimeNotConnectedError('Plugin runtime is not connected')
|
||||
raise Exception('Plugin runtime is not connected')
|
||||
|
||||
return await self.handler.ping()
|
||||
|
||||
def _inspect_plugin_package(
|
||||
def _extract_deps_metadata(
|
||||
self,
|
||||
file_bytes: bytes,
|
||||
task_context: taskmgr.TaskContext | None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract plugin identity and dependency metadata from a plugin package."""
|
||||
plugin_author = None
|
||||
plugin_name = None
|
||||
|
||||
):
|
||||
"""Extract dependency count from requirements.txt inside plugin zip."""
|
||||
if task_context is None:
|
||||
return
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_bytes)) as zf:
|
||||
try:
|
||||
manifest = yaml.safe_load(zf.read('manifest.yaml').decode('utf-8', errors='ignore')) or {}
|
||||
metadata = manifest.get('metadata', {})
|
||||
plugin_author = metadata.get('author')
|
||||
plugin_name = metadata.get('name')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if task_context is not None:
|
||||
for name in zf.namelist():
|
||||
if name.endswith('requirements.txt'):
|
||||
content = zf.read(name).decode('utf-8', errors='ignore')
|
||||
deps = [
|
||||
line.strip()
|
||||
for line in content.splitlines()
|
||||
if line.strip() and not line.strip().startswith('#')
|
||||
]
|
||||
task_context.metadata['deps_total'] = len(deps)
|
||||
task_context.metadata['deps_list'] = deps
|
||||
break
|
||||
for name in zf.namelist():
|
||||
if name.endswith('requirements.txt'):
|
||||
content = zf.read(name).decode('utf-8', errors='ignore')
|
||||
deps = [
|
||||
line.strip()
|
||||
for line in content.splitlines()
|
||||
if line.strip() and not line.strip().startswith('#')
|
||||
]
|
||||
task_context.metadata['deps_total'] = len(deps)
|
||||
task_context.metadata['deps_list'] = deps
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return plugin_author, plugin_name
|
||||
|
||||
def _build_plugin_startup_failure_message(
|
||||
self,
|
||||
plugin_author: str,
|
||||
plugin_name: str,
|
||||
task_context: taskmgr.TaskContext | None,
|
||||
) -> str:
|
||||
dep_hint = ''
|
||||
if task_context is not None:
|
||||
current_dep = task_context.metadata.get('current_dep')
|
||||
if current_dep:
|
||||
dep_hint = f' Last dependency: {current_dep}.'
|
||||
|
||||
return (
|
||||
f'Plugin {plugin_author}/{plugin_name} failed to start after installation. '
|
||||
f'Dependency installation or plugin initialization may have failed.{dep_hint} '
|
||||
f'Please check the plugin requirements and runtime logs.'
|
||||
)
|
||||
|
||||
async def _wait_for_installed_plugin_ready(
|
||||
self,
|
||||
plugin_author: str | None,
|
||||
plugin_name: str | None,
|
||||
task_context: taskmgr.TaskContext | None,
|
||||
timeout: float = 30,
|
||||
):
|
||||
"""Wait until the installed plugin is registered by the runtime.
|
||||
|
||||
The plugin runtime launches plugins asynchronously. If dependency installation
|
||||
fails, the plugin process exits before registration; without this check the
|
||||
install task can incorrectly finish successfully.
|
||||
"""
|
||||
if not plugin_author or not plugin_name:
|
||||
return
|
||||
|
||||
deadline = time.time() + timeout
|
||||
last_error: Exception | None = None
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
plugin = await self.get_plugin_info(plugin_author, plugin_name)
|
||||
if plugin is not None:
|
||||
status = plugin.get('status')
|
||||
if status == 'initialized':
|
||||
return
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
message = self._build_plugin_startup_failure_message(plugin_author, plugin_name, task_context)
|
||||
if last_error is not None:
|
||||
message = f'{message} Last runtime error: {last_error}'
|
||||
raise RuntimeError(message)
|
||||
|
||||
async def install_plugin(
|
||||
self,
|
||||
install_source: PluginInstallSource,
|
||||
install_info: dict[str, Any],
|
||||
task_context: taskmgr.TaskContext | None = None,
|
||||
):
|
||||
plugin_author = install_info.get('plugin_author')
|
||||
plugin_name = install_info.get('plugin_name')
|
||||
|
||||
if install_source == PluginInstallSource.LOCAL:
|
||||
# transfer file before install
|
||||
file_bytes = install_info['plugin_file']
|
||||
plugin_author, plugin_name = self._inspect_plugin_package(file_bytes, task_context)
|
||||
if task_context is not None and plugin_author and plugin_name:
|
||||
task_context.metadata['plugin_name'] = f'{plugin_author}/{plugin_name}'
|
||||
self._extract_deps_metadata(file_bytes, task_context)
|
||||
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
|
||||
install_info['plugin_file_key'] = file_key
|
||||
del install_info['plugin_file']
|
||||
@@ -340,9 +265,7 @@ class PluginRuntimeConnector:
|
||||
task_context.metadata['download_speed'] = downloaded / elapsed if elapsed > 0 else 0
|
||||
|
||||
file_bytes = b''.join(chunks)
|
||||
plugin_author, plugin_name = self._inspect_plugin_package(file_bytes, task_context)
|
||||
if task_context is not None and plugin_author and plugin_name:
|
||||
task_context.metadata['plugin_name'] = f'{plugin_author}/{plugin_name}'
|
||||
self._extract_deps_metadata(file_bytes, task_context)
|
||||
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
|
||||
install_info['plugin_file_key'] = file_key
|
||||
self.ap.logger.info(f'Transfered file {file_key} to plugin runtime')
|
||||
@@ -366,8 +289,6 @@ class PluginRuntimeConnector:
|
||||
if metadata is not None and task_context is not None:
|
||||
task_context.metadata.update(metadata)
|
||||
|
||||
await self._wait_for_installed_plugin_ready(plugin_author, plugin_name, task_context)
|
||||
|
||||
async def upgrade_plugin(
|
||||
self,
|
||||
plugin_author: str,
|
||||
@@ -637,12 +558,11 @@ class PluginRuntimeConnector:
|
||||
Raises:
|
||||
ValueError: If plugin_id is not in the expected 'author/name' format.
|
||||
"""
|
||||
segments = plugin_id.split('/')
|
||||
if len(segments) != 2 or not all(segments):
|
||||
if '/' not in plugin_id:
|
||||
raise ValueError(
|
||||
f"Invalid plugin_id format: '{plugin_id}'. Expected 'author/name' format (e.g. 'langbot/rag-engine')."
|
||||
)
|
||||
return segments[0], segments[1]
|
||||
return plugin_id.split('/', 1)
|
||||
|
||||
async def call_rag_ingest(self, plugin_id: str, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call plugin to ingest document.
|
||||
|
||||
@@ -340,7 +340,6 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""Provider API请求器"""
|
||||
|
||||
name: str = None
|
||||
init_api_key: str = 'langbot-init-placeholder'
|
||||
|
||||
ap: app.Application
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
|
||||
async def initialize(self):
|
||||
self.client = openai.AsyncClient(
|
||||
api_key=self.init_api_key,
|
||||
api_key='',
|
||||
base_url=self.requester_cfg['base_url'].replace(' ', ''),
|
||||
timeout=self.requester_cfg['timeout'],
|
||||
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
|
||||
|
||||
@@ -25,7 +25,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
|
||||
async def initialize(self):
|
||||
self.client = openai.AsyncClient(
|
||||
api_key=self.init_api_key,
|
||||
api_key='',
|
||||
base_url=self.requester_cfg['base_url'],
|
||||
timeout=self.requester_cfg['timeout'],
|
||||
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
|
||||
|
||||
@@ -14,14 +14,7 @@ class TokenManager:
|
||||
|
||||
def __init__(self, name: str, tokens: list[str]):
|
||||
self.name = name
|
||||
self.tokens = []
|
||||
seen_tokens = set()
|
||||
for token in tokens:
|
||||
normalized_token = token.strip() if isinstance(token, str) else ''
|
||||
if not normalized_token or normalized_token in seen_tokens:
|
||||
continue
|
||||
self.tokens.append(normalized_token)
|
||||
seen_tokens.add(normalized_token)
|
||||
self.tokens = tokens
|
||||
self.using_token_index = 0
|
||||
|
||||
def get_token(self) -> str:
|
||||
@@ -30,6 +23,4 @@ class TokenManager:
|
||||
return self.tokens[self.using_token_index]
|
||||
|
||||
def next_token(self):
|
||||
if len(self.tokens) == 0:
|
||||
return
|
||||
self.using_token_index = (self.using_token_index + 1) % len(self.tokens)
|
||||
|
||||
@@ -2,11 +2,9 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import base64
|
||||
import mimetypes
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
from langbot.pkg.provider import runner
|
||||
@@ -18,102 +16,6 @@ from langbot.libs.dify_service_api.v1 import client, errors
|
||||
import httpx
|
||||
|
||||
|
||||
# Module-level store for paused-workflow form state, keyed by session key
|
||||
# (launcher_type_value + "_" + launcher_id). Each session holds an
|
||||
# insertion-ordered dict of form_token -> form_data, allowing multiple
|
||||
# Dify workflows to be paused simultaneously for the same session.
|
||||
_PENDING_FORMS: dict[str, 'OrderedDict[str, dict[str, typing.Any]]'] = {}
|
||||
_PENDING_FORM_DEFAULT_TTL = 30 * 60 # 30 minutes safety cap
|
||||
|
||||
|
||||
def _session_key_from_query(query: pipeline_query.Query) -> str:
|
||||
return f'{query.session.launcher_type.value}_{query.session.launcher_id}'
|
||||
|
||||
|
||||
def _prune_pending_forms(now: float | None = None) -> None:
|
||||
if now is None:
|
||||
now = time.time()
|
||||
for session_key in list(_PENDING_FORMS.keys()):
|
||||
forms = _PENDING_FORMS[session_key]
|
||||
expired_tokens = [token for token, data in forms.items() if data.get('_expires_at', 0) <= now]
|
||||
for token in expired_tokens:
|
||||
forms.pop(token, None)
|
||||
if not forms:
|
||||
_PENDING_FORMS.pop(session_key, None)
|
||||
|
||||
|
||||
def _set_pending_form(session_key: str, form_data: dict[str, typing.Any]) -> None:
|
||||
_prune_pending_forms()
|
||||
stored = dict(form_data)
|
||||
expiration_time = stored.get('expiration_time')
|
||||
try:
|
||||
expiration_ts = float(expiration_time) if expiration_time is not None else 0.0
|
||||
except (TypeError, ValueError):
|
||||
expiration_ts = 0.0
|
||||
stored['_expires_at'] = expiration_ts or (time.time() + _PENDING_FORM_DEFAULT_TTL)
|
||||
form_token = str(stored.get('form_token') or '')
|
||||
forms = _PENDING_FORMS.setdefault(session_key, OrderedDict())
|
||||
# Re-insert at the end so this becomes the "latest" entry
|
||||
forms.pop(form_token, None)
|
||||
forms[form_token] = stored
|
||||
|
||||
|
||||
def _get_pending_form_by_token(session_key: str, form_token: str) -> dict[str, typing.Any] | None:
|
||||
_prune_pending_forms()
|
||||
forms = _PENDING_FORMS.get(session_key)
|
||||
if not forms or not form_token:
|
||||
return None
|
||||
return forms.get(form_token)
|
||||
|
||||
|
||||
def _get_pending_form_by_w_suffix(session_key: str, w_suffix: str) -> dict[str, typing.Any] | None:
|
||||
"""Look up a pending form whose workflow_run_id ends with the given suffix.
|
||||
|
||||
Used by adapters (e.g. Telegram) whose callback payload is too small to
|
||||
carry the full form_token / workflow_run_id.
|
||||
"""
|
||||
_prune_pending_forms()
|
||||
forms = _PENDING_FORMS.get(session_key)
|
||||
if not forms or not w_suffix:
|
||||
return None
|
||||
for token in reversed(forms):
|
||||
form = forms[token]
|
||||
if str(form.get('workflow_run_id', '')).endswith(w_suffix):
|
||||
return form
|
||||
return None
|
||||
|
||||
|
||||
def _get_latest_pending_form(session_key: str) -> dict[str, typing.Any] | None:
|
||||
_prune_pending_forms()
|
||||
forms = _PENDING_FORMS.get(session_key)
|
||||
if not forms:
|
||||
return None
|
||||
return forms[next(reversed(forms))]
|
||||
|
||||
|
||||
def _iter_pending_forms(session_key: str) -> typing.Iterator[dict[str, typing.Any]]:
|
||||
"""Iterate pending forms for a session, newest-first."""
|
||||
_prune_pending_forms()
|
||||
forms = _PENDING_FORMS.get(session_key)
|
||||
if not forms:
|
||||
return
|
||||
for token in reversed(list(forms.keys())):
|
||||
yield forms[token]
|
||||
|
||||
|
||||
def _clear_pending_form(session_key: str, form_token: str | None = None) -> None:
|
||||
"""Clear one specific pending form (by token) or all forms for the session."""
|
||||
forms = _PENDING_FORMS.get(session_key)
|
||||
if not forms:
|
||||
return
|
||||
if form_token is None:
|
||||
_PENDING_FORMS.pop(session_key, None)
|
||||
return
|
||||
forms.pop(form_token, None)
|
||||
if not forms:
|
||||
_PENDING_FORMS.pop(session_key, None)
|
||||
|
||||
|
||||
@runner.runner_class('dify-service-api')
|
||||
class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
"""Dify Service API 对话请求器"""
|
||||
@@ -433,140 +335,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
async def _submit_workflow_form_blocking(
|
||||
self, form_action: dict
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""Submit human input to resume a paused Dify workflow (non-streaming)."""
|
||||
|
||||
form_token = form_action['form_token']
|
||||
workflow_run_id = form_action['workflow_run_id']
|
||||
user = form_action['user']
|
||||
action_id = form_action.get('action_id', '')
|
||||
inputs = form_action.get('inputs', {})
|
||||
|
||||
async for chunk in self.dify_client.workflow_submit(
|
||||
form_token=form_token,
|
||||
workflow_run_id=workflow_run_id,
|
||||
inputs=inputs,
|
||||
user=user,
|
||||
action=action_id,
|
||||
timeout=120,
|
||||
):
|
||||
self.ap.logger.debug('dify-workflow-submit-chunk: ' + str(chunk))
|
||||
|
||||
if chunk['event'] == 'workflow_finished':
|
||||
if chunk['data'].get('error'):
|
||||
raise errors.DifyAPIError(chunk['data']['error'])
|
||||
content, _ = self._process_thinking_content(chunk['data']['outputs']['summary'])
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=content,
|
||||
)
|
||||
|
||||
def _resolve_pending_form(self, session_key: str, form_action: dict) -> dict | None:
|
||||
"""Locate the pending form this action targets.
|
||||
|
||||
Tries identifiers in order of specificity: form_token, full
|
||||
workflow_run_id, workflow_run_id suffix (Telegram-style compact id),
|
||||
then falls back to the newest pending form for the session.
|
||||
"""
|
||||
form_token = form_action.get('form_token')
|
||||
if form_token:
|
||||
form = _get_pending_form_by_token(session_key, form_token)
|
||||
if form:
|
||||
return form
|
||||
|
||||
workflow_run_id = form_action.get('workflow_run_id')
|
||||
if workflow_run_id:
|
||||
for form in _iter_pending_forms(session_key):
|
||||
if form.get('workflow_run_id') == workflow_run_id:
|
||||
return form
|
||||
|
||||
w_suffix = form_action.get('w_suffix')
|
||||
if w_suffix:
|
||||
form = _get_pending_form_by_w_suffix(session_key, w_suffix)
|
||||
if form:
|
||||
return form
|
||||
|
||||
return _get_latest_pending_form(session_key)
|
||||
|
||||
def _merge_pending_form_action(self, session_key: str, form_action: dict | None) -> dict | None:
|
||||
"""Backfill resume fields from the matching pending form."""
|
||||
if not form_action:
|
||||
return None
|
||||
|
||||
merged_action = dict(form_action)
|
||||
merged_action.pop('w_suffix', None)
|
||||
pending_form = self._resolve_pending_form(session_key, form_action)
|
||||
if pending_form:
|
||||
merged_action['form_token'] = merged_action.get('form_token') or pending_form.get('form_token', '')
|
||||
merged_action['workflow_run_id'] = merged_action.get('workflow_run_id') or pending_form.get(
|
||||
'workflow_run_id', ''
|
||||
)
|
||||
merged_action.setdefault('inputs', pending_form.get('inputs', {}))
|
||||
merged_action.setdefault('user', pending_form.get('user', ''))
|
||||
merged_action.setdefault('node_title', pending_form.get('node_title', ''))
|
||||
|
||||
# Resolve clicked action's display title from the stored actions list
|
||||
if 'action_title' not in merged_action:
|
||||
clicked_id = merged_action.get('action_id', '')
|
||||
for action in pending_form.get('actions', []):
|
||||
if str(action.get('id', '')) == str(clicked_id):
|
||||
merged_action['action_title'] = action.get('title', clicked_id)
|
||||
break
|
||||
|
||||
return merged_action
|
||||
|
||||
def _match_pending_form_action(self, session_key: str, user_text: str) -> dict | None:
|
||||
"""Match plain text replies against pending Dify form actions.
|
||||
|
||||
Iterates all pending forms newest-first; the first action whose
|
||||
title/id matches the text wins. This means when multiple forms are
|
||||
pending with the same button label, the most recent one resolves.
|
||||
"""
|
||||
normalized_text = user_text.strip().lower()
|
||||
if not normalized_text:
|
||||
return None
|
||||
|
||||
for pending_form in _iter_pending_forms(session_key):
|
||||
for action in pending_form.get('actions', []):
|
||||
titles = {
|
||||
str(action.get('title', '')).strip().lower(),
|
||||
str(action.get('id', '')).strip().lower(),
|
||||
}
|
||||
if normalized_text in titles:
|
||||
return {
|
||||
'form_token': pending_form.get('form_token', ''),
|
||||
'workflow_run_id': pending_form.get('workflow_run_id', ''),
|
||||
'action_id': action.get('id', ''),
|
||||
'action_title': action.get('title', action.get('id', '')),
|
||||
'node_title': pending_form.get('node_title', ''),
|
||||
'inputs': pending_form.get('inputs', {}),
|
||||
'user': pending_form.get('user', ''),
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
async def _workflow_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""调用工作流"""
|
||||
|
||||
# Check if this is a form action resume (button click or text match)
|
||||
form_action_raw = query.variables.get('_dify_form_action')
|
||||
session_key = _session_key_from_query(query)
|
||||
|
||||
if form_action_raw:
|
||||
form_action = self._merge_pending_form_action(session_key, form_action_raw)
|
||||
else:
|
||||
form_action = self._match_pending_form_action(session_key, str(query.message_chain))
|
||||
|
||||
if form_action:
|
||||
_clear_pending_form(session_key, form_action.get('form_token') or None)
|
||||
async for msg in self._submit_workflow_form_blocking(form_action):
|
||||
yield msg
|
||||
return
|
||||
|
||||
if not query.session.using_conversation.uuid:
|
||||
query.session.using_conversation.uuid = str(uuid.uuid4())
|
||||
|
||||
@@ -593,7 +366,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
}
|
||||
|
||||
inputs.update(query.variables)
|
||||
human_input_yielded = False
|
||||
|
||||
async for chunk in self.dify_client.workflow_run(
|
||||
inputs=inputs,
|
||||
@@ -605,46 +377,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
if chunk['event'] in ignored_events:
|
||||
continue
|
||||
|
||||
if chunk['event'] == 'workflow_paused':
|
||||
reasons = chunk['data'].get('reasons', [])
|
||||
workflow_run_id = chunk['data'].get('workflow_run_id', '')
|
||||
for reason in reasons:
|
||||
if reason.get('TYPE') == 'human_input_required':
|
||||
form_content = reason.get('form_content', '')
|
||||
actions = reason.get('actions', [])
|
||||
node_title = reason.get('node_title', '')
|
||||
|
||||
_set_pending_form(
|
||||
_session_key_from_query(query),
|
||||
{
|
||||
'workflow_run_id': workflow_run_id,
|
||||
'form_id': reason.get('form_id'),
|
||||
'form_token': reason.get('form_token'),
|
||||
'node_id': reason.get('node_id'),
|
||||
'node_title': node_title,
|
||||
'form_content': form_content,
|
||||
'inputs': reason.get('inputs', {}),
|
||||
'actions': actions,
|
||||
'expiration_time': reason.get('expiration_time'),
|
||||
'user': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
},
|
||||
)
|
||||
|
||||
query.variables['_dify_form_render'] = {
|
||||
'form_content': form_content,
|
||||
'actions': actions,
|
||||
'node_title': node_title,
|
||||
}
|
||||
|
||||
action_lines = '\n'.join(f'- [{a.get("title", a.get("id", ""))}]' for a in actions)
|
||||
display_text = f'[Human Input Required] {node_title}\n{form_content}\n{action_lines}'
|
||||
|
||||
human_input_yielded = True
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=display_text,
|
||||
)
|
||||
|
||||
if chunk['event'] == 'node_started':
|
||||
if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end':
|
||||
continue
|
||||
@@ -667,8 +399,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
yield msg
|
||||
|
||||
elif chunk['event'] == 'workflow_finished':
|
||||
if human_input_yielded:
|
||||
break
|
||||
if chunk['data']['error']:
|
||||
raise errors.DifyAPIError(chunk['data']['error'])
|
||||
content, _ = self._process_thinking_content(chunk['data']['outputs']['summary'])
|
||||
@@ -906,153 +636,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
async def _submit_workflow_form(
|
||||
self, form_action: dict
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""Submit human input to resume a paused Dify workflow."""
|
||||
|
||||
form_token = form_action['form_token']
|
||||
workflow_run_id = form_action['workflow_run_id']
|
||||
user = form_action['user']
|
||||
action_id = form_action.get('action_id', '')
|
||||
action_title = form_action.get('action_title', '') or action_id
|
||||
node_title = form_action.get('node_title', '')
|
||||
inputs = form_action.get('inputs', {})
|
||||
|
||||
messsage_idx = 0
|
||||
is_final = False
|
||||
think_start = False
|
||||
think_end = False
|
||||
workflow_contents = ''
|
||||
repause_form_data: dict | None = None
|
||||
|
||||
remove_think = self.pipeline_config['output'].get('misc', {}).get('remove-think')
|
||||
async for chunk in self.dify_client.workflow_submit(
|
||||
form_token=form_token,
|
||||
workflow_run_id=workflow_run_id,
|
||||
inputs=inputs,
|
||||
user=user,
|
||||
action=action_id,
|
||||
timeout=120,
|
||||
):
|
||||
self.ap.logger.debug('dify-workflow-submit-chunk: ' + str(chunk))
|
||||
|
||||
yield_this_iteration = False
|
||||
|
||||
if chunk['event'] == 'workflow_finished':
|
||||
is_final = True
|
||||
yield_this_iteration = True
|
||||
if chunk['data'].get('error'):
|
||||
raise errors.DifyAPIError(chunk['data']['error'])
|
||||
|
||||
if chunk['event'] == 'workflow_paused':
|
||||
reasons = chunk['data'].get('reasons', [])
|
||||
new_run_id = chunk['data'].get('workflow_run_id', workflow_run_id)
|
||||
for reason in reasons:
|
||||
if reason.get('TYPE') != 'human_input_required':
|
||||
continue
|
||||
form_content = reason.get('form_content', '')
|
||||
actions = reason.get('actions', [])
|
||||
# Use a distinct name — `node_title` (the just-resolved step)
|
||||
# must keep its value so the resume notice on the previous
|
||||
# card still shows which step the user acted on.
|
||||
paused_node_title = reason.get('node_title', '')
|
||||
raw_inputs = reason.get('inputs', {})
|
||||
|
||||
_set_pending_form(
|
||||
user,
|
||||
{
|
||||
'workflow_run_id': new_run_id,
|
||||
'form_id': reason.get('form_id'),
|
||||
'form_token': reason.get('form_token'),
|
||||
'node_id': reason.get('node_id'),
|
||||
'node_title': paused_node_title,
|
||||
'form_content': form_content,
|
||||
'inputs': raw_inputs if isinstance(raw_inputs, dict) else {},
|
||||
'actions': actions,
|
||||
'expiration_time': reason.get('expiration_time'),
|
||||
'user': user,
|
||||
},
|
||||
)
|
||||
|
||||
repause_form_data = {
|
||||
'form_content': form_content,
|
||||
'actions': actions,
|
||||
'node_title': paused_node_title,
|
||||
'workflow_run_id': new_run_id,
|
||||
'form_token': reason.get('form_token', ''),
|
||||
}
|
||||
# Ensure the final chunk has non-empty content so
|
||||
# ResponseWrapper (which skips empty-content chunks) lets it
|
||||
# propagate to SendResponseBackStage. Use a zero-width space
|
||||
# so neither Lark nor Telegram renders visible noise — the
|
||||
# adapter substitutes its own card text from _form_data.
|
||||
if not workflow_contents:
|
||||
workflow_contents = ''
|
||||
is_final = True
|
||||
yield_this_iteration = True
|
||||
break
|
||||
|
||||
if chunk['event'] == 'text_chunk':
|
||||
messsage_idx += 1
|
||||
if remove_think:
|
||||
if '<think>' in chunk['data']['text'] and not think_start:
|
||||
think_start = True
|
||||
continue
|
||||
if '</think>' in chunk['data']['text'] and not think_end:
|
||||
import re
|
||||
|
||||
content = re.sub(r'^\n</think>', '', chunk['data']['text'])
|
||||
workflow_contents += content
|
||||
think_end = True
|
||||
elif think_end:
|
||||
workflow_contents += chunk['data']['text']
|
||||
if think_start:
|
||||
continue
|
||||
else:
|
||||
workflow_contents += chunk['data']['text']
|
||||
if messsage_idx % 8 == 0:
|
||||
yield_this_iteration = True
|
||||
|
||||
if yield_this_iteration:
|
||||
msg = provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=workflow_contents,
|
||||
is_final=is_final,
|
||||
)
|
||||
msg._resume_from_form = True
|
||||
if action_title:
|
||||
msg._resume_action_title = action_title
|
||||
if node_title:
|
||||
msg._resume_node_title = node_title
|
||||
if is_final and repause_form_data:
|
||||
msg._form_data = repause_form_data
|
||||
msg._open_new_card = True
|
||||
yield msg
|
||||
if is_final:
|
||||
return
|
||||
|
||||
async def _workflow_messages_chunk(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""调用工作流"""
|
||||
|
||||
# Check if this is a form action resume (button click or text match)
|
||||
form_action_raw = query.variables.get('_dify_form_action')
|
||||
session_key = _session_key_from_query(query)
|
||||
|
||||
if form_action_raw:
|
||||
form_action = self._merge_pending_form_action(session_key, form_action_raw)
|
||||
else:
|
||||
form_action = self._match_pending_form_action(session_key, str(query.message_chain))
|
||||
|
||||
if form_action:
|
||||
_clear_pending_form(session_key, form_action.get('form_token') or None)
|
||||
# Resume paused workflow via submit endpoint
|
||||
async for msg in self._submit_workflow_form(form_action):
|
||||
yield msg
|
||||
return
|
||||
|
||||
if not query.session.using_conversation.uuid:
|
||||
query.session.using_conversation.uuid = str(uuid.uuid4())
|
||||
|
||||
@@ -1084,13 +672,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
think_start = False
|
||||
think_end = False
|
||||
workflow_contents = ''
|
||||
workflow_run_id = ''
|
||||
human_input_yielded = False
|
||||
|
||||
# Saved form data to attach to the final MessageChunk so the adapter
|
||||
# can detect it when is_final=True and render buttons.
|
||||
pending_form_data = None
|
||||
display_text = ''
|
||||
|
||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
async for chunk in self.dify_client.workflow_run(
|
||||
@@ -1101,62 +682,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
):
|
||||
self.ap.logger.debug('dify-workflow-chunk: ' + str(chunk))
|
||||
if chunk['event'] in ignored_events:
|
||||
if chunk['event'] == 'workflow_started':
|
||||
workflow_run_id = chunk['data'].get('workflow_run_id', '')
|
||||
continue
|
||||
|
||||
if chunk['event'] == 'workflow_paused':
|
||||
reasons = chunk['data'].get('reasons', [])
|
||||
workflow_run_id = chunk['data'].get('workflow_run_id', workflow_run_id)
|
||||
for reason in reasons:
|
||||
if reason.get('TYPE') == 'human_input_required':
|
||||
form_content = reason.get('form_content', '')
|
||||
actions = reason.get('actions', [])
|
||||
node_title = reason.get('node_title', '')
|
||||
|
||||
# Persist form state in module-level store keyed by session
|
||||
raw_inputs = reason.get('inputs', {})
|
||||
_set_pending_form(
|
||||
_session_key_from_query(query),
|
||||
{
|
||||
'workflow_run_id': workflow_run_id,
|
||||
'form_id': reason.get('form_id'),
|
||||
'form_token': reason.get('form_token'),
|
||||
'node_id': reason.get('node_id'),
|
||||
'node_title': node_title,
|
||||
'form_content': form_content,
|
||||
'inputs': raw_inputs if isinstance(raw_inputs, dict) else {},
|
||||
'actions': actions,
|
||||
'expiration_time': reason.get('expiration_time'),
|
||||
'user': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
},
|
||||
)
|
||||
|
||||
# Pass form render metadata to downstream stages
|
||||
query.variables['_dify_form_render'] = {
|
||||
'form_content': form_content,
|
||||
'actions': actions,
|
||||
'node_title': node_title,
|
||||
}
|
||||
|
||||
action_lines = '\n'.join(f'- [{a.get("title", a.get("id", ""))}]' for a in actions)
|
||||
display_text = f'[Human Input Required] {node_title}\n{form_content}\n{action_lines}'
|
||||
workflow_contents += display_text + '\n'
|
||||
|
||||
# Save form data to attach to the final chunk later.
|
||||
# We do NOT yield here — the form content will be sent
|
||||
# as the final MessageChunk (with is_final=True and
|
||||
# _form_data) so the adapter can update the card and
|
||||
# add buttons in one pass.
|
||||
pending_form_data = {
|
||||
'form_content': form_content,
|
||||
'actions': actions,
|
||||
'node_title': node_title,
|
||||
'workflow_run_id': workflow_run_id,
|
||||
'form_token': reason.get('form_token', ''),
|
||||
}
|
||||
human_input_yielded = True
|
||||
|
||||
if chunk['event'] == 'workflow_finished':
|
||||
is_final = True
|
||||
if chunk['data']['error']:
|
||||
@@ -1204,29 +730,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
yield msg
|
||||
|
||||
if messsage_idx % 8 == 0 or is_final:
|
||||
final_content = workflow_contents if workflow_contents.strip() else ''
|
||||
msg = provider_message.MessageChunk(
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=final_content,
|
||||
content=workflow_contents,
|
||||
is_final=is_final,
|
||||
)
|
||||
# Attach form data to the final chunk for the adapter
|
||||
if is_final and pending_form_data:
|
||||
msg._form_data = pending_form_data
|
||||
pending_form_data = None
|
||||
yield msg
|
||||
|
||||
# If the stream ended after workflow_paused without a
|
||||
# workflow_finished event, yield a final chunk so the adapter
|
||||
# can update the card and add buttons.
|
||||
if human_input_yielded and not is_final:
|
||||
msg = provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=workflow_contents or display_text,
|
||||
is_final=True,
|
||||
)
|
||||
msg._form_data = pending_form_data
|
||||
yield msg
|
||||
|
||||
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""运行请求"""
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import posixpath
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langbot.pkg.core import app
|
||||
from typing import Any
|
||||
from langbot.pkg.core import app
|
||||
|
||||
|
||||
class RAGRuntimeService:
|
||||
@@ -113,17 +109,8 @@ class RAGRuntimeService:
|
||||
regardless of the underlying storage provider.
|
||||
"""
|
||||
# Validate storage_path to prevent path traversal
|
||||
decoded_path = unquote(storage_path).replace('\\', '/')
|
||||
decoded_segments = decoded_path.split('/')
|
||||
normalized = posixpath.normpath(decoded_path)
|
||||
if (
|
||||
not storage_path
|
||||
or '\x00' in decoded_path
|
||||
or normalized.startswith('/')
|
||||
or '..' in decoded_segments
|
||||
or '..' in normalized.split('/')
|
||||
or re.match(r'^[A-Za-z]:/', normalized)
|
||||
):
|
||||
normalized = posixpath.normpath(storage_path)
|
||||
if normalized.startswith('/') or '..' in normalized.split('/'):
|
||||
raise ValueError('Invalid storage path')
|
||||
content_bytes = await self.ap.storage_mgr.storage_provider.load(normalized)
|
||||
return content_bytes if content_bytes else b''
|
||||
|
||||
@@ -13,11 +13,12 @@ class TelemetryManager:
|
||||
await telemetry.send({ ... })
|
||||
"""
|
||||
|
||||
send_tasks: list[asyncio.Task] = []
|
||||
|
||||
def __init__(self, ap: core_app.Application):
|
||||
self.ap = ap
|
||||
|
||||
self.telemetry_config = {}
|
||||
self.send_tasks: list[asyncio.Task] = []
|
||||
|
||||
async def initialize(self):
|
||||
self.telemetry_config = self.ap.instance_config.data.get('space', {})
|
||||
|
||||
@@ -83,7 +83,7 @@ def get_func_schema(function: typing.Callable) -> dict:
|
||||
|
||||
parameters['properties'][param.name] = {
|
||||
'type': param_type,
|
||||
'description': args_doc.get(param.name, ''),
|
||||
'description': args_doc[param.name],
|
||||
}
|
||||
|
||||
# add schema for array
|
||||
|
||||
@@ -145,8 +145,7 @@ def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]:
|
||||
"""获取QQ图片的下载链接"""
|
||||
parsed = urlparse(image_url)
|
||||
query = parse_qs(parsed.query)
|
||||
scheme = parsed.scheme or 'http'
|
||||
return f'{scheme}://{parsed.netloc}{parsed.path}', query
|
||||
return f'http://{parsed.netloc}{parsed.path}', query
|
||||
|
||||
|
||||
async def get_qq_image_bytes(image_url: str, query: dict = {}) -> tuple[bytes, str]:
|
||||
|
||||
@@ -23,10 +23,7 @@ def run_pip(params: list):
|
||||
pipmain(params)
|
||||
|
||||
|
||||
def install_requirements(file, extra_params: list | None = None):
|
||||
if extra_params is None:
|
||||
extra_params = []
|
||||
|
||||
def install_requirements(file, extra_params: list = []):
|
||||
pipmain(
|
||||
[
|
||||
'install',
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
@@ -46,40 +44,6 @@ LOCAL_PATTERNS = [
|
||||
'172.31.',
|
||||
]
|
||||
|
||||
HOST_LABEL_PATTERN = re.compile(r'^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?$')
|
||||
|
||||
|
||||
def _is_valid_hostname(host: str) -> bool:
|
||||
if host == 'localhost':
|
||||
return True
|
||||
|
||||
try:
|
||||
ipaddress.ip_address(host)
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if not host or len(host) > 253 or any(char.isspace() for char in host):
|
||||
return False
|
||||
|
||||
host = host.rstrip('.')
|
||||
if not host:
|
||||
return False
|
||||
|
||||
return all(HOST_LABEL_PATTERN.match(label) for label in host.split('.'))
|
||||
|
||||
|
||||
def _is_local_host(host: str) -> bool:
|
||||
if host == 'localhost':
|
||||
return True
|
||||
|
||||
try:
|
||||
ip_address = ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
return ip_address.is_private or ip_address.is_loopback or ip_address.is_unspecified
|
||||
|
||||
|
||||
def get_runner_category(runner_name: str, runner_url: str) -> str:
|
||||
if not runner_url:
|
||||
@@ -88,15 +52,12 @@ def get_runner_category(runner_name: str, runner_url: str) -> str:
|
||||
try:
|
||||
parsed_url = urlparse(runner_url)
|
||||
host = parsed_url.hostname.lower() if parsed_url.hostname else ''
|
||||
_ = parsed_url.port
|
||||
except Exception:
|
||||
return RunnerCategory.UNKNOWN
|
||||
|
||||
if not parsed_url.scheme or not host or not _is_valid_hostname(host):
|
||||
return RunnerCategory.UNKNOWN
|
||||
|
||||
if _is_local_host(host):
|
||||
return RunnerCategory.LOCAL
|
||||
for pattern in LOCAL_PATTERNS:
|
||||
if host.startswith(pattern):
|
||||
return RunnerCategory.LOCAL
|
||||
|
||||
for domain in CLOUD_DOMAINS:
|
||||
if host.endswith(domain):
|
||||
|
||||
284
tests/README.md
284
tests/README.md
@@ -2,48 +2,6 @@
|
||||
|
||||
This directory contains the test suite for LangBot, with a focus on comprehensive unit testing of pipeline stages.
|
||||
|
||||
## Quality Gate Layers
|
||||
|
||||
LangBot uses a layered quality gate system for developers and CI:
|
||||
|
||||
| Layer | Command | What it runs | When to use |
|
||||
|-------|---------|--------------|-------------|
|
||||
| **Quick** | `make test-quick` or `bash scripts/test-quick.sh` | Ruff lint + Unit tests + Smoke tests | Before every commit |
|
||||
| **Fast Integration** | `make test-integration-fast` or `bash scripts/test-integration-fast.sh` | SQLite/API/Pipeline integration (no external services) | Before PR, weekly |
|
||||
| **Coverage Gate** | `make test-coverage` or `bash scripts/test-coverage.sh` | All tests with coverage, threshold: 18% | Before merge, CI |
|
||||
| **Full Local** | `make test-all-local` | Quick + Integration + Coverage | Before major changes |
|
||||
|
||||
**Note**: PostgreSQL migration tests and slow tests are NOT in local default gates. They run in separate CI workflows.
|
||||
|
||||
### Developer Workflow
|
||||
|
||||
```bash
|
||||
# Daily: Quick self-test
|
||||
bash scripts/test-quick.sh
|
||||
|
||||
# Before PR: Full local gate
|
||||
make test-all-local
|
||||
|
||||
# Or run each layer separately:
|
||||
bash scripts/test-quick.sh # ~2 min
|
||||
bash scripts/test-integration-fast.sh # ~3 min
|
||||
bash scripts/test-coverage.sh # ~8 min
|
||||
```
|
||||
|
||||
### Coverage Baseline
|
||||
|
||||
Current coverage threshold: **18%**
|
||||
Actual coverage: **30%**
|
||||
|
||||
This is a conservative baseline to prevent coverage regression. It does NOT represent the final quality target. Key modules have higher coverage:
|
||||
- `pipeline.preproc.preproc`: 53%
|
||||
- `pipeline.process.process`: 96%
|
||||
- `pipeline.respback.respback`: 88%
|
||||
- `telemetry.telemetry`: 87%
|
||||
- `provider.session.sessionmgr`: 100%
|
||||
- `provider.tools.toolmgr`: 83%
|
||||
- `storage.providers.s3storage`: 80%
|
||||
|
||||
## Important Note
|
||||
|
||||
Due to circular import dependencies in the pipeline module structure, the test files use **lazy imports** via `importlib.import_module()` instead of direct imports. This ensures tests can run without triggering circular import errors.
|
||||
@@ -52,81 +10,19 @@ Due to circular import dependencies in the pipeline module structure, the test f
|
||||
|
||||
```
|
||||
tests/
|
||||
├── __init__.py
|
||||
├── factories/ # Shared test factories
|
||||
│ ├── __init__.py # Factory exports
|
||||
│ ├── app.py # FakeApp factory
|
||||
│ ├── message.py # Message/query factories
|
||||
│ ├── provider.py # FakeProvider factory
|
||||
│ └── platform.py # FakePlatform factory
|
||||
├── integration/ # Integration tests (real resources)
|
||||
│ ├── __init__.py
|
||||
│ ├── api/ # HTTP API tests
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── test_smoke.py # API smoke tests
|
||||
│ ├── pipeline/ # Pipeline stage-chain tests
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── test_full_flow.py # Full flow integration
|
||||
│ └── persistence/ # Database/persistence tests
|
||||
│ ├── __init__.py
|
||||
│ └── test_migrations.py # Alembic migration tests
|
||||
├── smoke/ # Smoke tests (quick validation)
|
||||
│ └── test_fake_message_flow.py
|
||||
├── unit_tests/ # Unit tests
|
||||
│ ├── box/ # Box module tests
|
||||
│ ├── config/ # Configuration tests
|
||||
│ ├── pipeline/ # Pipeline stage tests
|
||||
│ │ └── conftest.py # Shared fixtures and test infrastructure
|
||||
│ ├── platform/ # Platform adapter tests
|
||||
│ ├── plugin/ # Plugin system tests
|
||||
│ │ └── test_handler_actions.py # Action handler tests
|
||||
│ ├── provider/ # Provider tests
|
||||
│ │ ├── test_session_manager.py # SessionManager tests
|
||||
│ │ └── test_tool_manager.py # ToolManager tests
|
||||
│ ├── rag/ # RAG tests
|
||||
│ │ └── test_file_storage.py # File/ZIP storage tests
|
||||
│ ├── storage/ # Storage tests
|
||||
│ │ └── test_s3storage.py # S3StorageProvider tests
|
||||
│ ├── vector/ # Vector tests
|
||||
│ │ └── test_vdb_filter_conversion.py # VDB filter tests
|
||||
│ └── telemetry/ # Telemetry tests (rewritten)
|
||||
├── utils/ # Test utilities
|
||||
│ ├── __init__.py
|
||||
│ └── import_isolation.py # sys.modules isolation for circular imports
|
||||
└── README.md # This file
|
||||
├── pipeline/ # Pipeline stage tests
|
||||
│ ├── conftest.py # Shared fixtures and test infrastructure
|
||||
│ ├── test_simple.py # Basic infrastructure tests (always pass)
|
||||
│ ├── test_bansess.py # BanSessionCheckStage tests
|
||||
│ ├── test_ratelimit.py # RateLimit stage tests
|
||||
│ ├── test_preproc.py # PreProcessor stage tests
|
||||
│ ├── test_respback.py # SendResponseBackStage tests
|
||||
│ ├── test_resprule.py # GroupRespondRuleCheckStage tests
|
||||
│ ├── test_pipelinemgr.py # PipelineManager tests
|
||||
│ └── test_stages_integration.py # Integration tests
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Test Factories
|
||||
|
||||
The `tests/factories/` package provides reusable test factories:
|
||||
|
||||
```python
|
||||
from tests.factories import (
|
||||
FakeApp, # Mock application
|
||||
FakeProvider, # Fake LLM provider
|
||||
FakePlatform, # Fake platform adapter
|
||||
text_query, # Create text query
|
||||
group_text_query, # Create group query
|
||||
command_query, # Create command query
|
||||
)
|
||||
|
||||
# Create fake app
|
||||
app = FakeApp()
|
||||
|
||||
# Create query with text
|
||||
query = text_query("hello world")
|
||||
|
||||
# Create fake provider that returns specific response
|
||||
provider = FakeProvider().returns("test response")
|
||||
|
||||
# Create fake platform for outbound capture
|
||||
platform = FakePlatform()
|
||||
await platform.reply_message(query.message_event, reply_chain)
|
||||
outbound = platform.get_outbound_messages()
|
||||
```
|
||||
|
||||
See `tests/factories/__init__.py` for all available factories.
|
||||
|
||||
## Test Architecture
|
||||
|
||||
### Fixtures (`conftest.py`)
|
||||
@@ -147,28 +43,7 @@ The test suite uses a centralized fixture system that provides:
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Quick self-test for developers
|
||||
|
||||
For local branch validation without real provider keys:
|
||||
|
||||
```bash
|
||||
make test-quick
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```bash
|
||||
bash scripts/test-quick.sh
|
||||
```
|
||||
|
||||
This runs:
|
||||
1. Ruff lint check
|
||||
2. Unit tests
|
||||
3. Smoke tests
|
||||
|
||||
Suitable for quick validation before committing.
|
||||
|
||||
### Using the test runner script (recommended for full coverage)
|
||||
### Using the test runner script (recommended)
|
||||
```bash
|
||||
bash run_tests.sh
|
||||
```
|
||||
@@ -181,135 +56,38 @@ This script automatically:
|
||||
|
||||
### Manual test execution
|
||||
|
||||
#### Run all unit tests
|
||||
#### Run all tests
|
||||
```bash
|
||||
uv run pytest tests/unit_tests/ --cov=langbot --cov-report=xml --cov-report=term
|
||||
pytest tests/pipeline/
|
||||
```
|
||||
|
||||
#### Run specific test module
|
||||
#### Run only simple tests (no imports, always pass)
|
||||
```bash
|
||||
uv run pytest tests/unit_tests/pipeline/ -v
|
||||
pytest tests/pipeline/test_simple.py -v
|
||||
```
|
||||
|
||||
#### Run specific test file
|
||||
```bash
|
||||
uv run pytest tests/unit_tests/pipeline/test_bansess.py -v
|
||||
pytest tests/pipeline/test_bansess.py -v
|
||||
```
|
||||
|
||||
#### Run with coverage
|
||||
```bash
|
||||
uv run pytest tests/unit_tests/pipeline/ --cov=langbot --cov-report=html
|
||||
pytest tests/pipeline/ --cov=pkg/pipeline --cov-report=html
|
||||
```
|
||||
|
||||
#### Run specific test
|
||||
```bash
|
||||
uv run pytest tests/unit_tests/pipeline/test_bansess.py::test_bansess_whitelist_allow -v
|
||||
pytest tests/pipeline/test_bansess.py::test_bansess_whitelist_allow -v
|
||||
```
|
||||
|
||||
### Using markers
|
||||
|
||||
```bash
|
||||
# Run only unit tests
|
||||
uv run pytest tests/unit_tests/ -m unit
|
||||
|
||||
# Run only integration tests
|
||||
uv run pytest tests/integration/ -m integration
|
||||
|
||||
# Run integration tests excluding slow ones
|
||||
uv run pytest tests/integration/ -m "not slow" -q
|
||||
|
||||
# Skip slow tests
|
||||
uv run pytest tests/unit_tests/ -m "not slow"
|
||||
```
|
||||
|
||||
### Running integration tests
|
||||
|
||||
Integration tests validate real system behavior with actual database/network resources.
|
||||
|
||||
```bash
|
||||
# Run all integration tests (excluding slow ones)
|
||||
uv run pytest tests/integration/ -m "not slow" -q
|
||||
|
||||
# Run SQLite migration integration tests
|
||||
uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
|
||||
|
||||
# Run API smoke integration tests
|
||||
uv run pytest tests/integration/api/test_smoke.py -q
|
||||
|
||||
# Run pipeline full-flow integration tests
|
||||
uv run pytest tests/integration/pipeline/test_full_flow.py -q
|
||||
|
||||
# Run with verbose output
|
||||
uv run pytest tests/integration/ -v
|
||||
```
|
||||
|
||||
Note: Integration tests use:
|
||||
- Temporary databases (tmp_path) for persistence tests
|
||||
- Fake app/services for API tests (no real provider/platform)
|
||||
- Fake runner/provider for pipeline tests (no real LLM API)
|
||||
- Do not require external services
|
||||
|
||||
### Running migration tests locally
|
||||
|
||||
SQLite migration tests can be run locally without any external dependencies:
|
||||
|
||||
```bash
|
||||
# SQLite migration tests (uses tmp_path, no external DB needed)
|
||||
uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
|
||||
```
|
||||
|
||||
PostgreSQL migration tests require an external PostgreSQL database:
|
||||
|
||||
```bash
|
||||
# PostgreSQL migration tests (requires PostgreSQL service)
|
||||
# Tests are marked as slow and skipped if TEST_POSTGRES_URL is not set
|
||||
TEST_POSTGRES_URL=postgresql+asyncpg://user:pass@localhost:5432/test_db \
|
||||
uv run pytest tests/integration/persistence/test_migrations_postgres.py -q --tb=short
|
||||
|
||||
# Or skip by default (no PostgreSQL available)
|
||||
uv run pytest tests/integration/persistence/test_migrations_postgres.py -q --tb=short
|
||||
# Output: SKIPPED (TEST_POSTGRES_URL not set)
|
||||
```
|
||||
|
||||
Note: PostgreSQL tests are **not** included in fast integration gate because they:
|
||||
- Require external PostgreSQL service
|
||||
- Are marked with `@pytest.mark.slow`
|
||||
- Need `TEST_POSTGRES_URL` environment variable
|
||||
|
||||
CI workflow `.github/workflows/test-migrations.yml` runs:
|
||||
- SQLite tests in `test-migrations-sqlite` job (fast, no external services)
|
||||
- PostgreSQL tests in `test-migrations-postgres` job (uses PostgreSQL service container)
|
||||
|
||||
### Running pipeline integration tests locally
|
||||
|
||||
Pipeline full-flow integration tests validate real stage interactions:
|
||||
|
||||
```bash
|
||||
# Run pipeline integration tests (uses fake runner, no real LLM API)
|
||||
uv run pytest tests/integration/pipeline/test_full_flow.py -q --tb=short
|
||||
|
||||
# Run with coverage for pipeline modules
|
||||
uv run pytest tests/integration/pipeline \
|
||||
--cov=langbot.pkg.pipeline.preproc.preproc \
|
||||
--cov=langbot.pkg.pipeline.process.process \
|
||||
--cov=langbot.pkg.pipeline.respback.respback \
|
||||
--cov-report=term -q
|
||||
```
|
||||
|
||||
These tests:
|
||||
- Use `FakeRunner` class to simulate LLM responses without real API calls
|
||||
- Import real `PreProcessor`, `MessageProcessor`, `SendResponseBackStage` stages
|
||||
- Validate stage chain: PreProcessor → Processor → SendResponseBackStage
|
||||
- Test prevent_default, exception handling, and full message flow
|
||||
- Do not require real LLM provider keys
|
||||
|
||||
### Known Issues
|
||||
|
||||
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
|
||||
|
||||
1. Make sure you're running from the project root directory
|
||||
2. Ensure dependencies are installed: `uv sync --dev`
|
||||
3. Try running a simple test first to verify the test infrastructure works
|
||||
2. Ensure the virtual environment is activated
|
||||
3. Try running `test_simple.py` first to verify the test infrastructure works
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
@@ -319,7 +97,7 @@ Tests are automatically run on:
|
||||
- Push to PR branch
|
||||
- Push to master/develop branches
|
||||
|
||||
The workflow runs tests on Python 3.11, 3.12, and 3.13 to ensure compatibility.
|
||||
The workflow runs tests on Python 3.10, 3.11, and 3.12 to ensure compatibility.
|
||||
|
||||
## Adding New Tests
|
||||
|
||||
@@ -333,8 +111,8 @@ Create a new test file `test_<stage_name>.py`:
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from langbot.pkg.pipeline.<module>.<stage> import <StageClass>
|
||||
from langbot.pkg.pipeline import entities as pipeline_entities
|
||||
from pkg.pipeline.<module>.<stage> import <StageClass>
|
||||
from pkg.pipeline import entities as pipeline_entities
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -350,7 +128,7 @@ async def test_stage_basic_flow(mock_app, sample_query):
|
||||
|
||||
### 2. For additional fixtures
|
||||
|
||||
Add new fixtures to the appropriate `conftest.py`:
|
||||
Add new fixtures to `conftest.py`:
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
@@ -364,7 +142,7 @@ def my_custom_fixture():
|
||||
Use the helper functions in `conftest.py`:
|
||||
|
||||
```python
|
||||
from tests.unit_tests.pipeline.conftest import create_stage_result, assert_result_continue
|
||||
from tests.pipeline.conftest import create_stage_result, assert_result_continue
|
||||
|
||||
result = create_stage_result(
|
||||
result_type=pipeline_entities.ResultType.CONTINUE,
|
||||
@@ -388,7 +166,7 @@ assert_result_continue(result)
|
||||
### Import errors
|
||||
Make sure you've installed the package in development mode:
|
||||
```bash
|
||||
uv sync --dev
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
### Async test failures
|
||||
@@ -399,11 +177,7 @@ Check that you're mocking at the right level and using `AsyncMock` for async fun
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- [x] Add integration tests for database migrations (SQLite)
|
||||
- [x] Add PostgreSQL migration integration tests (G-003)
|
||||
- [x] Add integration tests for full pipeline execution
|
||||
- [x] Add API smoke integration tests
|
||||
- [ ] Add E2E tests
|
||||
- [ ] Add integration tests for full pipeline execution
|
||||
- [ ] Add performance benchmarks
|
||||
- [ ] Add mutation testing for better coverage quality
|
||||
- [ ] Add property-based testing with Hypothesis
|
||||
- [ ] Add property-based testing with Hypothesis
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
"""E2E test fixtures.
|
||||
|
||||
Provides fixtures for starting real LangBot process with minimal configuration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import shutil
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from tests.e2e.utils.config_factory import create_minimal_config, create_test_directories
|
||||
from tests.e2e.utils.process_manager import LangBotProcess, find_project_root
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
pytestmark = pytest.mark.e2e
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def e2e_port():
|
||||
"""Port for E2E testing (non-default to avoid conflicts)."""
|
||||
return 15300
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def e2e_tmpdir():
|
||||
"""Create temporary directory for E2E testing."""
|
||||
tmpdir = Path(tempfile.mkdtemp(prefix='langbot_e2e_'))
|
||||
logger.info(f'E2E tmpdir: {tmpdir}')
|
||||
|
||||
yield tmpdir
|
||||
|
||||
# Cleanup
|
||||
logger.info(f'Cleaning up E2E tmpdir: {tmpdir}')
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def e2e_config_path(e2e_tmpdir, e2e_port):
|
||||
"""Create minimal config.yaml for E2E testing."""
|
||||
config_path = create_minimal_config(e2e_tmpdir, port=e2e_port)
|
||||
create_test_directories(e2e_tmpdir)
|
||||
logger.info(f'E2E config: {config_path}')
|
||||
return config_path
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def langbot_process(e2e_config_path, e2e_port, e2e_tmpdir):
|
||||
"""Start real LangBot process for E2E testing.
|
||||
|
||||
This fixture starts LangBot once per session and reuses it for all tests.
|
||||
Coverage data is collected from the subprocess.
|
||||
"""
|
||||
project_root = find_project_root()
|
||||
collect_coverage = True
|
||||
|
||||
proc = LangBotProcess(
|
||||
project_root=project_root,
|
||||
work_dir=e2e_tmpdir, # Run in tmpdir where data/config.yaml exists
|
||||
port=e2e_port,
|
||||
timeout=60, # Longer timeout for first startup
|
||||
collect_coverage=collect_coverage,
|
||||
)
|
||||
|
||||
success = proc.start()
|
||||
if not success:
|
||||
stdout, stderr = proc.get_logs()
|
||||
pytest.fail(f'LangBot failed to start:\nstdout: {stdout}\nstderr: {stderr}')
|
||||
|
||||
yield proc
|
||||
|
||||
# Cleanup
|
||||
proc.stop()
|
||||
|
||||
# Combine coverage data if collected
|
||||
if collect_coverage and proc.get_coverage_file():
|
||||
coverage_file = proc.get_coverage_file()
|
||||
if coverage_file.exists():
|
||||
# Copy coverage data to project root for combining
|
||||
target = project_root / '.coverage.e2e'
|
||||
shutil.copy(coverage_file, target)
|
||||
logger.info(f'Coverage data saved to: {target}')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_client(e2e_port, langbot_process):
|
||||
"""HTTP client for E2E testing."""
|
||||
import httpx
|
||||
|
||||
base_url = f'http://127.0.0.1:{e2e_port}'
|
||||
|
||||
with httpx.Client(base_url=base_url, timeout=10.0) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def e2e_db_path(e2e_tmpdir):
|
||||
"""Path to SQLite database file."""
|
||||
return e2e_tmpdir / 'data' / 'langbot.db'
|
||||
@@ -1,142 +0,0 @@
|
||||
"""E2E tests for LangBot startup flow.
|
||||
|
||||
Tests the complete startup process including:
|
||||
- boot.py startup orchestration
|
||||
- stages/ (build_app, load_config, migrate, etc.)
|
||||
- database initialization
|
||||
- API availability
|
||||
|
||||
Run: uv run pytest tests/e2e/test_startup.py -v -m e2e
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.e2e
|
||||
|
||||
|
||||
class TestStartupFlow:
|
||||
"""Tests for LangBot startup process."""
|
||||
|
||||
def test_process_is_running(self, langbot_process):
|
||||
"""Verify LangBot process is running."""
|
||||
assert langbot_process.is_running()
|
||||
|
||||
def test_health_check(self, langbot_process, e2e_port):
|
||||
"""Verify LangBot API is responding."""
|
||||
assert langbot_process.health_check()
|
||||
|
||||
def test_system_info_endpoint(self, e2e_client):
|
||||
"""Test /api/v1/system/info endpoint."""
|
||||
response = e2e_client.get('/api/v1/system/info')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data['code'] == 0
|
||||
assert 'data' in data
|
||||
# System info should contain version info
|
||||
assert 'version' in data['data'] or 'edition' in data['data']
|
||||
|
||||
def test_database_initialized(self, e2e_db_path):
|
||||
"""Verify SQLite database was created and initialized."""
|
||||
assert e2e_db_path.exists()
|
||||
|
||||
# Database should have some tables after migration
|
||||
import sqlite3
|
||||
conn = sqlite3.connect(str(e2e_db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check that core tables exist
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
# Core tables should be created by Alembic migrations
|
||||
# Note: table names may differ (legacy_pipelines instead of pipelines)
|
||||
expected_tables = ['legacy_pipelines', 'bots', 'model_providers', 'llm_models']
|
||||
for table in expected_tables:
|
||||
assert table in tables, f'Table {table} should exist. Available: {tables}'
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_chroma_directory_created(self, e2e_tmpdir):
|
||||
"""Verify Chroma vector database directory was created."""
|
||||
chroma_path = e2e_tmpdir / 'chroma'
|
||||
# Created by the E2E config factory before startup.
|
||||
assert chroma_path.exists()
|
||||
|
||||
def test_pipelines_endpoint(self, e2e_client):
|
||||
"""Test /api/v1/pipelines endpoint (requires auth)."""
|
||||
# Without auth, should return 401
|
||||
response = e2e_client.get('/api/v1/pipelines')
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_auth_endpoint(self, e2e_client, e2e_tmpdir):
|
||||
"""Test auth endpoint."""
|
||||
# First startup may allow initial setup
|
||||
response = e2e_client.post('/api/v1/user/auth', json={
|
||||
'username': 'admin',
|
||||
'password': 'admin',
|
||||
})
|
||||
|
||||
# Response could be:
|
||||
# - 200 if auth succeeds
|
||||
# - 400 if credentials wrong
|
||||
# - 401 if user not initialized
|
||||
assert response.status_code in [200, 400, 401]
|
||||
|
||||
|
||||
class TestStartupStages:
|
||||
"""Tests that verify individual startup stages worked correctly."""
|
||||
|
||||
def test_config_loaded(self, e2e_client):
|
||||
"""Verify config was loaded correctly by checking API port."""
|
||||
# If API responds on e2e_port, config was loaded
|
||||
assert e2e_client.get('/api/v1/system/info').status_code == 200
|
||||
|
||||
def test_migrations_applied(self, e2e_db_path):
|
||||
"""Verify database migrations were applied."""
|
||||
import sqlite3
|
||||
conn = sqlite3.connect(str(e2e_db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check alembic_version table exists and has version
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='alembic_version';")
|
||||
result = cursor.fetchone()
|
||||
assert result is not None, 'alembic_version table should exist'
|
||||
|
||||
cursor.execute('SELECT version_num FROM alembic_version;')
|
||||
version = cursor.fetchone()
|
||||
assert version is not None, 'Migration version should be set'
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_http_controller_initialized(self, e2e_client):
|
||||
"""Verify HTTP controller was initialized."""
|
||||
# Multiple endpoints should be available
|
||||
endpoints = [
|
||||
'/api/v1/system/info',
|
||||
'/api/v1/pipelines',
|
||||
'/api/v1/provider/providers',
|
||||
'/api/v1/platform/bots',
|
||||
]
|
||||
|
||||
for endpoint in endpoints:
|
||||
response = e2e_client.get(endpoint)
|
||||
# Should get a real route response, even if auth is required.
|
||||
assert response.status_code in [200, 401, 403], f'{endpoint} should be registered'
|
||||
|
||||
|
||||
class TestMinimalStartupNoLLM:
|
||||
"""Tests verifying LangBot can start without LLM providers."""
|
||||
|
||||
def test_api_available_without_llm(self, e2e_client):
|
||||
"""API should be available even without LLM providers configured."""
|
||||
response = e2e_client.get('/api/v1/system/info')
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_pipeline_metadata_available(self, e2e_client):
|
||||
"""Pipeline metadata endpoint should work without LLM."""
|
||||
# Requires auth, but endpoint should exist
|
||||
response = e2e_client.get('/api/v1/pipelines/_/metadata')
|
||||
assert response.status_code in [200, 401] # Not 404 or 500
|
||||
@@ -1,179 +0,0 @@
|
||||
"""E2E test configuration factory.
|
||||
|
||||
Generates minimal config.yaml for testing LangBot startup without external dependencies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def create_minimal_config(tmpdir: Path, port: int = 15300) -> Path:
|
||||
"""Create minimal config.yaml for E2E testing.
|
||||
|
||||
Uses embedded databases (SQLite, Chroma) to avoid external dependencies.
|
||||
Config is created at tmpdir/data/config.yaml (LangBot expects this location).
|
||||
"""
|
||||
# LangBot expects config at data/config.yaml
|
||||
data_dir = tmpdir / 'data'
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config = {
|
||||
'admins': [],
|
||||
'api': {
|
||||
'port': port,
|
||||
'webhook_prefix': f'http://127.0.0.1:{port}',
|
||||
'extra_webhook_prefix': '',
|
||||
},
|
||||
'command': {
|
||||
'enable': True,
|
||||
'prefix': ['!', '!'],
|
||||
'privilege': {},
|
||||
},
|
||||
'concurrency': {
|
||||
'pipeline': 20,
|
||||
'session': 1,
|
||||
},
|
||||
'proxy': {
|
||||
'http': '',
|
||||
'https': '',
|
||||
},
|
||||
'system': {
|
||||
'instance_id': '',
|
||||
'edition': 'community',
|
||||
'recovery_key': '',
|
||||
'allow_modify_login_info': True,
|
||||
'disabled_adapters': [],
|
||||
'limitation': {
|
||||
'max_bots': -1,
|
||||
'max_pipelines': -1,
|
||||
'max_extensions': -1,
|
||||
},
|
||||
'task_retention': {
|
||||
'completed_limit': 200,
|
||||
},
|
||||
'jwt': {
|
||||
'expire': 604800,
|
||||
'secret': 'e2e-test-secret-key',
|
||||
},
|
||||
},
|
||||
'database': {
|
||||
'use': 'sqlite',
|
||||
'sqlite': {
|
||||
'path': str(tmpdir / 'data' / 'langbot.db'),
|
||||
},
|
||||
'postgresql': {
|
||||
'host': '127.0.0.1',
|
||||
'port': 5432,
|
||||
'user': 'postgres',
|
||||
'password': 'postgres',
|
||||
'database': 'postgres',
|
||||
},
|
||||
},
|
||||
'vdb': {
|
||||
'use': 'chroma', # Chroma is embedded, no external dependency
|
||||
'chroma': {
|
||||
'path': str(tmpdir / 'chroma'),
|
||||
},
|
||||
'qdrant': {
|
||||
'url': '',
|
||||
'host': 'localhost',
|
||||
'port': 6333,
|
||||
'api_key': '',
|
||||
},
|
||||
'seekdb': {
|
||||
'mode': 'embedded',
|
||||
'path': str(tmpdir / 'seekdb'),
|
||||
'database': 'langbot',
|
||||
'host': 'localhost',
|
||||
'port': 2881,
|
||||
'user': 'root',
|
||||
'password': '',
|
||||
'tenant': '',
|
||||
},
|
||||
'milvus': {
|
||||
'uri': 'http://127.0.0.1:19530',
|
||||
'token': '',
|
||||
'db_name': '',
|
||||
},
|
||||
'pgvector': {
|
||||
'host': '127.0.0.1',
|
||||
'port': 5433,
|
||||
'database': 'langbot',
|
||||
'user': 'postgres',
|
||||
'password': 'postgres',
|
||||
},
|
||||
},
|
||||
'storage': {
|
||||
'use': 'local',
|
||||
'cleanup': {
|
||||
'enabled': False, # Disable cleanup for tests
|
||||
'check_interval_hours': 1,
|
||||
'uploaded_file_retention_days': 7,
|
||||
'log_retention_days': 3,
|
||||
},
|
||||
'local': {
|
||||
'path': str(tmpdir / 'storage'),
|
||||
},
|
||||
's3': {
|
||||
'endpoint_url': '',
|
||||
'access_key_id': '',
|
||||
'secret_access_key': '',
|
||||
'region': 'us-east-1',
|
||||
'bucket': 'langbot-storage',
|
||||
},
|
||||
},
|
||||
'plugin': {
|
||||
'enable': False, # Disable plugin system for minimal startup
|
||||
'runtime_ws_url': '',
|
||||
'enable_marketplace': False,
|
||||
'display_plugin_debug_url': '',
|
||||
'binary_storage': {
|
||||
'max_value_bytes': 10485760,
|
||||
},
|
||||
},
|
||||
'monitoring': {
|
||||
'auto_cleanup': {
|
||||
'enabled': False, # Disable cleanup for tests
|
||||
'retention_days': 30,
|
||||
'check_interval_hours': 1,
|
||||
'delete_batch_size': 1000,
|
||||
},
|
||||
},
|
||||
'space': {
|
||||
'url': 'https://space.langbot.app',
|
||||
'models_gateway_api_url': 'https://api.langbot.cloud/v1',
|
||||
'oauth_authorize_url': 'https://space.langbot.app/auth/authorize',
|
||||
'disable_models_service': True, # Disable external services
|
||||
'disable_telemetry': True, # Disable telemetry for tests
|
||||
},
|
||||
'provider': {}, # Empty providers - minimal startup
|
||||
'llm': [], # Empty LLM models
|
||||
}
|
||||
|
||||
# Ensure data directory exists (LangBot expects config at data/config.yaml)
|
||||
data_dir = tmpdir / 'data'
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write config to data/config.yaml (LangBot's expected location)
|
||||
config_path = data_dir / 'config.yaml'
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(config, f, default_flow_style=False)
|
||||
|
||||
return config_path
|
||||
|
||||
|
||||
def create_test_directories(tmpdir: Path) -> dict[str, Path]:
|
||||
"""Create necessary directories for LangBot testing."""
|
||||
directories = {
|
||||
'data': tmpdir / 'data',
|
||||
'logs': tmpdir / 'logs',
|
||||
'storage': tmpdir / 'storage',
|
||||
'chroma': tmpdir / 'chroma',
|
||||
}
|
||||
|
||||
for path in directories.values():
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return directories
|
||||
@@ -1,204 +0,0 @@
|
||||
"""E2E test process manager.
|
||||
|
||||
Manages LangBot subprocess lifecycle for E2E testing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import time
|
||||
import signal
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LangBotProcess:
|
||||
"""Manages a LangBot subprocess for E2E testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project_root: Path,
|
||||
work_dir: Path,
|
||||
port: int = 15300,
|
||||
timeout: int = 30,
|
||||
collect_coverage: bool = True,
|
||||
):
|
||||
self.project_root = project_root
|
||||
self.work_dir = work_dir # Directory containing data/config.yaml
|
||||
self.port = port
|
||||
self.timeout = timeout
|
||||
self.collect_coverage = collect_coverage
|
||||
self.process: Optional[subprocess.Popen] = None
|
||||
self._stdout_data: bytes = b''
|
||||
self._stderr_data: bytes = b''
|
||||
self._coverage_file: Optional[Path] = None
|
||||
|
||||
def start(self) -> bool:
|
||||
"""Start LangBot process and wait for it to be ready."""
|
||||
import httpx
|
||||
|
||||
# Prepare environment
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = str(self.project_root / 'src')
|
||||
|
||||
# Set API port via environment variable
|
||||
env['API__PORT'] = str(self.port)
|
||||
env['API__WEBHOOK_PREFIX'] = f'http://127.0.0.1:{self.port}'
|
||||
|
||||
# Disable telemetry
|
||||
env['SPACE__DISABLE_TELEMETRY'] = 'true'
|
||||
env['SPACE__DISABLE_MODELS_SERVICE'] = 'true'
|
||||
|
||||
# Build command
|
||||
if self.collect_coverage:
|
||||
# Use coverage.py to collect coverage data
|
||||
# Set COVERAGE_PROCESS_START to enable coverage in subprocess
|
||||
self._coverage_file = self.work_dir / '.coverage.e2e'
|
||||
env['COVERAGE_PROCESS_START'] = str(self.project_root / '.coveragerc')
|
||||
env['COVERAGE_FILE'] = str(self._coverage_file)
|
||||
|
||||
# Create .coveragerc for subprocess
|
||||
coveragerc_content = """
|
||||
[run]
|
||||
source = langbot.pkg
|
||||
parallel = True
|
||||
data_file = {}
|
||||
omit =
|
||||
*/tests/*
|
||||
*/test_*.py
|
||||
|
||||
[report]
|
||||
precision = 2
|
||||
""".format(str(self._coverage_file))
|
||||
coveragerc_path = self.work_dir / '.coveragerc'
|
||||
with open(coveragerc_path, 'w') as f:
|
||||
f.write(coveragerc_content)
|
||||
|
||||
cmd = [
|
||||
'coverage', 'run',
|
||||
'--rcfile=' + str(coveragerc_path),
|
||||
'-m', 'langbot',
|
||||
]
|
||||
else:
|
||||
cmd = ['uv', 'run', 'python', '-m', 'langbot']
|
||||
|
||||
logger.info(f'Starting LangBot in: {self.work_dir}')
|
||||
logger.info(f'Command: {cmd}')
|
||||
|
||||
# Start process (run in work_dir so it finds data/config.yaml)
|
||||
self.process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=self.work_dir,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if os.name != 'nt' else None,
|
||||
)
|
||||
|
||||
# Wait for startup
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.timeout:
|
||||
# Check if process died
|
||||
if self.process.poll() is not None:
|
||||
self._stdout_data, self._stderr_data = self.process.communicate()
|
||||
logger.error(f'LangBot process died: {self._stderr_data.decode()}')
|
||||
return False
|
||||
|
||||
# Try to connect
|
||||
try:
|
||||
r = httpx.get(
|
||||
f'http://127.0.0.1:{self.port}/api/v1/system/info',
|
||||
timeout=2.0,
|
||||
)
|
||||
if r.status_code == 200:
|
||||
logger.info(f'LangBot started successfully on port {self.port}')
|
||||
return True
|
||||
except (httpx.ConnectError, httpx.TimeoutException):
|
||||
pass
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
# Timeout
|
||||
logger.error(f'LangBot startup timeout after {self.timeout}s')
|
||||
self.stop()
|
||||
return False
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop LangBot process gracefully."""
|
||||
if self.process is None:
|
||||
return
|
||||
|
||||
logger.info('Stopping LangBot process...')
|
||||
|
||||
# Try graceful shutdown first
|
||||
if os.name != 'nt':
|
||||
# Send SIGTERM to process group
|
||||
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
|
||||
else:
|
||||
self.process.terminate()
|
||||
|
||||
# Wait for graceful shutdown
|
||||
try:
|
||||
self.process.wait(timeout=5)
|
||||
logger.info('LangBot stopped gracefully')
|
||||
except subprocess.TimeoutExpired:
|
||||
# Force kill
|
||||
logger.warning('Force killing LangBot process')
|
||||
if os.name != 'nt':
|
||||
os.killpg(os.getpgid(self.process.pid), signal.SIGKILL)
|
||||
else:
|
||||
self.process.kill()
|
||||
self.process.wait()
|
||||
|
||||
# Collect output for debugging
|
||||
if self.process.stdout or self.process.stderr:
|
||||
self._stdout_data, self._stderr_data = self.process.communicate()
|
||||
|
||||
self.process = None
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if process is still running."""
|
||||
return self.process is not None and self.process.poll() is None
|
||||
|
||||
def get_logs(self) -> tuple[str, str]:
|
||||
"""Get stdout and stderr logs."""
|
||||
stdout = self._stdout_data.decode('utf-8', errors='replace')
|
||||
stderr = self._stderr_data.decode('utf-8', errors='replace')
|
||||
return stdout, stderr
|
||||
|
||||
def get_coverage_file(self) -> Optional[Path]:
|
||||
"""Get coverage data file path."""
|
||||
return self._coverage_file
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""Check if LangBot API is responding."""
|
||||
import httpx
|
||||
|
||||
if not self.is_running():
|
||||
return False
|
||||
|
||||
try:
|
||||
r = httpx.get(
|
||||
f'http://127.0.0.1:{self.port}/api/v1/system/info',
|
||||
timeout=5.0,
|
||||
)
|
||||
return r.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def find_project_root() -> Path:
|
||||
"""Find LangBot project root directory."""
|
||||
current = Path(__file__).resolve()
|
||||
|
||||
# Walk up until we find src/langbot
|
||||
for parent in current.parents:
|
||||
if (parent / 'src' / 'langbot').exists():
|
||||
return parent
|
||||
|
||||
# Fallback to LangBot-test-build directory
|
||||
return Path('/home/glwuy/langbot-app/LangBot-test-build')
|
||||
@@ -1,102 +0,0 @@
|
||||
"""
|
||||
Shared test factories for LangBot tests.
|
||||
|
||||
Provides reusable factories for:
|
||||
- Fake application (app.py)
|
||||
- Messages and queries (message.py)
|
||||
- Fake providers (provider.py)
|
||||
- Fake platforms (platform.py)
|
||||
|
||||
Usage:
|
||||
from tests.factories import FakeApp, text_query, FakeProvider
|
||||
|
||||
app = FakeApp()
|
||||
query = text_query("hello")
|
||||
provider = FakeProvider.returns("response")
|
||||
"""
|
||||
|
||||
from tests.factories.app import FakeApp, fake_app
|
||||
from tests.factories.message import (
|
||||
text_chain,
|
||||
group_text_chain,
|
||||
mention_chain,
|
||||
image_chain,
|
||||
text_query,
|
||||
group_text_query,
|
||||
private_text_query,
|
||||
command_query,
|
||||
mention_query,
|
||||
empty_query,
|
||||
image_query,
|
||||
file_query,
|
||||
unsupported_query,
|
||||
voice_query,
|
||||
at_all_query,
|
||||
query_with_session,
|
||||
query_with_config,
|
||||
friend_message_event,
|
||||
group_message_event,
|
||||
mock_adapter,
|
||||
)
|
||||
from tests.factories.provider import (
|
||||
FakeProvider,
|
||||
fake_provider,
|
||||
fake_provider_pong,
|
||||
fake_provider_timeout,
|
||||
fake_provider_auth_error,
|
||||
fake_provider_rate_limit,
|
||||
fake_provider_malformed,
|
||||
fake_model,
|
||||
)
|
||||
from tests.factories.platform import (
|
||||
FakePlatform,
|
||||
fake_platform,
|
||||
fake_platform_with_streaming,
|
||||
fake_platform_with_failure,
|
||||
mock_platform_adapter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# App
|
||||
"FakeApp",
|
||||
"fake_app",
|
||||
# Message chains
|
||||
"text_chain",
|
||||
"group_text_chain",
|
||||
"mention_chain",
|
||||
"image_chain",
|
||||
# Message events
|
||||
"friend_message_event",
|
||||
"group_message_event",
|
||||
# Mock adapters
|
||||
"mock_adapter",
|
||||
# Queries
|
||||
"text_query",
|
||||
"group_text_query",
|
||||
"private_text_query",
|
||||
"command_query",
|
||||
"mention_query",
|
||||
"empty_query",
|
||||
"image_query",
|
||||
"file_query",
|
||||
"unsupported_query",
|
||||
"voice_query",
|
||||
"at_all_query",
|
||||
"query_with_session",
|
||||
"query_with_config",
|
||||
# Provider
|
||||
"FakeProvider",
|
||||
"fake_provider",
|
||||
"fake_provider_pong",
|
||||
"fake_provider_timeout",
|
||||
"fake_provider_auth_error",
|
||||
"fake_provider_rate_limit",
|
||||
"fake_provider_malformed",
|
||||
"fake_model",
|
||||
# Platform
|
||||
"FakePlatform",
|
||||
"fake_platform",
|
||||
"fake_platform_with_streaming",
|
||||
"fake_platform_with_failure",
|
||||
"mock_platform_adapter",
|
||||
]
|
||||
@@ -1,137 +0,0 @@
|
||||
"""
|
||||
Fake application factory for tests.
|
||||
|
||||
Provides a mock Application object with all dependencies needed by pipeline stages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
|
||||
class FakeApp:
|
||||
"""Mock Application object providing all basic dependencies needed by stages."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
command_prefix: list[str] = ["/", "!"],
|
||||
command_enable: bool = True,
|
||||
pipeline_concurrency: int = 10,
|
||||
admins: list[str] | None = None,
|
||||
**extra_attrs,
|
||||
):
|
||||
self.logger = self._create_mock_logger()
|
||||
self.sess_mgr = self._create_mock_session_manager()
|
||||
self.model_mgr = self._create_mock_model_manager()
|
||||
self.tool_mgr = self._create_mock_tool_manager()
|
||||
self.plugin_connector = self._create_mock_plugin_connector()
|
||||
self.persistence_mgr = self._create_mock_persistence_manager()
|
||||
self.query_pool = self._create_mock_query_pool()
|
||||
self.instance_config = self._create_mock_instance_config(
|
||||
command_prefix=command_prefix,
|
||||
command_enable=command_enable,
|
||||
pipeline_concurrency=pipeline_concurrency,
|
||||
admins=admins or [],
|
||||
)
|
||||
self.task_mgr = self._create_mock_task_manager()
|
||||
|
||||
# Handler-specific optional attributes
|
||||
self.telemetry = self._create_mock_telemetry()
|
||||
self.survey = None
|
||||
self.cmd_mgr = self._create_mock_cmd_mgr()
|
||||
|
||||
# Apply any extra attributes for specific test scenarios
|
||||
for name, value in extra_attrs.items():
|
||||
setattr(self, name, value)
|
||||
|
||||
# Captured outbound messages (for assertions)
|
||||
self._outbound_messages: list = []
|
||||
|
||||
def _create_mock_logger(self):
|
||||
logger = Mock()
|
||||
logger.debug = Mock()
|
||||
logger.info = Mock()
|
||||
logger.error = Mock()
|
||||
logger.warning = Mock()
|
||||
return logger
|
||||
|
||||
def _create_mock_session_manager(self):
|
||||
sess_mgr = AsyncMock()
|
||||
sess_mgr.get_session = AsyncMock()
|
||||
sess_mgr.get_conversation = AsyncMock()
|
||||
return sess_mgr
|
||||
|
||||
def _create_mock_model_manager(self):
|
||||
model_mgr = AsyncMock()
|
||||
model_mgr.get_model_by_uuid = AsyncMock()
|
||||
return model_mgr
|
||||
|
||||
def _create_mock_tool_manager(self):
|
||||
tool_mgr = AsyncMock()
|
||||
tool_mgr.get_all_tools = AsyncMock(return_value=[])
|
||||
return tool_mgr
|
||||
|
||||
def _create_mock_plugin_connector(self):
|
||||
plugin_connector = AsyncMock()
|
||||
plugin_connector.emit_event = AsyncMock()
|
||||
return plugin_connector
|
||||
|
||||
def _create_mock_persistence_manager(self):
|
||||
persistence_mgr = AsyncMock()
|
||||
persistence_mgr.execute_async = AsyncMock()
|
||||
return persistence_mgr
|
||||
|
||||
def _create_mock_query_pool(self):
|
||||
query_pool = Mock()
|
||||
query_pool.cached_queries = {}
|
||||
query_pool.queries = []
|
||||
query_pool.condition = AsyncMock()
|
||||
return query_pool
|
||||
|
||||
def _create_mock_instance_config(
|
||||
self,
|
||||
command_prefix: list[str],
|
||||
command_enable: bool,
|
||||
pipeline_concurrency: int,
|
||||
admins: list[str],
|
||||
):
|
||||
instance_config = Mock()
|
||||
instance_config.data = {
|
||||
"command": {"prefix": command_prefix, "enable": command_enable},
|
||||
"concurrency": {"pipeline": pipeline_concurrency},
|
||||
"admins": admins,
|
||||
}
|
||||
return instance_config
|
||||
|
||||
def _create_mock_task_manager(self):
|
||||
task_mgr = Mock()
|
||||
task_mgr.create_task = Mock()
|
||||
return task_mgr
|
||||
|
||||
def _create_mock_telemetry(self):
|
||||
telemetry = AsyncMock()
|
||||
telemetry.start_send_task = AsyncMock()
|
||||
return telemetry
|
||||
|
||||
def _create_mock_cmd_mgr(self):
|
||||
cmd_mgr = AsyncMock()
|
||||
cmd_mgr.execute = AsyncMock()
|
||||
return cmd_mgr
|
||||
|
||||
def capture_message(self, message):
|
||||
"""Capture an outbound message for test assertions."""
|
||||
self._outbound_messages.append(message)
|
||||
|
||||
def get_outbound_messages(self) -> list:
|
||||
"""Get all captured outbound messages."""
|
||||
return self._outbound_messages.copy()
|
||||
|
||||
def clear_outbound_messages(self):
|
||||
"""Clear captured outbound messages."""
|
||||
self._outbound_messages.clear()
|
||||
|
||||
|
||||
def fake_app(**kwargs) -> FakeApp:
|
||||
"""Create a FakeApp instance with optional overrides."""
|
||||
return FakeApp(**kwargs)
|
||||
@@ -1,472 +0,0 @@
|
||||
"""
|
||||
Message and query factories for tests.
|
||||
|
||||
Provides reusable factories for creating message chains, events, and query objects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
import typing
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
|
||||
# Counter for generating unique IDs
|
||||
_query_counter = 0
|
||||
|
||||
|
||||
def _next_query_id() -> int:
|
||||
"""Generate a unique query ID."""
|
||||
global _query_counter
|
||||
_query_counter += 1
|
||||
return _query_counter
|
||||
|
||||
|
||||
# ============== Message Chain Factories ==============
|
||||
|
||||
|
||||
def text_chain(text: str = "hello") -> platform_message.MessageChain:
|
||||
"""Create a simple text message chain."""
|
||||
return platform_message.MessageChain([
|
||||
platform_message.Plain(text=text),
|
||||
])
|
||||
|
||||
|
||||
def group_text_chain(text: str = "hello") -> platform_message.MessageChain:
|
||||
"""Create a group text message chain (same as text_chain, context provided by event)."""
|
||||
return text_chain(text)
|
||||
|
||||
|
||||
def mention_chain(
|
||||
text: str = "hello",
|
||||
target: typing.Union[int, str] = 12345,
|
||||
) -> platform_message.MessageChain:
|
||||
"""Create a message chain with @mention."""
|
||||
return platform_message.MessageChain([
|
||||
platform_message.At(target=target),
|
||||
platform_message.Plain(text=f" {text}"),
|
||||
])
|
||||
|
||||
|
||||
def image_chain(
|
||||
text: str = "",
|
||||
url: str = "https://example.com/image.png",
|
||||
) -> platform_message.MessageChain:
|
||||
"""Create a message chain with an image."""
|
||||
components = []
|
||||
if text:
|
||||
components.append(platform_message.Plain(text=text))
|
||||
components.append(platform_message.Image(url=url))
|
||||
return platform_message.MessageChain(components)
|
||||
|
||||
|
||||
def command_chain(
|
||||
command: str = "help",
|
||||
prefix: str = "/",
|
||||
) -> platform_message.MessageChain:
|
||||
"""Create a command message chain."""
|
||||
return platform_message.MessageChain([
|
||||
platform_message.Plain(text=f"{prefix}{command}"),
|
||||
])
|
||||
|
||||
|
||||
# ============== Message Event Factories ==============
|
||||
|
||||
|
||||
def friend_message_event(
|
||||
message_chain: platform_message.MessageChain,
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
nickname: str = "TestUser",
|
||||
) -> platform_events.FriendMessage:
|
||||
"""Create a friend (private) message event."""
|
||||
sender = platform_entities.Friend(
|
||||
id=sender_id,
|
||||
nickname=nickname,
|
||||
remark=None,
|
||||
)
|
||||
return platform_events.FriendMessage(
|
||||
type="FriendMessage",
|
||||
sender=sender,
|
||||
message_chain=message_chain,
|
||||
time=1609459200,
|
||||
)
|
||||
|
||||
|
||||
def group_message_event(
|
||||
message_chain: platform_message.MessageChain,
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
sender_name: str = "TestUser",
|
||||
group_id: typing.Union[int, str] = 99999,
|
||||
group_name: str = "TestGroup",
|
||||
) -> platform_events.GroupMessage:
|
||||
"""Create a group message event."""
|
||||
group = platform_entities.Group(
|
||||
id=group_id,
|
||||
name=group_name,
|
||||
permission=platform_entities.Permission.Member,
|
||||
)
|
||||
sender = platform_entities.GroupMember(
|
||||
id=sender_id,
|
||||
member_name=sender_name,
|
||||
permission=platform_entities.Permission.Member,
|
||||
group=group,
|
||||
)
|
||||
return platform_events.GroupMessage(
|
||||
type="GroupMessage",
|
||||
sender=sender,
|
||||
message_chain=message_chain,
|
||||
time=1609459200,
|
||||
)
|
||||
|
||||
|
||||
# ============== Mock Adapter Factory ==============
|
||||
|
||||
|
||||
def mock_adapter() -> Mock:
|
||||
"""Create a mock platform adapter."""
|
||||
adapter = AsyncMock()
|
||||
adapter.is_stream_output_supported = AsyncMock(return_value=False)
|
||||
adapter.reply_message = AsyncMock()
|
||||
adapter.reply_message_chunk = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
# ============== Query Factories ==============
|
||||
|
||||
|
||||
def _base_query(
|
||||
message_chain: platform_message.MessageChain,
|
||||
message_event: platform_events.MessageEvent,
|
||||
launcher_type: provider_session.LauncherTypes,
|
||||
launcher_id: typing.Union[int, str],
|
||||
sender_id: typing.Union[int, str],
|
||||
adapter: Mock,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
"""Create a base query with model_construct to bypass validation."""
|
||||
query_id = _next_query_id()
|
||||
|
||||
base_data = {
|
||||
"query_id": query_id,
|
||||
"launcher_type": launcher_type,
|
||||
"launcher_id": launcher_id,
|
||||
"sender_id": sender_id,
|
||||
"message_chain": message_chain,
|
||||
"message_event": message_event,
|
||||
"adapter": adapter,
|
||||
"pipeline_uuid": "test-pipeline-uuid",
|
||||
"bot_uuid": "test-bot-uuid",
|
||||
"pipeline_config": {
|
||||
"ai": {
|
||||
"runner": {"runner": "local-agent"},
|
||||
"local-agent": {
|
||||
"model": {"primary": "test-model-uuid", "fallbacks": []},
|
||||
"prompt": "test-prompt",
|
||||
},
|
||||
},
|
||||
"output": {"misc": {"at-sender": False, "quote-origin": False}},
|
||||
"trigger": {"misc": {"combine-quote-message": False}},
|
||||
},
|
||||
"session": None,
|
||||
"prompt": None,
|
||||
"messages": [],
|
||||
"user_message": None,
|
||||
"use_funcs": [],
|
||||
"use_llm_model_uuid": None,
|
||||
"variables": {},
|
||||
"resp_messages": [],
|
||||
"resp_message_chain": None,
|
||||
"current_stage_name": None,
|
||||
}
|
||||
|
||||
# Apply overrides
|
||||
for key, value in overrides.items():
|
||||
base_data[key] = value
|
||||
|
||||
return pipeline_query.Query.model_construct(**base_data)
|
||||
|
||||
|
||||
def text_query(
|
||||
text: str = "hello",
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
"""Create a basic text query (private chat)."""
|
||||
chain = text_chain(text)
|
||||
event = friend_message_event(chain, sender_id)
|
||||
adapter = mock_adapter()
|
||||
return _base_query(
|
||||
message_chain=chain,
|
||||
message_event=event,
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=sender_id,
|
||||
sender_id=sender_id,
|
||||
adapter=adapter,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def private_text_query(
|
||||
text: str = "hello",
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
"""Create a private text query (alias for text_query)."""
|
||||
return text_query(text, sender_id, **overrides)
|
||||
|
||||
|
||||
def group_text_query(
|
||||
text: str = "hello",
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
group_id: typing.Union[int, str] = 99999,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
"""Create a group text query."""
|
||||
chain = text_chain(text)
|
||||
event = group_message_event(chain, sender_id, group_id=group_id)
|
||||
adapter = mock_adapter()
|
||||
return _base_query(
|
||||
message_chain=chain,
|
||||
message_event=event,
|
||||
launcher_type=provider_session.LauncherTypes.GROUP,
|
||||
launcher_id=group_id,
|
||||
sender_id=sender_id,
|
||||
adapter=adapter,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def command_query(
|
||||
command: str = "help",
|
||||
prefix: str = "/",
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
"""Create a command-like query."""
|
||||
chain = command_chain(command, prefix)
|
||||
event = friend_message_event(chain, sender_id)
|
||||
adapter = mock_adapter()
|
||||
return _base_query(
|
||||
message_chain=chain,
|
||||
message_event=event,
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=sender_id,
|
||||
sender_id=sender_id,
|
||||
adapter=adapter,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def mention_query(
|
||||
text: str = "hello",
|
||||
target: typing.Union[int, str] = 12345,
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
group_id: typing.Union[int, str] = 99999,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
"""Create a mention-bot query (group chat with @mention)."""
|
||||
chain = mention_chain(text, target)
|
||||
event = group_message_event(chain, sender_id, group_id=group_id)
|
||||
adapter = mock_adapter()
|
||||
return _base_query(
|
||||
message_chain=chain,
|
||||
message_event=event,
|
||||
launcher_type=provider_session.LauncherTypes.GROUP,
|
||||
launcher_id=group_id,
|
||||
sender_id=sender_id,
|
||||
adapter=adapter,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def empty_query(**overrides) -> pipeline_query.Query:
|
||||
"""Create an empty message query."""
|
||||
chain = platform_message.MessageChain([])
|
||||
event = friend_message_event(chain)
|
||||
adapter = mock_adapter()
|
||||
return _base_query(
|
||||
message_chain=chain,
|
||||
message_event=event,
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
adapter=adapter,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def image_query(
|
||||
text: str = "",
|
||||
url: str = "https://example.com/image.png",
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
"""Create an image query."""
|
||||
chain = image_chain(text, url)
|
||||
event = friend_message_event(chain, sender_id)
|
||||
adapter = mock_adapter()
|
||||
return _base_query(
|
||||
message_chain=chain,
|
||||
message_event=event,
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=sender_id,
|
||||
sender_id=sender_id,
|
||||
adapter=adapter,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def file_query(
|
||||
url: str = "https://example.com/document.pdf",
|
||||
name: str = "document.pdf",
|
||||
text: str = "",
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
"""Create a file attachment query."""
|
||||
components = []
|
||||
if text:
|
||||
components.append(platform_message.Plain(text=text))
|
||||
components.append(platform_message.File(url=url, name=name))
|
||||
chain = platform_message.MessageChain(components)
|
||||
event = friend_message_event(chain, sender_id)
|
||||
adapter = mock_adapter()
|
||||
return _base_query(
|
||||
message_chain=chain,
|
||||
message_event=event,
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=sender_id,
|
||||
sender_id=sender_id,
|
||||
adapter=adapter,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def unsupported_query(
|
||||
unsupported_type: str = "CustomComponent",
|
||||
text: str = "",
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
"""Create a query with unsupported/unknown message segment."""
|
||||
components = []
|
||||
if text:
|
||||
components.append(platform_message.Plain(text=text))
|
||||
# Use Unknown component for unsupported types
|
||||
components.append(platform_message.Unknown(text=f"Unsupported: {unsupported_type}"))
|
||||
chain = platform_message.MessageChain(components)
|
||||
event = friend_message_event(chain, sender_id)
|
||||
adapter = mock_adapter()
|
||||
return _base_query(
|
||||
message_chain=chain,
|
||||
message_event=event,
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=sender_id,
|
||||
sender_id=sender_id,
|
||||
adapter=adapter,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def query_with_session(
|
||||
text: str = "hello",
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
session: provider_session.Session = None,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
"""Create a query with a session object.
|
||||
|
||||
If session is None, creates a default session with empty conversation.
|
||||
"""
|
||||
if session is None:
|
||||
# Create a default session
|
||||
session = provider_session.Session(
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=sender_id,
|
||||
sender_id=sender_id,
|
||||
use_prompt_name="default",
|
||||
using_conversation=None,
|
||||
conversations=[],
|
||||
)
|
||||
|
||||
return text_query(text, sender_id, session=session, **overrides)
|
||||
|
||||
|
||||
def query_with_config(
|
||||
text: str = "hello",
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
pipeline_config: dict = None,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
"""Create a query with custom pipeline configuration.
|
||||
|
||||
If pipeline_config is None, uses default config.
|
||||
Useful for testing specific stage behaviors.
|
||||
"""
|
||||
if pipeline_config is None:
|
||||
pipeline_config = {
|
||||
"ai": {
|
||||
"runner": {"runner": "local-agent"},
|
||||
"local-agent": {
|
||||
"model": {"primary": "test-model-uuid", "fallbacks": []},
|
||||
"prompt": "test-prompt",
|
||||
},
|
||||
},
|
||||
"output": {"misc": {"at-sender": False, "quote-origin": False}},
|
||||
"trigger": {"misc": {"combine-quote-message": False}},
|
||||
}
|
||||
|
||||
return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides)
|
||||
|
||||
|
||||
def voice_query(
|
||||
url: str = "https://example.com/audio.mp3",
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
"""Create a voice/audio query."""
|
||||
components = [
|
||||
platform_message.Voice(url=url),
|
||||
]
|
||||
chain = platform_message.MessageChain(components)
|
||||
event = friend_message_event(chain, sender_id)
|
||||
adapter = mock_adapter()
|
||||
return _base_query(
|
||||
message_chain=chain,
|
||||
message_event=event,
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=sender_id,
|
||||
sender_id=sender_id,
|
||||
adapter=adapter,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def at_all_query(
|
||||
text: str = "hello",
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
group_id: typing.Union[int, str] = 99999,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
"""Create a group query with @All mention."""
|
||||
components = [
|
||||
platform_message.AtAll(),
|
||||
platform_message.Plain(text=f" {text}"),
|
||||
]
|
||||
chain = platform_message.MessageChain(components)
|
||||
event = group_message_event(chain, sender_id, group_id=group_id)
|
||||
adapter = mock_adapter()
|
||||
return _base_query(
|
||||
message_chain=chain,
|
||||
message_event=event,
|
||||
launcher_type=provider_session.LauncherTypes.GROUP,
|
||||
launcher_id=group_id,
|
||||
sender_id=sender_id,
|
||||
adapter=adapter,
|
||||
**overrides,
|
||||
)
|
||||
@@ -1,336 +0,0 @@
|
||||
"""
|
||||
Fake platform factory for tests.
|
||||
|
||||
Provides a fake platform adapter for tests that need inbound message injection
|
||||
and outbound message capture.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
import typing
|
||||
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
|
||||
|
||||
class FakePlatform:
|
||||
"""Fake platform adapter for unit and integration tests.
|
||||
|
||||
Simulates platform behavior without real network calls:
|
||||
- Inbound text message construction
|
||||
- Group and private conversation identities
|
||||
- Mention-bot flag
|
||||
- Outbound text capture
|
||||
- Outbound file/image capture
|
||||
- Send failure simulation
|
||||
|
||||
Does not start real platform adapters.
|
||||
Does not call IM platform SDKs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bot_account_id: str = "test-bot",
|
||||
stream_output_supported: bool = False,
|
||||
raise_error: Exception = None,
|
||||
):
|
||||
self.bot_account_id = bot_account_id
|
||||
self._stream_output_supported = stream_output_supported
|
||||
self._raise_error = raise_error
|
||||
|
||||
# Captured outbound messages
|
||||
self._outbound_messages: list[dict] = []
|
||||
self._outbound_chunks: list[dict] = []
|
||||
|
||||
# Registered listeners
|
||||
self._listeners: dict = {}
|
||||
|
||||
def raises(self, error: Exception) -> "FakePlatform":
|
||||
"""Configure platform to raise an error on send."""
|
||||
self._raise_error = error
|
||||
return self
|
||||
|
||||
def send_failure(self) -> "FakePlatform":
|
||||
"""Configure platform to simulate send failure."""
|
||||
return self.raises(Exception("Platform send failure"))
|
||||
|
||||
def supports_streaming(self, supported: bool = True) -> "FakePlatform":
|
||||
"""Configure whether streaming output is supported."""
|
||||
self._stream_output_supported = supported
|
||||
return self
|
||||
|
||||
def get_outbound_messages(self) -> list[dict]:
|
||||
"""Get all captured outbound messages for assertions."""
|
||||
return self._outbound_messages.copy()
|
||||
|
||||
def get_outbound_chunks(self) -> list[dict]:
|
||||
"""Get all captured outbound streaming chunks for assertions."""
|
||||
return self._outbound_chunks.copy()
|
||||
|
||||
def clear_outbound(self):
|
||||
"""Clear captured outbound messages."""
|
||||
self._outbound_messages.clear()
|
||||
self._outbound_chunks.clear()
|
||||
|
||||
def last_message(self) -> dict | None:
|
||||
"""Get the last captured outbound message."""
|
||||
return self._outbound_messages[-1] if self._outbound_messages else None
|
||||
|
||||
def last_chunk(self) -> dict | None:
|
||||
"""Get the last captured streaming chunk."""
|
||||
return self._outbound_chunks[-1] if self._outbound_chunks else None
|
||||
|
||||
# ============== Inbound Message Construction ==============
|
||||
|
||||
def create_friend_message(
|
||||
self,
|
||||
text: str,
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
nickname: str = "TestUser",
|
||||
) -> platform_events.FriendMessage:
|
||||
"""Create an inbound friend (private) message event."""
|
||||
sender = platform_entities.Friend(
|
||||
id=sender_id,
|
||||
nickname=nickname,
|
||||
remark=None,
|
||||
)
|
||||
chain = platform_message.MessageChain([
|
||||
platform_message.Plain(text=text),
|
||||
])
|
||||
return platform_events.FriendMessage(
|
||||
type="FriendMessage",
|
||||
sender=sender,
|
||||
message_chain=chain,
|
||||
time=1609459200,
|
||||
)
|
||||
|
||||
def create_group_message(
|
||||
self,
|
||||
text: str,
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
sender_name: str = "TestUser",
|
||||
group_id: typing.Union[int, str] = 99999,
|
||||
group_name: str = "TestGroup",
|
||||
mention_bot: bool = False,
|
||||
) -> platform_events.GroupMessage:
|
||||
"""Create an inbound group message event.
|
||||
|
||||
Args:
|
||||
text: Message text content
|
||||
sender_id: Sender user ID
|
||||
sender_name: Sender display name
|
||||
group_id: Group ID
|
||||
group_name: Group name
|
||||
mention_bot: If True, prepend @mention of bot account
|
||||
"""
|
||||
group = platform_entities.Group(
|
||||
id=group_id,
|
||||
name=group_name,
|
||||
permission=platform_entities.Permission.Member,
|
||||
)
|
||||
sender = platform_entities.GroupMember(
|
||||
id=sender_id,
|
||||
member_name=sender_name,
|
||||
permission=platform_entities.Permission.Member,
|
||||
group=group,
|
||||
)
|
||||
|
||||
# Build message chain with optional mention
|
||||
components = []
|
||||
if mention_bot:
|
||||
components.append(platform_message.At(target=self.bot_account_id))
|
||||
components.append(platform_message.Plain(text=" "))
|
||||
components.append(platform_message.Plain(text=text))
|
||||
|
||||
chain = platform_message.MessageChain(components)
|
||||
return platform_events.GroupMessage(
|
||||
type="GroupMessage",
|
||||
sender=sender,
|
||||
message_chain=chain,
|
||||
time=1609459200,
|
||||
)
|
||||
|
||||
def create_image_message(
|
||||
self,
|
||||
url: str = "https://example.com/image.png",
|
||||
text: str = "",
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
is_group: bool = False,
|
||||
group_id: typing.Union[int, str] = 99999,
|
||||
) -> platform_events.MessageEvent:
|
||||
"""Create an inbound image message event."""
|
||||
components = []
|
||||
if text:
|
||||
components.append(platform_message.Plain(text=text))
|
||||
components.append(platform_message.Image(url=url))
|
||||
chain = platform_message.MessageChain(components)
|
||||
|
||||
if is_group:
|
||||
return self.create_group_message("", sender_id, group_id=group_id)
|
||||
# Replace chain
|
||||
else:
|
||||
sender = platform_entities.Friend(id=sender_id, nickname="TestUser", remark=None)
|
||||
return platform_events.FriendMessage(
|
||||
type="FriendMessage",
|
||||
sender=sender,
|
||||
message_chain=chain,
|
||||
time=1609459200,
|
||||
)
|
||||
|
||||
# ============== Adapter Methods (Simulated) ==============
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: platform_message.MessageChain,
|
||||
):
|
||||
"""Simulate sending a message (captures for assertions)."""
|
||||
if self._raise_error:
|
||||
raise self._raise_error
|
||||
|
||||
self._outbound_messages.append({
|
||||
"type": "send",
|
||||
"target_type": target_type,
|
||||
"target_id": target_id,
|
||||
"message": message,
|
||||
})
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
"""Simulate replying to a message (captures for assertions)."""
|
||||
if self._raise_error:
|
||||
raise self._raise_error
|
||||
|
||||
self._outbound_messages.append({
|
||||
"type": "reply",
|
||||
"source_type": message_source.type,
|
||||
"source": message_source,
|
||||
"message": message,
|
||||
"quote_origin": quote_origin,
|
||||
})
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
bot_message: dict,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
):
|
||||
"""Simulate streaming reply (captures for assertions)."""
|
||||
if self._raise_error:
|
||||
raise self._raise_error
|
||||
|
||||
self._outbound_chunks.append({
|
||||
"type": "reply_chunk",
|
||||
"source_type": message_source.type,
|
||||
"source": message_source,
|
||||
"bot_message": bot_message,
|
||||
"message": message,
|
||||
"quote_origin": quote_origin,
|
||||
"is_final": is_final,
|
||||
})
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
"""Return whether streaming output is supported."""
|
||||
return self._stream_output_supported
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable,
|
||||
):
|
||||
"""Register an event listener (stores for simulation)."""
|
||||
if event_type not in self._listeners:
|
||||
self._listeners[event_type] = []
|
||||
self._listeners[event_type].append(callback)
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable,
|
||||
):
|
||||
"""Unregister an event listener."""
|
||||
if event_type in self._listeners:
|
||||
self._listeners[event_type].remove(callback)
|
||||
|
||||
async def run_async(self):
|
||||
"""Simulate running the adapter (does nothing)."""
|
||||
pass
|
||||
|
||||
async def kill(self) -> bool:
|
||||
"""Simulate killing the adapter."""
|
||||
return True
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
"""Simulate checking mute status."""
|
||||
return False
|
||||
|
||||
async def create_message_card(
|
||||
self,
|
||||
message_id: typing.Type[str, int],
|
||||
event: platform_events.MessageEvent,
|
||||
) -> bool:
|
||||
"""Simulate creating a message card."""
|
||||
return False
|
||||
|
||||
# ============== Simulation Helpers ==============
|
||||
|
||||
async def simulate_inbound_event(
|
||||
self,
|
||||
event: platform_events.Event,
|
||||
):
|
||||
"""Simulate receiving an inbound event by calling registered listeners."""
|
||||
listeners = self._listeners.get(type(event), [])
|
||||
for callback in listeners:
|
||||
await callback(event, self)
|
||||
|
||||
|
||||
def fake_platform(
|
||||
bot_account_id: str = "test-bot",
|
||||
stream_output_supported: bool = False,
|
||||
) -> FakePlatform:
|
||||
"""Create a FakePlatform instance."""
|
||||
return FakePlatform(
|
||||
bot_account_id=bot_account_id,
|
||||
stream_output_supported=stream_output_supported,
|
||||
)
|
||||
|
||||
|
||||
def fake_platform_with_streaming() -> FakePlatform:
|
||||
"""Create a FakePlatform that supports streaming output."""
|
||||
return FakePlatform(stream_output_supported=True)
|
||||
|
||||
|
||||
def fake_platform_with_failure() -> FakePlatform:
|
||||
"""Create a FakePlatform that simulates send failure."""
|
||||
return FakePlatform().send_failure()
|
||||
|
||||
|
||||
# ============== Mock Adapter (for Query) ==============
|
||||
|
||||
|
||||
def mock_platform_adapter(platform: FakePlatform = None) -> Mock:
|
||||
"""Create a mock platform adapter using FakePlatform or a simple mock."""
|
||||
if platform is None:
|
||||
platform = FakePlatform()
|
||||
|
||||
adapter = Mock()
|
||||
adapter.bot_account_id = platform.bot_account_id
|
||||
adapter.reply_message = AsyncMock(side_effect=platform.reply_message)
|
||||
adapter.reply_message_chunk = AsyncMock(side_effect=platform.reply_message_chunk)
|
||||
adapter.send_message = AsyncMock(side_effect=platform.send_message)
|
||||
adapter.is_stream_output_supported = AsyncMock(
|
||||
return_value=platform._stream_output_supported
|
||||
)
|
||||
adapter._fake_platform = platform # Store for assertions
|
||||
|
||||
return adapter
|
||||
@@ -1,224 +0,0 @@
|
||||
"""
|
||||
Fake provider factory for tests.
|
||||
|
||||
Provides a deterministic fake provider that simulates LLM responses without real API calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
import typing
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
class FakeProvider:
|
||||
"""Deterministic fake provider for unit and integration tests.
|
||||
|
||||
Simulates various provider behaviors:
|
||||
- Normal text response
|
||||
- Streaming response
|
||||
- Timeout error
|
||||
- Auth error
|
||||
- Rate-limit error
|
||||
- Malformed response
|
||||
|
||||
Does not call real LLM vendors.
|
||||
Does not require API keys.
|
||||
"""
|
||||
|
||||
PONG_RESPONSE = "LANGBOT_FAKE_PONG"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
default_response: str = "fake response",
|
||||
streaming_chunks: list[str] = None,
|
||||
raise_error: Exception = None,
|
||||
captured_requests: list = None,
|
||||
):
|
||||
self._default_response = default_response
|
||||
self._streaming_chunks = streaming_chunks or ["fake ", "response"]
|
||||
self._raise_error = raise_error
|
||||
self._captured_requests = captured_requests if captured_requests is not None else []
|
||||
|
||||
def returns(self, text: str) -> "FakeProvider":
|
||||
"""Configure provider to return a specific text response."""
|
||||
self._default_response = text
|
||||
self._streaming_chunks = [text]
|
||||
return self
|
||||
|
||||
def returns_streaming(self, chunks: list[str]) -> "FakeProvider":
|
||||
"""Configure provider to return streaming chunks."""
|
||||
self._streaming_chunks = chunks
|
||||
self._default_response = "".join(chunks)
|
||||
return self
|
||||
|
||||
def raises(self, error: Exception) -> "FakeProvider":
|
||||
"""Configure provider to raise an error."""
|
||||
self._raise_error = error
|
||||
return self
|
||||
|
||||
def timeout(self) -> "FakeProvider":
|
||||
"""Configure provider to simulate timeout."""
|
||||
return self.raises(TimeoutError("Provider timeout"))
|
||||
|
||||
def auth_error(self) -> "FakeProvider":
|
||||
"""Configure provider to simulate auth error."""
|
||||
return self.raises(Exception("Invalid API key"))
|
||||
|
||||
def rate_limit(self) -> "FakeProvider":
|
||||
"""Configure provider to simulate rate limit."""
|
||||
return self.raises(Exception("Rate limit exceeded"))
|
||||
|
||||
def malformed(self) -> "FakeProvider":
|
||||
"""Configure provider to simulate malformed response."""
|
||||
self._default_response = None
|
||||
return self
|
||||
|
||||
def get_captured_requests(self) -> list:
|
||||
"""Get all captured request arguments for assertions."""
|
||||
return self._captured_requests.copy()
|
||||
|
||||
def clear_captured_requests(self):
|
||||
"""Clear captured requests."""
|
||||
self._captured_requests.clear()
|
||||
|
||||
def _create_message(self, content: str) -> provider_message.Message:
|
||||
"""Create a provider message from text content."""
|
||||
return provider_message.Message(
|
||||
role="assistant",
|
||||
content=content,
|
||||
)
|
||||
|
||||
def _create_chunk(
|
||||
self,
|
||||
content: str,
|
||||
is_final: bool = False,
|
||||
msg_sequence: int = 0,
|
||||
) -> provider_message.MessageChunk:
|
||||
"""Create a provider message chunk."""
|
||||
return provider_message.MessageChunk(
|
||||
role="assistant",
|
||||
content=content,
|
||||
is_final=is_final,
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query,
|
||||
model,
|
||||
messages: list,
|
||||
funcs: list,
|
||||
extra_args: dict,
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message:
|
||||
"""Simulate non-streaming LLM invocation."""
|
||||
# Capture request for assertions
|
||||
self._captured_requests.append({
|
||||
"query_id": query.query_id if query else None,
|
||||
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
|
||||
"messages": messages,
|
||||
"funcs": funcs,
|
||||
"extra_args": extra_args,
|
||||
})
|
||||
|
||||
# Simulate error if configured
|
||||
if self._raise_error:
|
||||
raise self._raise_error
|
||||
|
||||
# Return response
|
||||
if self._default_response is None:
|
||||
# Malformed response
|
||||
return provider_message.Message(role="assistant", content=None)
|
||||
|
||||
return self._create_message(self._default_response)
|
||||
|
||||
async def invoke_llm_stream(
|
||||
self,
|
||||
query,
|
||||
model,
|
||||
messages: list,
|
||||
funcs: list,
|
||||
extra_args: dict,
|
||||
remove_think: bool = False,
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""Simulate streaming LLM invocation."""
|
||||
# Capture request for assertions
|
||||
self._captured_requests.append({
|
||||
"query_id": query.query_id if query else None,
|
||||
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
|
||||
"messages": messages,
|
||||
"funcs": funcs,
|
||||
"extra_args": extra_args,
|
||||
"streaming": True,
|
||||
})
|
||||
|
||||
# Simulate error if configured
|
||||
if self._raise_error:
|
||||
raise self._raise_error
|
||||
|
||||
# Yield chunks
|
||||
for i, chunk in enumerate(self._streaming_chunks):
|
||||
is_final = (i == len(self._streaming_chunks) - 1)
|
||||
yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i)
|
||||
|
||||
|
||||
def fake_provider(
|
||||
default_response: str = "fake response",
|
||||
) -> FakeProvider:
|
||||
"""Create a FakeProvider with optional default response."""
|
||||
return FakeProvider(default_response=default_response)
|
||||
|
||||
|
||||
def fake_provider_pong() -> FakeProvider:
|
||||
"""Create a FakeProvider that returns the pong response."""
|
||||
return FakeProvider(default_response=FakeProvider.PONG_RESPONSE)
|
||||
|
||||
|
||||
def fake_provider_timeout() -> FakeProvider:
|
||||
"""Create a FakeProvider that simulates timeout."""
|
||||
return FakeProvider().timeout()
|
||||
|
||||
|
||||
def fake_provider_auth_error() -> FakeProvider:
|
||||
"""Create a FakeProvider that simulates auth error."""
|
||||
return FakeProvider().auth_error()
|
||||
|
||||
|
||||
def fake_provider_rate_limit() -> FakeProvider:
|
||||
"""Create a FakeProvider that simulates rate limit."""
|
||||
return FakeProvider().rate_limit()
|
||||
|
||||
|
||||
def fake_provider_malformed() -> FakeProvider:
|
||||
"""Create a FakeProvider that simulates malformed response."""
|
||||
return FakeProvider().malformed()
|
||||
|
||||
|
||||
# ============== Mock Model Factory ==============
|
||||
|
||||
|
||||
def fake_model(
|
||||
*,
|
||||
uuid: str = "test-model-uuid",
|
||||
name: str = "test-model",
|
||||
abilities: list[str] = None,
|
||||
provider: FakeProvider = None,
|
||||
) -> Mock:
|
||||
"""Create a mock model with a fake provider."""
|
||||
model = Mock()
|
||||
model.model_entity = Mock()
|
||||
model.model_entity.uuid = uuid
|
||||
model.model_entity.name = name
|
||||
model.model_entity.abilities = abilities or ["func_call", "vision"]
|
||||
model.model_entity.extra_args = {}
|
||||
|
||||
# Attach fake provider
|
||||
if provider is None:
|
||||
provider = FakeProvider()
|
||||
|
||||
model.provider = provider
|
||||
|
||||
return model
|
||||
@@ -1,6 +0,0 @@
|
||||
"""
|
||||
Integration tests package.
|
||||
|
||||
These tests validate real system behavior with actual database/network resources.
|
||||
Run with: uv run pytest tests/integration/ -m "not slow" -q
|
||||
"""
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
API integration tests package.
|
||||
|
||||
Tests for HTTP API endpoints using Quart test client.
|
||||
"""
|
||||
@@ -1,28 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def dedupe_preregistered_groups() -> None:
|
||||
"""Keep API integration route registration isolated across test modules."""
|
||||
from langbot.pkg.api.http.controller import group
|
||||
|
||||
seen: set[tuple[str, str]] = set()
|
||||
unique_groups = []
|
||||
for group_cls in group.preregistered_groups:
|
||||
key = (group_cls.name, group_cls.path)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
unique_groups.append(group_cls)
|
||||
|
||||
group.preregistered_groups[:] = unique_groups
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def http_controller_cls(mock_circular_import_chain):
|
||||
"""Import HTTPController under each module's circular-import isolation."""
|
||||
from langbot.pkg.api.http.controller.main import HTTPController
|
||||
|
||||
dedupe_preregistered_groups()
|
||||
return HTTPController
|
||||
@@ -1,253 +0,0 @@
|
||||
"""
|
||||
API integration tests for bot endpoints.
|
||||
|
||||
Tests real HTTP API behavior for bot management.
|
||||
|
||||
Run: uv run pytest tests/integration/api/test_bots.py -q
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, Mock
|
||||
|
||||
from tests.factories import FakeApp
|
||||
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""Break circular import chain for API controller."""
|
||||
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
|
||||
|
||||
class FakeMinimalApplication:
|
||||
pass
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.Application = FakeMinimalApplication
|
||||
|
||||
mock_entities = MagicMock()
|
||||
mock_entities.LifecycleControlScope = MockLifecycleControlScope
|
||||
|
||||
clear = [
|
||||
'langbot.pkg.api.http.controller.group',
|
||||
'langbot.pkg.api.http.controller.groups',
|
||||
'langbot.pkg.api.http.controller.groups.platform',
|
||||
'langbot.pkg.api.http.controller.groups.platform.bots',
|
||||
'langbot.pkg.api.http.controller.groups.platform.adapters',
|
||||
'langbot.pkg.api.http.controller.main',
|
||||
]
|
||||
|
||||
with isolated_sys_modules(
|
||||
mocks={
|
||||
'langbot.pkg.core.app': mock_app,
|
||||
'langbot.pkg.core.entities': mock_entities,
|
||||
},
|
||||
clear=clear,
|
||||
):
|
||||
import langbot.pkg.api.http.controller.groups.platform.bots as _bots # noqa: E402, F401
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def fake_bot_app():
|
||||
"""Create FakeApp with bot services (module scope for reuse)."""
|
||||
app = FakeApp()
|
||||
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
|
||||
# Auth services
|
||||
app.user_service = Mock()
|
||||
app.user_service.is_initialized = AsyncMock(return_value=True)
|
||||
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
|
||||
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
|
||||
app.apikey_service = Mock()
|
||||
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
|
||||
|
||||
# Bot service
|
||||
app.bot_service = Mock()
|
||||
app.bot_service.get_bots = AsyncMock(return_value=[
|
||||
{
|
||||
'uuid': 'test-bot-uuid',
|
||||
'name': 'Test Bot',
|
||||
'platform': 'telegram',
|
||||
'pipeline_uuid': 'test-pipeline-uuid',
|
||||
}
|
||||
])
|
||||
app.bot_service.get_runtime_bot_info = AsyncMock(return_value={
|
||||
'uuid': 'test-bot-uuid',
|
||||
'name': 'Test Bot',
|
||||
'platform': 'telegram',
|
||||
'pipeline_uuid': 'test-pipeline-uuid',
|
||||
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
|
||||
})
|
||||
app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'})
|
||||
app.bot_service.update_bot = AsyncMock(return_value={})
|
||||
app.bot_service.delete_bot = AsyncMock()
|
||||
app.bot_service.list_event_logs = AsyncMock(return_value=(
|
||||
[{'uuid': 'log-1', 'message': 'test log'}],
|
||||
1
|
||||
))
|
||||
app.bot_service.send_message = AsyncMock()
|
||||
|
||||
# Platform manager
|
||||
app.platform_mgr = Mock()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def quart_test_client(fake_bot_app, http_controller_cls):
|
||||
"""Create Quart test client (module scope to avoid route re-registration)."""
|
||||
controller = http_controller_cls(fake_bot_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestBotEndpoints:
|
||||
"""Tests for /api/v1/platform/bots endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_bots_success(self, quart_test_client):
|
||||
"""GET /api/v1/platform/bots returns bot list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/platform/bots',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'data' in data
|
||||
assert 'bots' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_bot_success(self, quart_test_client):
|
||||
"""POST /api/v1/platform/bots creates new bot."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/platform/bots',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'uuid' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_single_bot_success(self, quart_test_client):
|
||||
"""GET /api/v1/platform/bots/{uuid} returns bot with runtime info."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/platform/bots/test-bot-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'bot' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_bot_success(self, quart_test_client):
|
||||
"""PUT /api/v1/platform/bots/{uuid} updates bot."""
|
||||
response = await quart_test_client.put(
|
||||
'/api/v1/platform/bots/test-bot-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'Updated Bot'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_bot_success(self, quart_test_client):
|
||||
"""DELETE /api/v1/platform/bots/{uuid} deletes bot."""
|
||||
response = await quart_test_client.delete(
|
||||
'/api/v1/platform/bots/test-bot-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestBotLogsEndpoint:
|
||||
"""Tests for bot logs endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_bot_logs_success(self, quart_test_client):
|
||||
"""POST /api/v1/platform/bots/{uuid}/logs returns logs."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/platform/bots/test-bot-uuid/logs',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'from_index': -1, 'max_count': 10}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'logs' in data['data']
|
||||
assert 'total_count' in data['data']
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestBotSendMessageEndpoint:
|
||||
"""Tests for bot send message endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_success(self, quart_test_client):
|
||||
"""POST /api/v1/platform/bots/{uuid}/send_message sends message."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/platform/bots/test-bot-uuid/send_message',
|
||||
headers={'Authorization': 'Bearer test_api_key'},
|
||||
json={
|
||||
'target_type': 'person',
|
||||
'target_id': 'user123',
|
||||
'message_chain': [{'type': 'text', 'text': 'Hello'}]
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert data['data']['sent'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_missing_target_type(self, quart_test_client):
|
||||
"""POST send_message without target_type returns 400."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/platform/bots/test-bot-uuid/send_message',
|
||||
headers={'Authorization': 'Bearer test_api_key'},
|
||||
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
data = await response.get_json()
|
||||
assert data['code'] == -1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_invalid_target_type(self, quart_test_client):
|
||||
"""POST send_message with invalid target_type returns 400."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/platform/bots/test-bot-uuid/send_message',
|
||||
headers={'Authorization': 'Bearer test_api_key'},
|
||||
json={
|
||||
'target_type': 'invalid',
|
||||
'target_id': 'user123',
|
||||
'message_chain': [{'type': 'text', 'text': 'Hello'}]
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
data = await response.get_json()
|
||||
assert data['code'] == -1
|
||||
@@ -1,300 +0,0 @@
|
||||
"""
|
||||
API integration tests for embed widget endpoints.
|
||||
|
||||
Tests real HTTP API behavior for embed widget functionality.
|
||||
|
||||
Run: uv run pytest tests/integration/api/test_embed.py -q
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, Mock
|
||||
|
||||
from tests.factories import FakeApp
|
||||
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""Break circular import chain for API controller."""
|
||||
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
|
||||
|
||||
class FakeMinimalApplication:
|
||||
pass
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.Application = FakeMinimalApplication
|
||||
|
||||
mock_entities = MagicMock()
|
||||
mock_entities.LifecycleControlScope = MockLifecycleControlScope
|
||||
|
||||
clear = [
|
||||
'langbot.pkg.api.http.controller.group',
|
||||
'langbot.pkg.api.http.controller.groups',
|
||||
'langbot.pkg.api.http.controller.groups.pipelines',
|
||||
'langbot.pkg.api.http.controller.groups.pipelines.embed',
|
||||
'langbot.pkg.api.http.controller.main',
|
||||
]
|
||||
|
||||
with isolated_sys_modules(
|
||||
mocks={
|
||||
'langbot.pkg.core.app': mock_app,
|
||||
'langbot.pkg.core.entities': mock_entities,
|
||||
},
|
||||
clear=clear,
|
||||
):
|
||||
import langbot.pkg.api.http.controller.groups.pipelines.embed as _embed # noqa: E402, F401
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def fake_embed_app():
|
||||
"""Create FakeApp with embed widget services (module scope)."""
|
||||
app = FakeApp()
|
||||
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
|
||||
# Create mock web_page_bot with valid UUID format
|
||||
mock_bot_entity = Mock()
|
||||
mock_bot_entity.uuid = 'a1b2c3d4-5678-90ab-cdef-123456789abc'
|
||||
mock_bot_entity.adapter = 'web_page_bot'
|
||||
mock_bot_entity.enable = True
|
||||
mock_bot_entity.use_pipeline_uuid = 'test-pipeline-uuid'
|
||||
mock_bot_entity.name = 'Test Web Bot'
|
||||
mock_bot_entity.adapter_config = {
|
||||
'turnstile_secret_key': '',
|
||||
'turnstile_site_key': '',
|
||||
'language': 'en_US',
|
||||
'bubble_icon': 'logo',
|
||||
}
|
||||
|
||||
mock_runtime_bot = Mock()
|
||||
mock_runtime_bot.bot_entity = mock_bot_entity
|
||||
|
||||
# Platform manager with bots
|
||||
app.platform_mgr = Mock()
|
||||
app.platform_mgr.bots = [mock_runtime_bot]
|
||||
|
||||
# WebSocket proxy bot with adapter
|
||||
mock_websocket_adapter = Mock()
|
||||
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[
|
||||
{'id': 'msg-1', 'content': 'test message'}
|
||||
])
|
||||
mock_websocket_adapter.reset_session = Mock()
|
||||
mock_websocket_adapter.handle_websocket_message = AsyncMock()
|
||||
|
||||
mock_ws_proxy_bot = Mock()
|
||||
mock_ws_proxy_bot.adapter = mock_websocket_adapter
|
||||
app.platform_mgr.websocket_proxy_bot = mock_ws_proxy_bot
|
||||
|
||||
# Monitoring service for feedback
|
||||
app.monitoring_service = Mock()
|
||||
app.monitoring_service.record_feedback = AsyncMock()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def quart_test_client(fake_embed_app, http_controller_cls):
|
||||
"""Create Quart test client (module scope)."""
|
||||
controller = http_controller_cls(fake_embed_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestEmbedWidgetEndpoint:
|
||||
"""Tests for widget.js endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_widget_js_success(self, quart_test_client):
|
||||
"""GET /api/v1/embed/{bot_uuid}/widget.js returns JS."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js'
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert 'javascript' in response.content_type
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_widget_js_invalid_uuid(self, quart_test_client):
|
||||
"""GET widget.js with invalid UUID returns 400."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/embed/invalid-uuid/widget.js'
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_widget_js_bot_not_found(self, quart_test_client):
|
||||
"""GET widget.js for non-existent bot returns 404."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js'
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestEmbedLogoEndpoint:
|
||||
"""Tests for logo endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logo_success(self, quart_test_client):
|
||||
"""GET /api/v1/embed/logo returns image."""
|
||||
response = await quart_test_client.get('/api/v1/embed/logo')
|
||||
|
||||
assert response.status_code == 200
|
||||
assert 'image/webp' in response.content_type
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestEmbedTurnstileVerifyEndpoint:
|
||||
"""Tests for Turnstile verification endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turnstile_verify_no_secret(self, quart_test_client):
|
||||
"""POST turnstile verify without secret returns dummy token."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
|
||||
json={'token': 'test-token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'token' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turnstile_verify_invalid_uuid(self, quart_test_client):
|
||||
"""POST turnstile verify with invalid UUID returns 400."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/invalid-uuid/turnstile/verify',
|
||||
json={'token': 'test-token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turnstile_verify_missing_token(self, quart_test_client):
|
||||
"""POST turnstile verify without token returns 400."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
|
||||
json={}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestEmbedMessagesEndpoint:
|
||||
"""Tests for messages endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages_person_success(self, quart_test_client):
|
||||
"""GET messages/person returns messages."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/person',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'messages' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages_group_success(self, quart_test_client):
|
||||
"""GET messages/group returns messages."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/group',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages_invalid_session_type(self, quart_test_client):
|
||||
"""GET messages with invalid session_type returns 400."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/invalid',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestEmbedResetEndpoint:
|
||||
"""Tests for session reset endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_session_person_success(self, quart_test_client):
|
||||
"""POST reset/person resets session."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/reset/person',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_session_invalid_uuid(self, quart_test_client):
|
||||
"""POST reset with invalid UUID returns 400."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/invalid-uuid/reset/person',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestEmbedFeedbackEndpoint:
|
||||
"""Tests for feedback submission endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_feedback_like(self, quart_test_client):
|
||||
"""POST feedback with type=1 (like) succeeds."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
||||
json={'message_id': 'msg-123', 'feedback_type': 1}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'feedback_id' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_feedback_dislike(self, quart_test_client):
|
||||
"""POST feedback with type=2 (dislike) succeeds."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
||||
json={'message_id': 'msg-123', 'feedback_type': 2}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_feedback_invalid_type(self, quart_test_client):
|
||||
"""POST feedback with invalid type returns 400."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
||||
json={'message_id': 'msg-123', 'feedback_type': 99}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -1,259 +0,0 @@
|
||||
"""
|
||||
API integration tests for knowledge base endpoints.
|
||||
|
||||
Tests real HTTP API behavior for knowledge base management.
|
||||
|
||||
Run: uv run pytest tests/integration/api/test_knowledge.py -q
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, Mock
|
||||
|
||||
from tests.factories import FakeApp
|
||||
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""Break circular import chain for API controller."""
|
||||
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
|
||||
|
||||
class FakeMinimalApplication:
|
||||
pass
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.Application = FakeMinimalApplication
|
||||
|
||||
mock_entities = MagicMock()
|
||||
mock_entities.LifecycleControlScope = MockLifecycleControlScope
|
||||
|
||||
clear = [
|
||||
'langbot.pkg.api.http.controller.group',
|
||||
'langbot.pkg.api.http.controller.groups',
|
||||
'langbot.pkg.api.http.controller.groups.knowledge',
|
||||
'langbot.pkg.api.http.controller.groups.knowledge.base',
|
||||
'langbot.pkg.api.http.controller.groups.knowledge.engines',
|
||||
'langbot.pkg.api.http.controller.groups.knowledge.parsers',
|
||||
'langbot.pkg.api.http.controller.main',
|
||||
]
|
||||
|
||||
with isolated_sys_modules(
|
||||
mocks={
|
||||
'langbot.pkg.core.app': mock_app,
|
||||
'langbot.pkg.core.entities': mock_entities,
|
||||
},
|
||||
clear=clear,
|
||||
):
|
||||
import langbot.pkg.api.http.controller.groups.knowledge.base as _knowledge # noqa: E402, F401
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def fake_knowledge_app():
|
||||
"""Create FakeApp with knowledge services (module scope for reuse)."""
|
||||
app = FakeApp()
|
||||
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
|
||||
# Auth services
|
||||
app.user_service = Mock()
|
||||
app.user_service.is_initialized = AsyncMock(return_value=True)
|
||||
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
|
||||
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
|
||||
app.apikey_service = Mock()
|
||||
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
|
||||
|
||||
# Knowledge service
|
||||
app.knowledge_service = Mock()
|
||||
app.knowledge_service.get_knowledge_bases = AsyncMock(return_value=[
|
||||
{
|
||||
'uuid': 'test-kb-uuid',
|
||||
'name': 'Test Knowledge Base',
|
||||
'description': 'Test KB description',
|
||||
'engine_plugin_id': 'test/engine',
|
||||
'created_at': '2024-01-01T00:00:00',
|
||||
'updated_at': '2024-01-01T00:00:00',
|
||||
}
|
||||
])
|
||||
app.knowledge_service.get_knowledge_base = AsyncMock(return_value={
|
||||
'uuid': 'test-kb-uuid',
|
||||
'name': 'Test Knowledge Base',
|
||||
'description': 'Test KB description',
|
||||
'engine_plugin_id': 'test/engine',
|
||||
})
|
||||
app.knowledge_service.create_knowledge_base = AsyncMock(return_value={'uuid': 'new-kb-uuid'})
|
||||
app.knowledge_service.update_knowledge_base = AsyncMock(return_value={})
|
||||
app.knowledge_service.delete_knowledge_base = AsyncMock()
|
||||
app.knowledge_service.get_files_by_knowledge_base = AsyncMock(return_value=[
|
||||
{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}
|
||||
])
|
||||
app.knowledge_service.store_file = AsyncMock(return_value={'task_id': 'test-task-id'})
|
||||
app.knowledge_service.delete_file = AsyncMock()
|
||||
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[
|
||||
{'content': 'test result', 'score': 0.95}
|
||||
])
|
||||
|
||||
# RAG manager
|
||||
app.rag_mgr = Mock()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def quart_test_client(fake_knowledge_app, http_controller_cls):
|
||||
"""Create Quart test client (module scope to avoid route re-registration)."""
|
||||
controller = http_controller_cls(fake_knowledge_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestKnowledgeBaseEndpoints:
|
||||
"""Tests for /api/v1/knowledge/bases endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_knowledge_bases_success(self, quart_test_client):
|
||||
"""GET /api/v1/knowledge/bases returns knowledge base list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/knowledge/bases',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'data' in data
|
||||
assert 'bases' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_knowledge_base_success(self, quart_test_client):
|
||||
"""POST /api/v1/knowledge/bases creates new knowledge base."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/knowledge/bases',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'uuid' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_single_knowledge_base_success(self, quart_test_client):
|
||||
"""GET /api/v1/knowledge/bases/{uuid} returns knowledge base."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'base' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_knowledge_base_success(self, quart_test_client):
|
||||
"""PUT /api/v1/knowledge/bases/{uuid} updates knowledge base."""
|
||||
response = await quart_test_client.put(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'Updated KB'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_knowledge_base_success(self, quart_test_client):
|
||||
"""DELETE /api/v1/knowledge/bases/{uuid} deletes knowledge base."""
|
||||
response = await quart_test_client.delete(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestKnowledgeBaseFilesEndpoints:
|
||||
"""Tests for knowledge base files endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_files_success(self, quart_test_client):
|
||||
"""GET /api/v1/knowledge/bases/{uuid}/files returns files."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid/files',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'files' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_file_to_knowledge_base(self, quart_test_client):
|
||||
"""POST /api/v1/knowledge/bases/{uuid}/files adds file."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid/files',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'task_id' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_file_from_knowledge_base(self, quart_test_client):
|
||||
"""DELETE /api/v1/knowledge/bases/{uuid}/files/{file_id}."""
|
||||
response = await quart_test_client.delete(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestKnowledgeBaseRetrieveEndpoint:
|
||||
"""Tests for knowledge base retrieval endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_knowledge_success(self, quart_test_client):
|
||||
"""POST /api/v1/knowledge/bases/{uuid}/retrieve."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'results' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_without_query_returns_error(self, quart_test_client):
|
||||
"""POST retrieve without query returns 400."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
data = await response.get_json()
|
||||
assert data['code'] == -1
|
||||
@@ -1,330 +0,0 @@
|
||||
"""
|
||||
API integration tests for monitoring endpoints.
|
||||
|
||||
Tests real HTTP API behavior for monitoring data retrieval.
|
||||
|
||||
Run: uv run pytest tests/integration/api/test_monitoring.py -q
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, Mock
|
||||
|
||||
from tests.factories import FakeApp
|
||||
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""Break circular import chain for API controller."""
|
||||
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
|
||||
|
||||
class FakeMinimalApplication:
|
||||
pass
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.Application = FakeMinimalApplication
|
||||
|
||||
mock_entities = MagicMock()
|
||||
mock_entities.LifecycleControlScope = MockLifecycleControlScope
|
||||
|
||||
clear = [
|
||||
'langbot.pkg.api.http.controller.group',
|
||||
'langbot.pkg.api.http.controller.groups',
|
||||
'langbot.pkg.api.http.controller.groups.monitoring',
|
||||
'langbot.pkg.api.http.controller.main',
|
||||
]
|
||||
|
||||
with isolated_sys_modules(
|
||||
mocks={
|
||||
'langbot.pkg.core.app': mock_app,
|
||||
'langbot.pkg.core.entities': mock_entities,
|
||||
},
|
||||
clear=clear,
|
||||
):
|
||||
import langbot.pkg.api.http.controller.groups.monitoring as _monitoring # noqa: E402, F401
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def fake_monitoring_app():
|
||||
"""Create FakeApp with monitoring services (module scope)."""
|
||||
app = FakeApp()
|
||||
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
|
||||
# Auth services - USER_TOKEN auth requires jwt verification AND get_user_by_email
|
||||
app.user_service = Mock()
|
||||
app.user_service.is_initialized = AsyncMock(return_value=True)
|
||||
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
|
||||
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
|
||||
|
||||
# Monitoring service
|
||||
app.monitoring_service = Mock()
|
||||
app.monitoring_service.get_overview_metrics = AsyncMock(return_value={
|
||||
'total_messages': 100,
|
||||
'total_llm_calls': 50,
|
||||
'total_sessions': 20,
|
||||
'active_sessions': 5,
|
||||
'total_errors': 2,
|
||||
})
|
||||
app.monitoring_service.get_messages = AsyncMock(return_value=(
|
||||
[{'id': 'msg-1', 'content': 'test'}], 100
|
||||
))
|
||||
app.monitoring_service.get_llm_calls = AsyncMock(return_value=(
|
||||
[{'id': 'llm-1'}], 50
|
||||
))
|
||||
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=(
|
||||
[{'id': 'emb-1'}], 10
|
||||
))
|
||||
app.monitoring_service.get_sessions = AsyncMock(return_value=(
|
||||
[{'session_id': 'sess-1'}], 20
|
||||
))
|
||||
app.monitoring_service.get_errors = AsyncMock(return_value=(
|
||||
[{'id': 'err-1'}], 2
|
||||
))
|
||||
app.monitoring_service.get_session_analysis = AsyncMock(return_value={
|
||||
'found': True,
|
||||
'session_id': 'sess-1',
|
||||
})
|
||||
app.monitoring_service.get_message_details = AsyncMock(return_value={
|
||||
'found': True,
|
||||
'message_id': 'msg-1',
|
||||
})
|
||||
app.monitoring_service.get_feedback_stats = AsyncMock(return_value={'like_count': 10})
|
||||
app.monitoring_service.get_feedback_list = AsyncMock(return_value=(
|
||||
[{'feedback_id': 'fb-1'}], 12
|
||||
))
|
||||
app.monitoring_service.export_messages = AsyncMock(return_value=[{'id': 'msg-1'}])
|
||||
app.monitoring_service.export_llm_calls = AsyncMock(return_value=[{'id': 'llm-1'}])
|
||||
app.monitoring_service.export_errors = AsyncMock(return_value=[{'id': 'err-1'}])
|
||||
app.monitoring_service.export_sessions = AsyncMock(return_value=[{'session_id': 'sess-1'}])
|
||||
app.monitoring_service.export_feedback = AsyncMock(return_value=[{'id': 'fb-1'}])
|
||||
app.monitoring_service.export_embedding_calls = AsyncMock(return_value=[{'id': 'emb-1'}])
|
||||
app.monitoring_service._escape_csv_field = Mock(return_value='escaped')
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def quart_test_client(fake_monitoring_app, http_controller_cls):
|
||||
"""Create Quart test client (module scope)."""
|
||||
controller = http_controller_cls(fake_monitoring_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestMonitoringOverviewEndpoint:
|
||||
"""Tests for /api/v1/monitoring/overview endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_overview_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/overview returns metrics."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/overview',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestMonitoringMessagesEndpoint:
|
||||
"""Tests for /api/v1/monitoring/messages endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/messages returns message list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/messages',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'messages' in data['data']
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestMonitoringLLMCallsEndpoint:
|
||||
"""Tests for /api/v1/monitoring/llm-calls endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_llm_calls_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/llm-calls."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/llm-calls',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestMonitoringEmbeddingCallsEndpoint:
|
||||
"""Tests for /api/v1/monitoring/embedding-calls endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_embedding_calls_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/embedding-calls."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/embedding-calls',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestMonitoringSessionsEndpoint:
|
||||
"""Tests for /api/v1/monitoring/sessions endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sessions_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/sessions."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/sessions',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestMonitoringErrorsEndpoint:
|
||||
"""Tests for /api/v1/monitoring/errors endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_errors_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/errors."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/errors',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestMonitoringAllDataEndpoint:
|
||||
"""Tests for /api/v1/monitoring/data endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_data_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/data returns all data."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/data',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert 'overview' in data['data']
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestMonitoringDetailsEndpoints:
|
||||
"""Tests for detail endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_analysis(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/sessions/{id}/analysis."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/sessions/sess-1/analysis',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_message_details(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/messages/{id}/details."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/messages/msg-1/details',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestMonitoringFeedbackEndpoints:
|
||||
"""Tests for feedback endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_feedback_stats(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/feedback/stats."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/feedback/stats',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_feedback_list(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/feedback."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/feedback',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestMonitoringExportEndpoint:
|
||||
"""Tests for /api/v1/monitoring/export endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_messages(self, quart_test_client):
|
||||
"""GET export?type=messages returns CSV."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/export?type=messages',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert 'text/csv' in response.content_type
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_llm_calls(self, quart_test_client):
|
||||
"""GET export?type=llm-calls returns CSV."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/export?type=llm-calls',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_sessions(self, quart_test_client):
|
||||
"""GET export?type=sessions returns CSV."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/export?type=sessions',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_feedback(self, quart_test_client):
|
||||
"""GET export?type=feedback returns CSV."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/export?type=feedback',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -1,273 +0,0 @@
|
||||
"""
|
||||
API integration tests for pipeline endpoints.
|
||||
|
||||
Tests real HTTP API behavior using Quart test client with mocked services.
|
||||
Extends test_smoke.py coverage for pipeline-related endpoints.
|
||||
|
||||
Run: uv run pytest tests/integration/api/test_pipelines.py -q
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, Mock
|
||||
|
||||
from tests.factories import FakeApp
|
||||
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""Break circular import chain for API controller."""
|
||||
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
|
||||
|
||||
class FakeMinimalApplication:
|
||||
pass
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.Application = FakeMinimalApplication
|
||||
|
||||
mock_entities = MagicMock()
|
||||
mock_entities.LifecycleControlScope = MockLifecycleControlScope
|
||||
|
||||
clear = [
|
||||
'langbot.pkg.api.http.controller.group',
|
||||
'langbot.pkg.api.http.controller.groups',
|
||||
'langbot.pkg.api.http.controller.groups.pipelines',
|
||||
'langbot.pkg.api.http.controller.groups.pipelines.pipelines',
|
||||
'langbot.pkg.api.http.controller.groups.pipelines.embed',
|
||||
'langbot.pkg.api.http.controller.groups.pipelines.websocket_chat',
|
||||
'langbot.pkg.api.http.controller.main',
|
||||
]
|
||||
|
||||
with isolated_sys_modules(
|
||||
mocks={
|
||||
'langbot.pkg.core.app': mock_app,
|
||||
'langbot.pkg.core.entities': mock_entities,
|
||||
},
|
||||
clear=clear,
|
||||
):
|
||||
# Import groups after mocking to populate preregistered_groups
|
||||
import langbot.pkg.api.http.controller.groups.pipelines.pipelines as _pipelines # noqa: E402, F401
|
||||
yield
|
||||
|
||||
|
||||
# ============== FAKE APPLICATION WITH PIPELINE SERVICES ==============
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def fake_pipeline_app():
|
||||
"""Create FakeApp with pipeline-specific services (module scope for reuse)."""
|
||||
app = FakeApp()
|
||||
|
||||
# Pipeline config
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
|
||||
# Auth services
|
||||
app.user_service = Mock()
|
||||
app.user_service.is_initialized = AsyncMock(return_value=True)
|
||||
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
|
||||
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
|
||||
app.apikey_service = Mock()
|
||||
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
|
||||
|
||||
# Pipeline service
|
||||
app.pipeline_service = Mock()
|
||||
app.pipeline_service.get_pipeline_metadata = AsyncMock(return_value=[
|
||||
{'name': 'trigger', 'stages': []},
|
||||
{'name': 'ai', 'stages': []},
|
||||
])
|
||||
app.pipeline_service.get_pipelines = AsyncMock(return_value=[
|
||||
{
|
||||
'uuid': 'test-pipeline-uuid',
|
||||
'name': 'Test Pipeline',
|
||||
'description': 'Test description',
|
||||
'created_at': '2024-01-01T00:00:00',
|
||||
'updated_at': '2024-01-01T00:00:00',
|
||||
'is_default': False,
|
||||
}
|
||||
])
|
||||
app.pipeline_service.get_pipeline = AsyncMock(return_value={
|
||||
'uuid': 'test-pipeline-uuid',
|
||||
'name': 'Test Pipeline',
|
||||
'config': {},
|
||||
})
|
||||
app.pipeline_service.create_pipeline = AsyncMock(return_value={'uuid': 'new-pipeline-uuid'})
|
||||
app.pipeline_service.update_pipeline = AsyncMock(return_value={})
|
||||
app.pipeline_service.delete_pipeline = AsyncMock()
|
||||
app.pipeline_service.copy_pipeline = AsyncMock(return_value={'uuid': 'copied-pipeline-uuid'})
|
||||
|
||||
# Bot service
|
||||
app.bot_service = Mock()
|
||||
app.bot_service.get_bots = AsyncMock(return_value=[])
|
||||
app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'})
|
||||
|
||||
# MCP service (for extensions endpoint)
|
||||
app.mcp_service = Mock()
|
||||
app.mcp_service.get_mcp_servers = AsyncMock(return_value=[])
|
||||
|
||||
# Plugin connector (for extensions endpoint)
|
||||
app.plugin_connector.list_plugins = AsyncMock(return_value=[])
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def quart_test_client(fake_pipeline_app, http_controller_cls):
|
||||
"""Create Quart test client (module scope to avoid route re-registration)."""
|
||||
controller = http_controller_cls(fake_pipeline_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
yield client
|
||||
|
||||
|
||||
# ============== PIPELINE ENDPOINT TESTS ==============
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestPipelineMetadataEndpoint:
|
||||
"""Tests for /api/v1/pipelines/_/metadata endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pipeline_metadata_success(self, quart_test_client):
|
||||
"""GET /api/v1/pipelines/_/metadata returns metadata list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/pipelines/_/metadata',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'data' in data
|
||||
assert isinstance(data['data'], dict)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pipeline_metadata_requires_auth(self, quart_test_client):
|
||||
"""Pipeline metadata endpoint requires authentication."""
|
||||
response = await quart_test_client.get('/api/v1/pipelines/_/metadata')
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestPipelinesListEndpoint:
|
||||
"""Tests for /api/v1/pipelines endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pipelines_success(self, quart_test_client):
|
||||
"""GET /api/v1/pipelines returns pipeline list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/pipelines',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'data' in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pipelines_with_sort_param(self, quart_test_client):
|
||||
"""GET pipelines with sort parameter."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/pipelines?sort_by=created_at&sort_order=DESC',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestPipelinesCRUDEndpoints:
|
||||
"""Tests for pipeline CRUD operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_single_pipeline_success(self, quart_test_client):
|
||||
"""GET /api/v1/pipelines/{uuid} returns pipeline."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/pipelines/test-pipeline-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'data' in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pipeline_success(self, quart_test_client):
|
||||
"""POST /api/v1/pipelines creates new pipeline."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/pipelines',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'New Pipeline', 'config': {}}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'uuid' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_pipeline_success(self, quart_test_client):
|
||||
"""PUT /api/v1/pipelines/{uuid} updates pipeline."""
|
||||
response = await quart_test_client.put(
|
||||
'/api/v1/pipelines/test-pipeline-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'Updated Pipeline'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_pipeline_success(self, quart_test_client):
|
||||
"""DELETE /api/v1/pipelines/{uuid} deletes pipeline."""
|
||||
response = await quart_test_client.delete(
|
||||
'/api/v1/pipelines/test-pipeline-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_copy_pipeline_success(self, quart_test_client):
|
||||
"""POST /api/v1/pipelines/{uuid}/copy copies pipeline."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/pipelines/test-pipeline-uuid/copy',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'uuid' in data['data']
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestPipelineExtensionsEndpoint:
|
||||
"""Tests for pipeline extensions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_extensions(self, quart_test_client):
|
||||
"""GET /api/v1/pipelines/{uuid}/extensions."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/pipelines/test-pipeline-uuid/extensions',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
# Should return 200 if pipeline found
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
@@ -1,347 +0,0 @@
|
||||
"""
|
||||
API integration tests for provider/model endpoints.
|
||||
|
||||
Tests real HTTP API behavior for provider and model management.
|
||||
|
||||
Run: uv run pytest tests/integration/api/test_providers.py -q
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, Mock
|
||||
|
||||
from tests.factories import FakeApp
|
||||
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""Break circular import chain for API controller."""
|
||||
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
|
||||
|
||||
class FakeMinimalApplication:
|
||||
pass
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.Application = FakeMinimalApplication
|
||||
|
||||
mock_entities = MagicMock()
|
||||
mock_entities.LifecycleControlScope = MockLifecycleControlScope
|
||||
|
||||
clear = [
|
||||
'langbot.pkg.api.http.controller.group',
|
||||
'langbot.pkg.api.http.controller.groups',
|
||||
'langbot.pkg.api.http.controller.groups.provider',
|
||||
'langbot.pkg.api.http.controller.groups.provider.providers',
|
||||
'langbot.pkg.api.http.controller.groups.provider.models',
|
||||
'langbot.pkg.api.http.controller.main',
|
||||
]
|
||||
|
||||
with isolated_sys_modules(
|
||||
mocks={
|
||||
'langbot.pkg.core.app': mock_app,
|
||||
'langbot.pkg.core.entities': mock_entities,
|
||||
},
|
||||
clear=clear,
|
||||
):
|
||||
import langbot.pkg.api.http.controller.groups.provider.providers as _providers # noqa: E402, F401
|
||||
import langbot.pkg.api.http.controller.groups.provider.models as _models # noqa: E402, F401
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def fake_provider_app():
|
||||
"""Create FakeApp with provider/model services (module scope for reuse)."""
|
||||
app = FakeApp()
|
||||
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
|
||||
# Auth services
|
||||
app.user_service = Mock()
|
||||
app.user_service.is_initialized = AsyncMock(return_value=True)
|
||||
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
|
||||
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
|
||||
app.apikey_service = Mock()
|
||||
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
|
||||
|
||||
# Provider service
|
||||
app.provider_service = Mock()
|
||||
app.provider_service.get_providers = AsyncMock(return_value=[
|
||||
{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
|
||||
])
|
||||
app.provider_service.get_provider = AsyncMock(return_value={
|
||||
'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'
|
||||
})
|
||||
app.provider_service.create_provider = AsyncMock(return_value='new-provider-uuid')
|
||||
app.provider_service.update_provider = AsyncMock(return_value={})
|
||||
app.provider_service.delete_provider = AsyncMock()
|
||||
app.provider_service.get_provider_model_counts = AsyncMock(return_value={
|
||||
'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0
|
||||
})
|
||||
|
||||
# LLM model service
|
||||
app.llm_model_service = Mock()
|
||||
app.llm_model_service.get_llm_models = AsyncMock(return_value=[
|
||||
{'uuid': 'test-model-uuid', 'name': 'gpt-4'}
|
||||
])
|
||||
app.llm_model_service.get_llm_model = AsyncMock(return_value={
|
||||
'uuid': 'test-model-uuid', 'name': 'gpt-4'
|
||||
})
|
||||
app.llm_model_service.create_llm_model = AsyncMock(return_value={'uuid': 'new-model-uuid'})
|
||||
app.llm_model_service.update_llm_model = AsyncMock(return_value={})
|
||||
app.llm_model_service.delete_llm_model = AsyncMock()
|
||||
|
||||
# Embedding model service
|
||||
app.embedding_models_service = Mock()
|
||||
app.embedding_models_service.get_embedding_models = AsyncMock(return_value=[])
|
||||
app.embedding_models_service.create_embedding_model = AsyncMock(return_value={'uuid': 'new-embedding-uuid'})
|
||||
|
||||
# Rerank model service
|
||||
app.rerank_models_service = Mock()
|
||||
app.rerank_models_service.get_rerank_models = AsyncMock(return_value=[])
|
||||
app.rerank_models_service.create_rerank_model = AsyncMock(return_value={'uuid': 'new-rerank-uuid'})
|
||||
|
||||
# Model manager
|
||||
app.model_mgr = Mock()
|
||||
app.model_mgr.load_provider = AsyncMock()
|
||||
app.model_mgr.unload_provider = AsyncMock()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def quart_test_client(fake_provider_app, http_controller_cls):
|
||||
"""Create Quart test client (module scope to avoid route re-registration)."""
|
||||
controller = http_controller_cls(fake_provider_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestProviderEndpoints:
|
||||
"""Tests for /api/v1/provider endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_providers_success(self, quart_test_client):
|
||||
"""GET /api/v1/provider/providers returns provider list with complete structure."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/providers',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'data' in data
|
||||
# Verify response structure completeness
|
||||
providers = data['data']['providers']
|
||||
assert isinstance(providers, list)
|
||||
assert len(providers) == 1
|
||||
# Verify required fields in provider object
|
||||
provider = providers[0]
|
||||
assert 'uuid' in provider
|
||||
assert 'name' in provider
|
||||
assert 'requester' in provider
|
||||
assert provider['uuid'] == 'test-provider-uuid'
|
||||
assert provider['name'] == 'OpenAI'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_single_provider_success(self, quart_test_client):
|
||||
"""GET /api/v1/provider/providers/{uuid} returns complete provider structure."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/providers/test-provider-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
# Verify response structure
|
||||
provider = data['data']['provider']
|
||||
assert 'uuid' in provider
|
||||
assert 'name' in provider
|
||||
assert 'requester' in provider
|
||||
assert provider['uuid'] == 'test-provider-uuid'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_provider_success(self, quart_test_client):
|
||||
"""POST /api/v1/provider/providers creates new provider with uuid returned."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/provider/providers',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'New Provider', 'requester': 'chatcmpl'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
# Verify uuid is present and matches expected
|
||||
assert 'data' in data
|
||||
assert 'uuid' in data['data']
|
||||
assert data['data']['uuid'] == 'new-provider-uuid'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_provider_success(self, quart_test_client):
|
||||
"""PUT /api/v1/provider/providers/{uuid} updates provider."""
|
||||
response = await quart_test_client.put(
|
||||
'/api/v1/provider/providers/test-provider-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'Updated Provider'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_provider_success(self, quart_test_client):
|
||||
"""DELETE /api/v1/provider/providers/{uuid} deletes provider."""
|
||||
response = await quart_test_client.delete(
|
||||
'/api/v1/provider/providers/test-provider-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_includes_model_counts(self, quart_test_client):
|
||||
"""GET provider response includes model counts."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/providers/test-provider-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
# Model counts are embedded in provider response
|
||||
provider_data = data['data']['provider']
|
||||
assert 'llm_count' in provider_data
|
||||
assert 'embedding_count' in provider_data
|
||||
assert 'rerank_count' in provider_data
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestModelEndpoints:
|
||||
"""Tests for /api/v1/provider/models endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_llm_models_success(self, quart_test_client):
|
||||
"""GET /api/v1/provider/models/llm returns model list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/models/llm',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'data' in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_single_llm_model_success(self, quart_test_client):
|
||||
"""GET /api/v1/provider/models/llm/{uuid} returns model."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/models/llm/test-model-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_llm_model_success(self, quart_test_client):
|
||||
"""POST /api/v1/provider/models/llm creates new model."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/provider/models/llm',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'uuid' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_llm_model_success(self, quart_test_client):
|
||||
"""DELETE /api/v1/provider/models/llm/{uuid} deletes model."""
|
||||
response = await quart_test_client.delete(
|
||||
'/api/v1/provider/models/llm/test-model-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestEmbeddingModelEndpoints:
|
||||
"""Tests for /api/v1/provider/models/embedding endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_embedding_models_success(self, quart_test_client):
|
||||
"""GET /api/v1/provider/models/embedding returns model list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/models/embedding',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'models' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_embedding_model_success(self, quart_test_client):
|
||||
"""POST /api/v1/provider/models/embedding creates new model."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/provider/models/embedding',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'uuid' in data['data']
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestRerankModelEndpoints:
|
||||
"""Tests for /api/v1/provider/models/rerank endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rerank_models_success(self, quart_test_client):
|
||||
"""GET /api/v1/provider/models/rerank returns model list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/models/rerank',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'models' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_rerank_model_success(self, quart_test_client):
|
||||
"""POST /api/v1/provider/models/rerank creates new model."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/provider/models/rerank',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['code'] == 0
|
||||
assert 'uuid' in data['data']
|
||||
@@ -1,345 +0,0 @@
|
||||
"""
|
||||
API smoke integration tests.
|
||||
|
||||
Tests real HTTP API behavior using Quart test client.
|
||||
Validates controller/service/routing wiring without real provider/platform.
|
||||
|
||||
Run: uv run pytest tests/integration/api/test_smoke.py -q
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, Mock
|
||||
|
||||
from tests.factories import FakeApp
|
||||
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""
|
||||
Break circular import chain for API controller using isolated_sys_modules.
|
||||
|
||||
Chain: http_controller → groups/plugins → core.app → pipeline entities
|
||||
|
||||
We need to mock core.app to prevent the circular chain when importing HTTPController.
|
||||
But we must allow groups to be imported to populate preregistered_groups.
|
||||
"""
|
||||
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
|
||||
|
||||
# Mock core.app with minimal Application that groups can reference
|
||||
class FakeMinimalApplication:
|
||||
pass
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.Application = FakeMinimalApplication
|
||||
|
||||
# Mock core.entities with proper Enum
|
||||
mock_entities = MagicMock()
|
||||
mock_entities.LifecycleControlScope = MockLifecycleControlScope
|
||||
|
||||
# Modules to clear (force re-import after mocking)
|
||||
clear = [
|
||||
'langbot.pkg.api.http.controller.group',
|
||||
'langbot.pkg.api.http.controller.groups',
|
||||
'langbot.pkg.api.http.controller.groups.system',
|
||||
'langbot.pkg.api.http.controller.groups.user',
|
||||
'langbot.pkg.api.http.controller.main',
|
||||
]
|
||||
|
||||
with isolated_sys_modules(
|
||||
mocks={
|
||||
'langbot.pkg.core.app': mock_app,
|
||||
'langbot.pkg.core.entities': mock_entities,
|
||||
},
|
||||
clear=clear,
|
||||
):
|
||||
# Import groups after mocking core.app/core.entities
|
||||
import langbot.pkg.api.http.controller.group as _group_module # noqa: E402, F401
|
||||
import langbot.pkg.api.http.controller.groups.system as _system_group # noqa: E402, F401
|
||||
import langbot.pkg.api.http.controller.groups.user as _user_group # noqa: E402, F401
|
||||
|
||||
yield
|
||||
|
||||
|
||||
# ============== FAKE APPLICATION FOR API TESTS ==============
|
||||
|
||||
@pytest.fixture
|
||||
def fake_api_app():
|
||||
"""
|
||||
Create minimal FakeApp for API smoke tests with all required services.
|
||||
|
||||
Uses tests.factories.FakeApp as base and adds API-specific services.
|
||||
"""
|
||||
app = FakeApp()
|
||||
|
||||
# API-specific config
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'plugin': {'enable_marketplace': True},
|
||||
'space': {'url': 'https://space.langbot.app'},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
|
||||
# API-specific services
|
||||
app.user_service = Mock()
|
||||
app.user_service.is_initialized = AsyncMock(return_value=False)
|
||||
app.user_service.authenticate = AsyncMock(return_value='fake_token')
|
||||
app.user_service.create_user = AsyncMock()
|
||||
app.user_service.verify_jwt_token = AsyncMock(side_effect=ValueError('Invalid token'))
|
||||
app.user_service.get_user_by_email = AsyncMock(return_value=Mock())
|
||||
app.user_service.generate_jwt_token = AsyncMock(return_value='fake_token')
|
||||
|
||||
app.apikey_service = Mock()
|
||||
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
|
||||
|
||||
app.maintenance_service = Mock()
|
||||
app.maintenance_service.get_storage_analysis = AsyncMock(return_value={})
|
||||
|
||||
app.plugin_connector.is_enable_plugin = False
|
||||
app.plugin_connector.ping_plugin_runtime = AsyncMock()
|
||||
|
||||
app.task_mgr.get_tasks_dict = Mock(return_value={'tasks': []})
|
||||
app.task_mgr.get_task_by_id = Mock(return_value=None)
|
||||
|
||||
# Required by controller groups
|
||||
app.model_mgr = Mock()
|
||||
app.platform_mgr = Mock()
|
||||
app.pipeline_pool = Mock()
|
||||
app.pipeline_mgr = Mock()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ============== QUART TEST CLIENT FIXTURE ==============
|
||||
|
||||
@pytest.fixture
|
||||
async def quart_test_client(fake_api_app, http_controller_cls):
|
||||
"""
|
||||
Create Quart test client with real HTTPController and route registration.
|
||||
|
||||
Requires mock_circular_import_chain fixture to run first (usefixtures).
|
||||
"""
|
||||
controller = http_controller_cls(fake_api_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
|
||||
yield client
|
||||
|
||||
|
||||
# ============== API SMOKE TESTS ==============
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestHealthEndpoint:
|
||||
"""Tests for /healthz endpoint - simplest smoke test."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthz_returns_ok(self, quart_test_client):
|
||||
"""
|
||||
/healthz endpoint returns {'code': 0, 'msg': 'ok'}.
|
||||
|
||||
This tests:
|
||||
- HTTPController instantiation
|
||||
- Quart app creation
|
||||
- Route registration
|
||||
- Basic response handling
|
||||
"""
|
||||
response = await quart_test_client.get('/healthz')
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data == {'code': 0, 'msg': 'ok'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthz_no_auth_required(self, quart_test_client):
|
||||
"""
|
||||
/healthz doesn't require authentication.
|
||||
|
||||
Tests that AuthType.NONE endpoints work without headers.
|
||||
"""
|
||||
response = await quart_test_client.get('/healthz')
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestSystemEndpoint:
|
||||
"""Tests for /api/v1/system endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_info_no_auth(self, quart_test_client):
|
||||
"""
|
||||
/api/v1/system/info returns system information without auth.
|
||||
|
||||
AuthType.NONE endpoint.
|
||||
"""
|
||||
response = await quart_test_client.get('/api/v1/system/info')
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
|
||||
# Verify response structure
|
||||
assert data['code'] == 0
|
||||
assert data['msg'] == 'ok'
|
||||
assert 'data' in data
|
||||
|
||||
# Verify expected fields
|
||||
system_data = data['data']
|
||||
assert 'version' in system_data
|
||||
assert 'debug' in system_data
|
||||
assert 'edition' in system_data
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestProtectedEndpoints:
|
||||
"""Tests for authentication/authorization behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_protected_endpoint_rejects_no_token(self, quart_test_client):
|
||||
"""
|
||||
Protected endpoint (USER_TOKEN) returns 401 without auth.
|
||||
|
||||
Tests that AuthType.USER_TOKEN properly rejects unauthorized requests.
|
||||
"""
|
||||
# /api/v1/user/check-token requires USER_TOKEN
|
||||
response = await quart_test_client.get('/api/v1/user/check-token')
|
||||
|
||||
assert response.status_code == 401
|
||||
data = await response.get_json()
|
||||
|
||||
# Verify error response structure
|
||||
assert data['code'] == -1
|
||||
assert 'msg' in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_protected_endpoint_with_invalid_token(self, quart_test_client):
|
||||
"""
|
||||
Protected endpoint returns 401 with invalid token.
|
||||
"""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/user/check-token',
|
||||
headers={'Authorization': 'Bearer invalid_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestInvalidPayload:
|
||||
"""Tests for error handling with invalid payloads."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_json_body(self, quart_test_client):
|
||||
"""
|
||||
POST endpoint without JSON body handles gracefully.
|
||||
"""
|
||||
# /api/v1/user/auth expects JSON with 'user' and 'password'
|
||||
response = await quart_test_client.post('/api/v1/user/auth')
|
||||
|
||||
# Should return error (500, 400, or 401) with stable JSON structure
|
||||
assert response.status_code in (400, 500, 401)
|
||||
data = await response.get_json()
|
||||
|
||||
# Verify error response has expected structure
|
||||
assert 'code' in data
|
||||
assert 'msg' in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_structure(self, quart_test_client):
|
||||
"""
|
||||
POST with wrong JSON structure returns stable error.
|
||||
"""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/user/auth',
|
||||
json={'wrong_field': 'value'}
|
||||
)
|
||||
|
||||
# Should return error with stable JSON structure
|
||||
assert response.status_code in (400, 500, 401)
|
||||
data = await response.get_json()
|
||||
assert 'code' in data
|
||||
assert 'msg' in data
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestUserInitEndpoint:
|
||||
"""Tests for /api/v1/user/init endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_init_get_returns_not_initialized(self, quart_test_client):
|
||||
"""
|
||||
GET /api/v1/user/init returns initialized status.
|
||||
|
||||
Uses fake user_service.is_initialized() = False.
|
||||
"""
|
||||
response = await quart_test_client.get('/api/v1/user/init')
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
|
||||
assert data['code'] == 0
|
||||
assert data['msg'] == 'ok'
|
||||
assert data['data']['initialized'] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestRealImports:
|
||||
"""Tests that verify real production code is imported."""
|
||||
|
||||
def test_http_controller_real_import(self):
|
||||
"""
|
||||
Verify HTTPController is real production class, not mock.
|
||||
"""
|
||||
from langbot.pkg.api.http.controller.main import HTTPController
|
||||
|
||||
assert HTTPController.__name__ == 'HTTPController'
|
||||
assert hasattr(HTTPController, 'initialize')
|
||||
assert hasattr(HTTPController, 'register_routes')
|
||||
|
||||
def test_group_real_import(self):
|
||||
"""
|
||||
Verify RouterGroup and AuthType are real production classes.
|
||||
"""
|
||||
from langbot.pkg.api.http.controller.group import RouterGroup, AuthType, preregistered_groups
|
||||
|
||||
assert RouterGroup.__name__ == 'RouterGroup'
|
||||
assert hasattr(AuthType, 'NONE')
|
||||
assert hasattr(AuthType, 'USER_TOKEN')
|
||||
assert isinstance(preregistered_groups, list)
|
||||
|
||||
def test_system_group_registered(self):
|
||||
"""
|
||||
Verify SystemRouterGroup is registered in preregistered_groups.
|
||||
"""
|
||||
from langbot.pkg.api.http.controller.group import preregistered_groups
|
||||
|
||||
# Find system group
|
||||
system_group = None
|
||||
for g in preregistered_groups:
|
||||
if g.name == 'system':
|
||||
system_group = g
|
||||
break
|
||||
|
||||
assert system_group is not None
|
||||
assert system_group.path == '/api/v1/system'
|
||||
|
||||
def test_user_group_registered(self):
|
||||
"""
|
||||
Verify UserRouterGroup is registered in preregistered_groups.
|
||||
"""
|
||||
from langbot.pkg.api.http.controller.group import preregistered_groups
|
||||
|
||||
# Find user group
|
||||
user_group = None
|
||||
for g in preregistered_groups:
|
||||
if g.name == 'user':
|
||||
user_group = g
|
||||
break
|
||||
|
||||
assert user_group is not None
|
||||
assert user_group.path == '/api/v1/user'
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
Persistence integration tests package.
|
||||
|
||||
Tests for database migrations and storage behavior.
|
||||
"""
|
||||
@@ -1,251 +0,0 @@
|
||||
"""
|
||||
SQLite migration integration tests.
|
||||
|
||||
Tests real Alembic migration behavior using temporary SQLite databases.
|
||||
Validates the migration workflow from .github/workflows/test-migrations.yml.
|
||||
|
||||
Run: uv run pytest tests/integration/persistence/test_migrations.py -q
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.persistence.alembic_runner import (
|
||||
run_alembic_upgrade,
|
||||
run_alembic_stamp,
|
||||
get_alembic_current,
|
||||
)
|
||||
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sqlite_db_url(tmp_path):
|
||||
"""Create SQLite URL with temporary database file."""
|
||||
db_file = tmp_path / "test_migrations.db"
|
||||
return f"sqlite+aiosqlite:///{db_file}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sqlite_engine(sqlite_db_url):
|
||||
"""Create async SQLite engine."""
|
||||
engine = create_async_engine(sqlite_db_url)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
class TestSQLiteMigrationBaseline:
|
||||
"""Tests for baseline stamp workflow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_stamp_sets_revision(self, sqlite_engine):
|
||||
"""
|
||||
Stamp baseline on existing tables sets correct revision.
|
||||
|
||||
Workflow:
|
||||
1. Create tables via Base.metadata.create_all
|
||||
2. Stamp with '0001_baseline'
|
||||
3. Verify current revision is '0001_baseline'
|
||||
"""
|
||||
# Create all tables (simulates existing DB created by ORM)
|
||||
async with sqlite_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp baseline
|
||||
await run_alembic_stamp(sqlite_engine, '0001_baseline')
|
||||
|
||||
# Verify revision
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_stamp_on_empty_db(self, sqlite_engine):
|
||||
"""
|
||||
Stamp on empty database (no tables) still sets revision.
|
||||
|
||||
This is an edge case - stamping without tables.
|
||||
"""
|
||||
# Don't create tables - stamp directly
|
||||
await run_alembic_stamp(sqlite_engine, '0001_baseline')
|
||||
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev == '0001_baseline'
|
||||
|
||||
|
||||
class TestSQLiteMigrationUpgrade:
|
||||
"""Tests for upgrade to head workflow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upgrade_from_baseline_to_head(self, sqlite_engine):
|
||||
"""
|
||||
Upgrade from baseline to head applies all migrations.
|
||||
|
||||
Workflow:
|
||||
1. Create tables
|
||||
2. Stamp baseline
|
||||
3. Upgrade to head
|
||||
4. Verify current revision is head
|
||||
"""
|
||||
# Create tables
|
||||
async with sqlite_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp baseline
|
||||
await run_alembic_stamp(sqlite_engine, '0001_baseline')
|
||||
|
||||
# Upgrade to head
|
||||
await run_alembic_upgrade(sqlite_engine, 'head')
|
||||
|
||||
# Verify revision
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev is not None, "Expected a revision after upgrade"
|
||||
# Head should be the latest migration
|
||||
assert rev.startswith('0003'), f"Expected head to be 0003_*, got {rev}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upgrade_idempotent(self, sqlite_engine):
|
||||
"""
|
||||
Running upgrade to head multiple times is idempotent.
|
||||
|
||||
Workflow:
|
||||
1. Upgrade to head
|
||||
2. Get revision
|
||||
3. Upgrade to head again
|
||||
4. Verify same revision
|
||||
"""
|
||||
# Create tables
|
||||
async with sqlite_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp and upgrade
|
||||
await run_alembic_stamp(sqlite_engine, '0001_baseline')
|
||||
await run_alembic_upgrade(sqlite_engine, 'head')
|
||||
|
||||
rev1 = await get_alembic_current(sqlite_engine)
|
||||
|
||||
# Upgrade again - should be idempotent
|
||||
await run_alembic_upgrade(sqlite_engine, 'head')
|
||||
|
||||
rev2 = await get_alembic_current(sqlite_engine)
|
||||
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
|
||||
|
||||
|
||||
class TestSQLiteMigrationFreshDatabase:
|
||||
"""Tests for fresh database workflow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fresh_db_upgrade_from_scratch(self, tmp_path):
|
||||
"""
|
||||
Fresh database (no tables) can be upgraded directly to head.
|
||||
|
||||
Workflow:
|
||||
1. Create fresh engine with new DB file
|
||||
2. Create tables
|
||||
3. Upgrade to head
|
||||
4. Verify revision
|
||||
"""
|
||||
# Use different DB file for fresh test
|
||||
fresh_db_file = tmp_path / "test_migrations_fresh.db"
|
||||
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
|
||||
fresh_engine = create_async_engine(fresh_url)
|
||||
|
||||
# Create tables on fresh DB
|
||||
async with fresh_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Upgrade to head directly (no baseline stamp)
|
||||
await run_alembic_upgrade(fresh_engine, 'head')
|
||||
|
||||
# Verify revision
|
||||
rev = await get_alembic_current(fresh_engine)
|
||||
assert rev is not None, "Expected a revision on fresh DB"
|
||||
|
||||
await fresh_engine.dispose()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fresh_db_without_create_all_behavior(self, tmp_path):
|
||||
"""
|
||||
Fresh database without create_all - test actual behavior.
|
||||
|
||||
This tests what happens when migrations run on truly empty DB.
|
||||
The behavior is determined by Alembic and migration scripts.
|
||||
|
||||
EXPECTED: Either:
|
||||
1. Migration succeeds (if scripts handle empty DB)
|
||||
2. Migration fails with specific error (if scripts require tables)
|
||||
|
||||
IMPORTANT: This test verifies the ACTUAL behavior, not accepting
|
||||
any arbitrary failure with try-except pass.
|
||||
"""
|
||||
fresh_db_file = tmp_path / "test_empty_migrations.db"
|
||||
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
|
||||
fresh_engine = create_async_engine(fresh_url)
|
||||
|
||||
# Capture the actual behavior
|
||||
actual_result = None
|
||||
actual_error = None
|
||||
|
||||
try:
|
||||
await run_alembic_upgrade(fresh_engine, 'head')
|
||||
rev = await get_alembic_current(fresh_engine)
|
||||
actual_result = rev
|
||||
except Exception as e:
|
||||
actual_error = e
|
||||
|
||||
await fresh_engine.dispose()
|
||||
|
||||
# Verify specific behavior - one of two outcomes is expected
|
||||
if actual_result is not None:
|
||||
# Migration succeeded - verify revision exists
|
||||
assert actual_result is not None, "Revision should exist after successful migration"
|
||||
else:
|
||||
# Migration failed - verify the error type is known
|
||||
# Alembic typically raises specific errors for missing tables
|
||||
assert actual_error is not None, "Error should be captured if migration failed"
|
||||
# Log the error type for documentation (don't silently pass)
|
||||
error_type = type(actual_error).__name__
|
||||
# Acceptable error types for empty DB scenarios
|
||||
acceptable_errors = [
|
||||
'OperationalError', # SQLite table not found
|
||||
'ProgrammingError', # SQLAlchemy errors
|
||||
'CommandError', # Alembic command errors
|
||||
]
|
||||
assert error_type in acceptable_errors, (
|
||||
f"Unexpected error type: {error_type}. "
|
||||
f"This may indicate a regression in migration behavior. "
|
||||
f"Error: {actual_error}"
|
||||
)
|
||||
|
||||
|
||||
class TestSQLiteMigrationGetCurrent:
|
||||
"""Tests for get_alembic_current behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_on_unstamped_db_returns_none(self, sqlite_engine):
|
||||
"""
|
||||
get_alembic_current returns None for unstamped database.
|
||||
"""
|
||||
# Create tables but don't stamp
|
||||
async with sqlite_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# No stamp - should return None
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev is None, f"Expected None for unstamped DB, got {rev}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_after_stamp_returns_revision(self, sqlite_engine):
|
||||
"""
|
||||
get_alembic_current returns correct revision after stamp.
|
||||
"""
|
||||
async with sqlite_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
await run_alembic_stamp(sqlite_engine, '0001_baseline')
|
||||
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev == '0001_baseline'
|
||||
@@ -1,217 +0,0 @@
|
||||
"""
|
||||
PostgreSQL migration integration tests.
|
||||
|
||||
Tests real Alembic migration behavior using PostgreSQL database.
|
||||
Marked as slow - requires external PostgreSQL service.
|
||||
|
||||
Run locally (requires PostgreSQL):
|
||||
TEST_POSTGRES_URL=postgresql+asyncpg://user:pass@localhost:5432/test_db \
|
||||
uv run pytest tests/integration/persistence/test_migrations_postgres.py -q
|
||||
|
||||
CI runs automatically with PostgreSQL service container.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy import text
|
||||
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.persistence.alembic_runner import (
|
||||
run_alembic_upgrade,
|
||||
run_alembic_stamp,
|
||||
get_alembic_current,
|
||||
)
|
||||
|
||||
|
||||
pytestmark = [pytest.mark.integration, pytest.mark.slow]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def postgres_url():
|
||||
"""Get PostgreSQL URL from environment."""
|
||||
url = os.environ.get('TEST_POSTGRES_URL')
|
||||
if not url:
|
||||
pytest.skip("TEST_POSTGRES_URL not set")
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def postgres_engine(postgres_url):
|
||||
"""Create async PostgreSQL engine."""
|
||||
engine = create_async_engine(postgres_url, isolation_level="AUTOCOMMIT")
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def clean_tables(postgres_engine):
|
||||
"""Drop all tables before and after each test for isolation."""
|
||||
# Drop all tables before test
|
||||
async with postgres_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
yield
|
||||
|
||||
# Drop all tables after test
|
||||
async with postgres_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def clean_alembic_version(postgres_engine):
|
||||
"""Drop alembic_version table before and after each test."""
|
||||
async with postgres_engine.begin() as conn:
|
||||
# Drop alembic_version table if exists
|
||||
try:
|
||||
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield
|
||||
|
||||
async with postgres_engine.begin() as conn:
|
||||
try:
|
||||
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class TestPostgreSQLMigrationBaseline:
|
||||
"""Tests for baseline stamp workflow on PostgreSQL."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_baseline_stamp_sets_revision(
|
||||
self, postgres_engine, clean_tables, clean_alembic_version
|
||||
):
|
||||
"""
|
||||
Stamp baseline on existing tables sets correct revision.
|
||||
|
||||
Workflow:
|
||||
1. Create tables via Base.metadata.create_all
|
||||
2. Stamp with '0001_baseline'
|
||||
3. Verify current revision is '0001_baseline'
|
||||
"""
|
||||
# Create all tables (simulates existing DB created by ORM)
|
||||
async with postgres_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp baseline
|
||||
await run_alembic_stamp(postgres_engine, '0001_baseline')
|
||||
|
||||
# Verify revision
|
||||
rev = await get_alembic_current(postgres_engine)
|
||||
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_baseline_stamp_on_empty_db(
|
||||
self, postgres_engine, clean_tables, clean_alembic_version
|
||||
):
|
||||
"""
|
||||
Stamp on empty database (no tables) still sets revision.
|
||||
|
||||
This is an edge case - stamping without tables.
|
||||
"""
|
||||
# Don't create tables - stamp directly
|
||||
await run_alembic_stamp(postgres_engine, '0001_baseline')
|
||||
|
||||
rev = await get_alembic_current(postgres_engine)
|
||||
assert rev == '0001_baseline'
|
||||
|
||||
|
||||
class TestPostgreSQLMigrationUpgrade:
|
||||
"""Tests for upgrade to head workflow on PostgreSQL."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_upgrade_from_baseline_to_head(
|
||||
self, postgres_engine, clean_tables, clean_alembic_version
|
||||
):
|
||||
"""
|
||||
Upgrade from baseline to head applies all migrations.
|
||||
|
||||
Workflow:
|
||||
1. Create tables
|
||||
2. Stamp baseline
|
||||
3. Upgrade to head
|
||||
4. Verify current revision is head
|
||||
"""
|
||||
# Create tables
|
||||
async with postgres_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp baseline
|
||||
await run_alembic_stamp(postgres_engine, '0001_baseline')
|
||||
|
||||
# Upgrade to head
|
||||
await run_alembic_upgrade(postgres_engine, 'head')
|
||||
|
||||
# Verify revision
|
||||
rev = await get_alembic_current(postgres_engine)
|
||||
assert rev is not None, "Expected a revision after upgrade"
|
||||
# Head should be the latest migration (0003 for current state)
|
||||
assert rev.startswith('0003'), f"Expected head to be 0003_*, got {rev}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_upgrade_idempotent(
|
||||
self, postgres_engine, clean_tables, clean_alembic_version
|
||||
):
|
||||
"""
|
||||
Running upgrade to head multiple times is idempotent.
|
||||
|
||||
Workflow:
|
||||
1. Upgrade to head
|
||||
2. Get revision
|
||||
3. Upgrade to head again
|
||||
4. Verify same revision
|
||||
"""
|
||||
# Create tables
|
||||
async with postgres_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp and upgrade
|
||||
await run_alembic_stamp(postgres_engine, '0001_baseline')
|
||||
await run_alembic_upgrade(postgres_engine, 'head')
|
||||
|
||||
rev1 = await get_alembic_current(postgres_engine)
|
||||
|
||||
# Upgrade again - should be idempotent
|
||||
await run_alembic_upgrade(postgres_engine, 'head')
|
||||
|
||||
rev2 = await get_alembic_current(postgres_engine)
|
||||
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
|
||||
|
||||
|
||||
class TestPostgreSQLMigrationGetCurrent:
|
||||
"""Tests for get_alembic_current behavior on PostgreSQL."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_get_current_on_unstamped_db_returns_none(
|
||||
self, postgres_engine, clean_tables, clean_alembic_version
|
||||
):
|
||||
"""
|
||||
get_alembic_current returns None for unstamped database.
|
||||
"""
|
||||
# Create tables but don't stamp
|
||||
async with postgres_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# No stamp - should return None
|
||||
rev = await get_alembic_current(postgres_engine)
|
||||
assert rev is None, f"Expected None for unstamped DB, got {rev}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_get_current_after_stamp_returns_revision(
|
||||
self, postgres_engine, clean_tables, clean_alembic_version
|
||||
):
|
||||
"""
|
||||
get_alembic_current returns correct revision after stamp.
|
||||
"""
|
||||
async with postgres_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
await run_alembic_stamp(postgres_engine, '0001_baseline')
|
||||
|
||||
rev = await get_alembic_current(postgres_engine)
|
||||
assert rev == '0001_baseline'
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
Pipeline integration tests package.
|
||||
|
||||
Tests for full pipeline flow using fake provider/runner.
|
||||
"""
|
||||
@@ -1,778 +0,0 @@
|
||||
"""
|
||||
Pipeline full-flow integration tests.
|
||||
|
||||
Tests real pipeline stages with fake runner/provider.
|
||||
Validates message processing through PreProcessor, Processor, and SendResponseBackStage.
|
||||
|
||||
Uses RuntimePipeline directly (not PipelineManager) to avoid DB dependency.
|
||||
|
||||
Run: uv run pytest tests/integration/pipeline -q --tb=short
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
import sys
|
||||
|
||||
from tests.factories import FakeApp, text_query, mock_platform_adapter
|
||||
from tests.factories.provider import FakeProvider
|
||||
from tests.factories.platform import FakePlatform
|
||||
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""
|
||||
Break circular import chain for pipeline modules using isolated_sys_modules.
|
||||
|
||||
Chain: pipeline → core.app → provider.runner → http_controller → groups/plugins
|
||||
|
||||
We mock minimal modules to allow importing RuntimePipeline, StageInstContainer,
|
||||
and stage classes without triggering full application initialization.
|
||||
|
||||
After mocking, we import the stage modules so decorators register them.
|
||||
"""
|
||||
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
|
||||
|
||||
# Mock core.entities with LifecycleControlScope enum
|
||||
mock_core_entities = Mock()
|
||||
mock_core_entities.LifecycleControlScope = MockLifecycleControlScope
|
||||
|
||||
# Mock core.app - Application class is referenced but not instantiated
|
||||
mock_core_app = Mock()
|
||||
|
||||
# Mock provider.runner with preregistered_runners list
|
||||
mock_runner = Mock()
|
||||
mock_runner.preregistered_runners = [] # Will be populated in tests
|
||||
|
||||
# Mock utils.importutil - prevents auto-import of runners
|
||||
mock_importutil = Mock()
|
||||
mock_importutil.import_modules_in_pkg = lambda pkg: None
|
||||
mock_importutil.import_modules_in_pkgs = lambda pkgs: None
|
||||
|
||||
# Modules to clear (force re-import after mocking)
|
||||
clear = [
|
||||
'langbot.pkg.pipeline.stage',
|
||||
'langbot.pkg.pipeline.entities',
|
||||
'langbot.pkg.pipeline.pipelinemgr',
|
||||
'langbot.pkg.pipeline.preproc.preproc',
|
||||
'langbot.pkg.pipeline.process.process',
|
||||
'langbot.pkg.pipeline.process.handler',
|
||||
'langbot.pkg.pipeline.process.handlers.chat',
|
||||
'langbot.pkg.pipeline.process.handlers.command',
|
||||
'langbot.pkg.pipeline.respback.respback',
|
||||
'langbot.pkg.provider.runner',
|
||||
]
|
||||
|
||||
with isolated_sys_modules(
|
||||
mocks={
|
||||
'langbot.pkg.core.entities': mock_core_entities,
|
||||
'langbot.pkg.core.app': mock_core_app,
|
||||
'langbot.pkg.provider.runner': mock_runner,
|
||||
'langbot.pkg.utils.importutil': mock_importutil,
|
||||
'langbot.pkg.pipeline.controller': Mock(),
|
||||
'langbot.pkg.pipeline.pipelinemgr': Mock(),
|
||||
},
|
||||
clear=clear,
|
||||
):
|
||||
# Import stage modules AFTER clearing so decorators register them
|
||||
from importlib import import_module
|
||||
|
||||
# Import stage base first
|
||||
import_module('langbot.pkg.pipeline.stage')
|
||||
|
||||
# Import entities
|
||||
import_module('langbot.pkg.pipeline.entities')
|
||||
|
||||
# Import specific stages to register them
|
||||
import_module('langbot.pkg.pipeline.preproc.preproc')
|
||||
import_module('langbot.pkg.pipeline.process.process')
|
||||
import_module('langbot.pkg.pipeline.respback.respback')
|
||||
|
||||
# Import pipelinemgr for RuntimePipeline
|
||||
import_module('langbot.pkg.pipeline.pipelinemgr')
|
||||
|
||||
yield
|
||||
|
||||
|
||||
# ============== FAKE RUNNER ==============
|
||||
|
||||
class FakeRunner:
|
||||
"""Minimal fake runner class for pipeline integration tests.
|
||||
|
||||
Note: preregistered_runners expects a CLASS, not an instance.
|
||||
The handler calls runner_cls(self.ap, query.pipeline_config) to instantiate.
|
||||
"""
|
||||
|
||||
name = 'local-agent'
|
||||
|
||||
def __init__(self, app=None, config=None):
|
||||
self.app = app
|
||||
self.config = config or {}
|
||||
self._provider = FakeProvider()
|
||||
# Instance-level configuration set via class attribute
|
||||
self._response_text = "fake response"
|
||||
self._raise_error = None
|
||||
|
||||
@classmethod
|
||||
def returns(cls, text: str):
|
||||
"""Create a runner class configured to return specific text."""
|
||||
# We create a subclass with configured response
|
||||
class ConfiguredRunner(cls):
|
||||
name = cls.name
|
||||
_response_text = text
|
||||
_raise_error = None
|
||||
|
||||
def __init__(self, app=None, config=None):
|
||||
super().__init__(app, config)
|
||||
self._response_text = text
|
||||
return ConfiguredRunner
|
||||
|
||||
@classmethod
|
||||
def raises(cls, error: Exception):
|
||||
"""Create a runner class configured to raise an error."""
|
||||
class ConfiguredRunner(cls):
|
||||
name = cls.name
|
||||
_response_text = None
|
||||
_raise_error = error
|
||||
|
||||
def __init__(self, app=None, config=None):
|
||||
super().__init__(app, config)
|
||||
self._raise_error = error
|
||||
return ConfiguredRunner
|
||||
|
||||
async def run(self, query):
|
||||
"""Run the fake provider and yield messages."""
|
||||
from langbot_plugin.api.entities.builtin.provider.message import Message
|
||||
|
||||
# Use the configured response/error
|
||||
if self._raise_error:
|
||||
raise self._raise_error
|
||||
|
||||
# Yield a simple message
|
||||
yield Message(role='assistant', content=self._response_text)
|
||||
|
||||
|
||||
# ============== PIPELINE APP FIXTURE ==============
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_app():
|
||||
"""
|
||||
Create FakeApp with all dependencies required by pipeline stages.
|
||||
|
||||
PreProcessor needs: sess_mgr, model_mgr, tool_mgr, plugin_connector
|
||||
Processor needs: instance_config, plugin_connector
|
||||
SendResponseBackStage needs: logger
|
||||
ChatMessageHandler needs: telemetry, survey
|
||||
"""
|
||||
app = FakeApp()
|
||||
|
||||
# Session/conversation mocks for PreProcessor
|
||||
mock_session = Mock()
|
||||
mock_session.launcher_type = Mock()
|
||||
mock_session.launcher_type.value = 'person'
|
||||
mock_session.launcher_id = 12345
|
||||
mock_session.sender_id = 12345
|
||||
mock_session.use_prompt_name = 'default'
|
||||
mock_session.using_conversation = None
|
||||
|
||||
# Create a simple class to mimic Prompt behavior
|
||||
class MockPrompt:
|
||||
def __init__(self, name, messages):
|
||||
self.name = name
|
||||
self.messages = messages
|
||||
def copy(self):
|
||||
return MockPrompt(self.name, list(self.messages))
|
||||
|
||||
# Create real lists for messages
|
||||
prompt_messages_list = []
|
||||
messages_list = []
|
||||
|
||||
mock_prompt = MockPrompt('default', prompt_messages_list)
|
||||
mock_conversation = Mock()
|
||||
mock_conversation.prompt = mock_prompt
|
||||
mock_conversation.messages = messages_list
|
||||
mock_conversation.uuid = 'test-conversation-uuid'
|
||||
mock_conversation.update_time = None
|
||||
mock_conversation.create_time = None
|
||||
|
||||
app.sess_mgr.get_session = AsyncMock(return_value=mock_session)
|
||||
app.sess_mgr.get_conversation = AsyncMock(return_value=mock_conversation)
|
||||
|
||||
# Model mock for PreProcessor
|
||||
mock_model = Mock()
|
||||
mock_model.model_entity = Mock()
|
||||
mock_model.model_entity.uuid = 'test-model-uuid'
|
||||
mock_model.model_entity.name = 'test-model'
|
||||
mock_model.model_entity.abilities = ['func_call', 'vision']
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=mock_model)
|
||||
|
||||
# Tool manager mock
|
||||
app.tool_mgr.get_all_tools = AsyncMock(return_value=[])
|
||||
|
||||
# Telemetry mock (required by ChatMessageHandler)
|
||||
app.telemetry = Mock()
|
||||
app.telemetry.start_send_task = AsyncMock()
|
||||
|
||||
# Survey mock
|
||||
app.survey = None
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_platform_adapter():
|
||||
"""Create a fake platform adapter for outbound capture."""
|
||||
platform = FakePlatform(stream_output_supported=False)
|
||||
adapter = mock_platform_adapter(platform)
|
||||
return adapter, platform
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_fake_runner():
|
||||
"""Factory fixture to set a fake runner CLASS in preregistered_runners."""
|
||||
def _set_runner(runner_cls):
|
||||
# preregistered_runners expects a list of runner classes
|
||||
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_cls]
|
||||
return _set_runner
|
||||
|
||||
|
||||
# ============== PIPELINE CONFIGURATION ==============
|
||||
|
||||
def create_minimal_pipeline_config():
|
||||
"""Create minimal pipeline configuration for tests."""
|
||||
return {
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent', 'expire-time': None},
|
||||
'local-agent': {
|
||||
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
|
||||
'prompt': 'default',
|
||||
'knowledge-bases': [],
|
||||
},
|
||||
},
|
||||
'output': {
|
||||
'force-delay': {'min': 0.0, 'max': 0.0},
|
||||
'misc': {
|
||||
'at-sender': False,
|
||||
'quote-origin': False,
|
||||
'exception-handling': 'show-hint',
|
||||
'failure-hint': 'Request failed.',
|
||||
},
|
||||
},
|
||||
'trigger': {
|
||||
'misc': {'combine-quote-message': False},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ============== HELPER TO PROCESS COROUTINE/GENERATOR ==============
|
||||
|
||||
async def collect_processor_results(processor, query, stage_name):
|
||||
"""
|
||||
Helper to handle the coroutine -> async_generator pattern.
|
||||
|
||||
Processor.process() returns a coroutine that yields an async_generator.
|
||||
This helper handles both cases like RuntimePipeline does.
|
||||
"""
|
||||
result = processor.process(query, stage_name)
|
||||
|
||||
# Handle coroutine (await it to get async_generator)
|
||||
if asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
|
||||
# Now iterate over async_generator
|
||||
results = []
|
||||
async for item in result:
|
||||
results.append(item)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ============== TESTS ==============
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestPipelineStageChainReal:
|
||||
"""Tests for real pipeline stage chain."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_pipeline_modules(self):
|
||||
"""Verify we can import real pipeline modules."""
|
||||
from langbot.pkg.pipeline import stage, entities
|
||||
from langbot.pkg.pipeline import pipelinemgr
|
||||
|
||||
assert hasattr(stage, 'PipelineStage')
|
||||
assert hasattr(stage, 'preregistered_stages')
|
||||
assert hasattr(entities, 'ResultType')
|
||||
assert hasattr(entities, 'StageProcessResult')
|
||||
assert hasattr(pipelinemgr, 'RuntimePipeline')
|
||||
assert hasattr(pipelinemgr, 'StageInstContainer')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_preregistration(self):
|
||||
"""Verify stages are preregistered after fixture imports them."""
|
||||
from langbot.pkg.pipeline import stage
|
||||
|
||||
# Check that our target stages are registered
|
||||
assert 'PreProcessor' in stage.preregistered_stages
|
||||
assert 'MessageProcessor' in stage.preregistered_stages
|
||||
assert 'SendResponseBackStage' in stage.preregistered_stages
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestPreProcessorStage:
|
||||
"""Tests for PreProcessor stage alone."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preproc_continues_on_valid_query(self, pipeline_app, fake_platform_adapter):
|
||||
"""PreProcessor should return CONTINUE for valid text query."""
|
||||
from langbot.pkg.pipeline import entities
|
||||
from langbot.pkg.pipeline.preproc import preproc
|
||||
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Create query with adapter
|
||||
query = text_query("hello")
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
|
||||
# Mock plugin_connector for PromptPreProcessing event
|
||||
mock_event_ctx = Mock()
|
||||
mock_event_ctx.event = Mock()
|
||||
mock_event_ctx.event.default_prompt = [] # Real list
|
||||
mock_event_ctx.event.prompt = [] # Real list
|
||||
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
# Create PreProcessor stage
|
||||
preproc_stage = preproc.PreProcessor(pipeline_app)
|
||||
|
||||
result = await preproc_stage.process(query, 'PreProcessor')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query.session is not None
|
||||
assert result.new_query.user_message is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preproc_sets_user_message(self, pipeline_app, fake_platform_adapter):
|
||||
"""PreProcessor should set user_message from message_chain."""
|
||||
from langbot.pkg.pipeline import entities
|
||||
from langbot.pkg.pipeline.preproc import preproc
|
||||
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
query = text_query("test message content")
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
|
||||
# Mock plugin_connector for PromptPreProcessing event
|
||||
mock_event_ctx = Mock()
|
||||
mock_event_ctx.event = Mock()
|
||||
mock_event_ctx.event.default_prompt = []
|
||||
mock_event_ctx.event.prompt = []
|
||||
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
preproc_stage = preproc.PreProcessor(pipeline_app)
|
||||
|
||||
result = await preproc_stage.process(query, 'PreProcessor')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
# Check user_message content
|
||||
assert result.new_query.user_message is not None
|
||||
assert result.new_query.user_message.role == 'user'
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestProcessorStage:
|
||||
"""Tests for MessageProcessor stage."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processor_calls_chat_handler(self, pipeline_app, fake_platform_adapter, set_fake_runner):
|
||||
"""Processor should route to ChatMessageHandler for non-command messages."""
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Set fake runner that returns pong
|
||||
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
|
||||
set_fake_runner(fake_runner)
|
||||
|
||||
# Create query
|
||||
query = text_query("hello")
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
query.resp_messages = []
|
||||
|
||||
# Mock plugin_connector to not prevent default
|
||||
mock_event_ctx = Mock()
|
||||
mock_event_ctx.is_prevented_default = Mock(return_value=False)
|
||||
mock_event_ctx.event = Mock()
|
||||
mock_event_ctx.event.user_message_alter = None
|
||||
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
# Create Processor stage
|
||||
from langbot.pkg.pipeline.process import process
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(query.pipeline_config)
|
||||
|
||||
# Collect results using helper
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
|
||||
assert len(results) >= 1
|
||||
# Check that resp_messages was populated
|
||||
assert len(query.resp_messages) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processor_prevent_default_without_reply_interrupts(self, pipeline_app, fake_platform_adapter):
|
||||
"""Processor should INTERRUPT when plugin prevents default without reply."""
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Create query
|
||||
query = text_query("hello")
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
|
||||
# Mock plugin_connector to prevent default without reply
|
||||
mock_event_ctx = Mock()
|
||||
mock_event_ctx.is_prevented_default = Mock(return_value=True)
|
||||
mock_event_ctx.event = Mock()
|
||||
mock_event_ctx.event.reply_message_chain = None
|
||||
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
# Create Processor stage
|
||||
from langbot.pkg.pipeline.process import process
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(query.pipeline_config)
|
||||
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processor_prevent_default_with_reply_continues(self, pipeline_app, fake_platform_adapter):
|
||||
"""Processor should CONTINUE when plugin prevents default with reply."""
|
||||
from langbot.pkg.pipeline import entities
|
||||
from tests.factories.message import text_chain
|
||||
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Create query
|
||||
query = text_query("hello")
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
query.resp_messages = []
|
||||
|
||||
# Create reply chain
|
||||
reply_chain = text_chain("plugin response")
|
||||
|
||||
# Mock plugin_connector to prevent default with reply
|
||||
mock_event_ctx = Mock()
|
||||
mock_event_ctx.is_prevented_default = Mock(return_value=True)
|
||||
mock_event_ctx.event = Mock()
|
||||
mock_event_ctx.event.reply_message_chain = reply_chain
|
||||
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
# Create Processor stage
|
||||
from langbot.pkg.pipeline.process import process
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(query.pipeline_config)
|
||||
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.CONTINUE
|
||||
assert len(query.resp_messages) == 1
|
||||
assert query.resp_messages[0] == reply_chain
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestRunnerExceptionFlow:
|
||||
"""Tests for runner exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_exception_yields_interrupt(self, pipeline_app, fake_platform_adapter, set_fake_runner):
|
||||
"""Runner exception should yield INTERRUPT with error notices."""
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Set fake runner that raises exception
|
||||
fake_runner = FakeRunner().raises(ValueError("API Error: rate limit exceeded"))
|
||||
set_fake_runner(fake_runner)
|
||||
|
||||
# Create query with exception handling config
|
||||
config = create_minimal_pipeline_config()
|
||||
config['output']['misc']['exception-handling'] = 'show-hint'
|
||||
config['output']['misc']['failure-hint'] = 'Request failed.'
|
||||
|
||||
query = text_query("hello")
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = config
|
||||
|
||||
# Mock plugin_connector to not prevent default
|
||||
mock_event_ctx = Mock()
|
||||
mock_event_ctx.is_prevented_default = Mock(return_value=False)
|
||||
mock_event_ctx.event = Mock()
|
||||
mock_event_ctx.event.user_message_alter = None
|
||||
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
# Create Processor stage
|
||||
from langbot.pkg.pipeline.process import process
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(query.pipeline_config)
|
||||
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert results[0].user_notice == 'Request failed.'
|
||||
assert results[0].error_notice is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_exception_show_error_mode(self, pipeline_app, fake_platform_adapter, set_fake_runner):
|
||||
"""show-error mode should show actual exception message."""
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Set fake runner that raises specific exception
|
||||
fake_runner = FakeRunner().raises(RuntimeError("Custom runtime error"))
|
||||
set_fake_runner(fake_runner)
|
||||
|
||||
# Create query with show-error mode
|
||||
config = create_minimal_pipeline_config()
|
||||
config['output']['misc']['exception-handling'] = 'show-error'
|
||||
|
||||
query = text_query("hello")
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = config
|
||||
|
||||
# Mock plugin_connector to not prevent default
|
||||
mock_event_ctx = Mock()
|
||||
mock_event_ctx.is_prevented_default = Mock(return_value=False)
|
||||
mock_event_ctx.event = Mock()
|
||||
mock_event_ctx.event.user_message_alter = None
|
||||
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
# Create Processor stage
|
||||
from langbot.pkg.pipeline.process import process
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(query.pipeline_config)
|
||||
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert 'Custom runtime error' in results[0].user_notice
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_exception_hide_mode(self, pipeline_app, fake_platform_adapter, set_fake_runner):
|
||||
"""hide mode should not show user notice."""
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Set fake runner that raises exception
|
||||
fake_runner = FakeRunner().raises(Exception("Hidden error"))
|
||||
set_fake_runner(fake_runner)
|
||||
|
||||
# Create query with hide mode
|
||||
config = create_minimal_pipeline_config()
|
||||
config['output']['misc']['exception-handling'] = 'hide'
|
||||
|
||||
query = text_query("hello")
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = config
|
||||
|
||||
# Mock plugin_connector to not prevent default
|
||||
mock_event_ctx = Mock()
|
||||
mock_event_ctx.is_prevented_default = Mock(return_value=False)
|
||||
mock_event_ctx.event = Mock()
|
||||
mock_event_ctx.event.user_message_alter = None
|
||||
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
# Create Processor stage
|
||||
from langbot.pkg.pipeline.process import process
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(query.pipeline_config)
|
||||
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert results[0].user_notice is None
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestSendResponseBackStage:
|
||||
"""Tests for SendResponseBackStage."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_response_calls_adapter(self, pipeline_app, fake_platform_adapter):
|
||||
"""SendResponseBackStage should call adapter.reply_message."""
|
||||
from langbot.pkg.pipeline import entities
|
||||
from langbot.pkg.pipeline.respback import respback
|
||||
from tests.factories.message import text_chain
|
||||
from langbot_plugin.api.entities.builtin.provider.message import Message
|
||||
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Create query with response message
|
||||
query = text_query("hello")
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
|
||||
# Add response message
|
||||
query.resp_messages = [Message(role='assistant', content='test response')]
|
||||
query.resp_message_chain = [text_chain('test response')]
|
||||
|
||||
# Create SendResponseBackStage
|
||||
respback_stage = respback.SendResponseBackStage(pipeline_app)
|
||||
|
||||
result = await respback_stage.process(query, 'SendResponseBackStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
|
||||
# Check that adapter was called
|
||||
outbound = platform.get_outbound_messages()
|
||||
assert len(outbound) == 1
|
||||
assert outbound[0]['type'] == 'reply'
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestStageChainIntegration:
|
||||
"""Tests for full stage chain (PreProcessor -> Processor -> SendResponseBackStage)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_chain_text_message_flow(self, pipeline_app, fake_platform_adapter, set_fake_runner):
|
||||
"""
|
||||
Full chain: text message -> PreProcessor -> Processor -> SendResponseBackStage.
|
||||
|
||||
Validates:
|
||||
- PreProcessor sets up session, user_message
|
||||
- Processor calls runner and populates resp_messages
|
||||
- SendResponseBackStage calls adapter.reply_message
|
||||
"""
|
||||
from langbot.pkg.pipeline import entities
|
||||
from langbot.pkg.pipeline.preproc import preproc
|
||||
from langbot.pkg.pipeline.process import process
|
||||
from langbot.pkg.pipeline.respback import respback
|
||||
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Set fake runner
|
||||
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
|
||||
set_fake_runner(fake_runner)
|
||||
|
||||
# Create query
|
||||
config = create_minimal_pipeline_config()
|
||||
query = text_query("ping")
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = config
|
||||
query.resp_messages = []
|
||||
query.resp_message_chain = []
|
||||
|
||||
# Mock plugin_connector for PreProcessor and Processor events
|
||||
mock_event_ctx_preproc = Mock()
|
||||
mock_event_ctx_preproc.event = Mock()
|
||||
mock_event_ctx_preproc.event.default_prompt = []
|
||||
mock_event_ctx_preproc.event.prompt = []
|
||||
|
||||
mock_event_ctx_processor = Mock()
|
||||
mock_event_ctx_processor.is_prevented_default = Mock(return_value=False)
|
||||
mock_event_ctx_processor.event = Mock()
|
||||
mock_event_ctx_processor.event.user_message_alter = None
|
||||
|
||||
pipeline_app.plugin_connector.emit_event = AsyncMock()
|
||||
pipeline_app.plugin_connector.emit_event.side_effect = [
|
||||
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
|
||||
mock_event_ctx_processor, # Processor NormalMessageReceived
|
||||
]
|
||||
|
||||
# Create stages
|
||||
preproc_stage = preproc.PreProcessor(pipeline_app)
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(config)
|
||||
respback_stage = respback.SendResponseBackStage(pipeline_app)
|
||||
|
||||
# Run PreProcessor
|
||||
result1 = await preproc_stage.process(query, 'PreProcessor')
|
||||
assert result1.result_type == entities.ResultType.CONTINUE
|
||||
query = result1.new_query
|
||||
|
||||
# Run Processor
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
assert len(results) >= 1
|
||||
|
||||
# Build resp_message_chain from resp_messages
|
||||
from tests.factories.message import text_chain
|
||||
for resp_msg in query.resp_messages:
|
||||
if resp_msg.content:
|
||||
query.resp_message_chain.append(text_chain(resp_msg.content))
|
||||
|
||||
# Run SendResponseBackStage
|
||||
result3 = await respback_stage.process(query, 'SendResponseBackStage')
|
||||
assert result3.result_type == entities.ResultType.CONTINUE
|
||||
|
||||
# Verify adapter was called
|
||||
outbound = platform.get_outbound_messages()
|
||||
assert len(outbound) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chain_stops_on_interrupt(self, pipeline_app, fake_platform_adapter):
|
||||
"""
|
||||
Chain should stop when a stage returns INTERRUPT.
|
||||
|
||||
PreProcessor returns CONTINUE, Processor returns INTERRUPT (prevent_default).
|
||||
"""
|
||||
from langbot.pkg.pipeline import entities
|
||||
from langbot.pkg.pipeline.preproc import preproc
|
||||
from langbot.pkg.pipeline.process import process
|
||||
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Create query
|
||||
query = text_query("hello")
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
|
||||
# Mock plugin_connector - PreProcessor continues, Processor interrupts
|
||||
mock_event_ctx_preproc = Mock()
|
||||
mock_event_ctx_preproc.event = Mock()
|
||||
mock_event_ctx_preproc.event.default_prompt = []
|
||||
mock_event_ctx_preproc.event.prompt = []
|
||||
|
||||
mock_event_ctx_processor = Mock()
|
||||
mock_event_ctx_processor.is_prevented_default = Mock(return_value=True)
|
||||
mock_event_ctx_processor.event = Mock()
|
||||
mock_event_ctx_processor.event.reply_message_chain = None
|
||||
|
||||
pipeline_app.plugin_connector.emit_event = AsyncMock()
|
||||
pipeline_app.plugin_connector.emit_event.side_effect = [
|
||||
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
|
||||
mock_event_ctx_processor, # Processor NormalMessageReceived
|
||||
]
|
||||
|
||||
# Create stages
|
||||
preproc_stage = preproc.PreProcessor(pipeline_app)
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(query.pipeline_config)
|
||||
|
||||
# Run PreProcessor
|
||||
result1 = await preproc_stage.process(query, 'PreProcessor')
|
||||
assert result1.result_type == entities.ResultType.CONTINUE
|
||||
query = result1.new_query
|
||||
|
||||
# Run Processor - should INTERRUPT
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
|
||||
# Chain stops here - no resp_messages
|
||||
assert len(query.resp_messages) == 0
|
||||
@@ -1,6 +0,0 @@
|
||||
"""
|
||||
Smoke tests package.
|
||||
|
||||
Smoke tests verify basic functionality works without testing edge cases.
|
||||
Run with: uv run pytest tests/smoke/ -q
|
||||
"""
|
||||
@@ -1,351 +0,0 @@
|
||||
"""
|
||||
Minimal fake flow smoke tests for LangBot.
|
||||
|
||||
These tests verify basic component interactions using fake providers and platforms.
|
||||
Not a full pipeline integration test - tests individual factory components.
|
||||
|
||||
For full pipeline tests, see tests/integration/ (planned).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.factories import (
|
||||
FakeApp,
|
||||
FakeProvider,
|
||||
FakePlatform,
|
||||
text_query,
|
||||
fake_provider_pong,
|
||||
fake_model,
|
||||
mock_platform_adapter,
|
||||
)
|
||||
|
||||
|
||||
class TestFakeMessageFlow:
|
||||
"""Smoke tests for fake message flow through pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_app_creation(self):
|
||||
"""Test FakeApp can be created with all dependencies."""
|
||||
app = FakeApp()
|
||||
|
||||
assert app.logger is not None
|
||||
assert app.sess_mgr is not None
|
||||
assert app.model_mgr is not None
|
||||
assert app.tool_mgr is not None
|
||||
assert app.persistence_mgr is not None
|
||||
assert app.query_pool is not None
|
||||
assert app.instance_config is not None
|
||||
|
||||
# Verify default config
|
||||
assert app.instance_config.data["command"]["prefix"] == ["/", "!"]
|
||||
assert app.instance_config.data["command"]["enable"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_provider_returns_text(self):
|
||||
"""Test FakeProvider returns configured response."""
|
||||
provider = FakeProvider(default_response="test response")
|
||||
|
||||
# Create mock model with provider
|
||||
model = fake_model(provider=provider)
|
||||
|
||||
# Create a simple query
|
||||
query = text_query("hello")
|
||||
|
||||
# Simulate invoke
|
||||
result = await provider.invoke_llm(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=[],
|
||||
funcs=[],
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.role == "assistant"
|
||||
assert result.content == "test response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_provider_pong(self):
|
||||
"""Test FakeProvider returns LANGBOT_FAKE_PONG marker."""
|
||||
provider = fake_provider_pong()
|
||||
model = fake_model(provider=provider)
|
||||
query = text_query("ping")
|
||||
|
||||
result = await provider.invoke_llm(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=[],
|
||||
funcs=[],
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
assert result.content == FakeProvider.PONG_RESPONSE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_provider_streaming(self):
|
||||
"""Test FakeProvider streaming response."""
|
||||
provider = FakeProvider().returns_streaming(["Hello", " World"])
|
||||
model = fake_model(provider=provider)
|
||||
query = text_query("hello")
|
||||
|
||||
chunks = []
|
||||
# invoke_llm_stream returns an async generator, don't await it
|
||||
async for chunk in provider.invoke_llm_stream(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=[],
|
||||
funcs=[],
|
||||
extra_args={},
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].content == "Hello"
|
||||
assert chunks[1].content == " World"
|
||||
assert chunks[1].is_final is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_provider_timeout(self):
|
||||
"""Test FakeProvider simulates timeout error."""
|
||||
provider = FakeProvider().timeout()
|
||||
model = fake_model(provider=provider)
|
||||
query = text_query("hello")
|
||||
|
||||
with pytest.raises(TimeoutError, match="Provider timeout"):
|
||||
await provider.invoke_llm(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=[],
|
||||
funcs=[],
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_provider_rate_limit(self):
|
||||
"""Test FakeProvider simulates rate limit error."""
|
||||
provider = FakeProvider().rate_limit()
|
||||
model = fake_model(provider=provider)
|
||||
query = text_query("hello")
|
||||
|
||||
with pytest.raises(Exception, match="Rate limit exceeded"):
|
||||
await provider.invoke_llm(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=[],
|
||||
funcs=[],
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_provider_captures_requests(self):
|
||||
"""Test FakeProvider captures request arguments."""
|
||||
provider = FakeProvider()
|
||||
model = fake_model(name="gpt-4", provider=provider)
|
||||
query = text_query("hello")
|
||||
|
||||
await provider.invoke_llm(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
funcs=[{"name": "test_func"}],
|
||||
extra_args={"temperature": 0.7},
|
||||
)
|
||||
|
||||
captured = provider.get_captured_requests()
|
||||
assert len(captured) == 1
|
||||
assert captured[0]["model"] == "gpt-4"
|
||||
assert captured[0]["messages"] == [{"role": "user", "content": "hello"}]
|
||||
assert captured[0]["funcs"] == [{"name": "test_func"}]
|
||||
assert captured[0]["extra_args"] == {"temperature": 0.7}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_platform_capture_outbound(self):
|
||||
"""Test FakePlatform captures outbound messages."""
|
||||
platform = FakePlatform(bot_account_id="test-bot")
|
||||
query = text_query("hello")
|
||||
|
||||
# Simulate sending reply
|
||||
from tests.factories.message import text_chain
|
||||
|
||||
reply_chain = text_chain("response text")
|
||||
event = query.message_event
|
||||
|
||||
await platform.reply_message(event, reply_chain, quote_origin=False)
|
||||
|
||||
# Verify captured
|
||||
outbound = platform.get_outbound_messages()
|
||||
assert len(outbound) == 1
|
||||
assert outbound[0]["type"] == "reply"
|
||||
assert outbound[0]["message"] == reply_chain
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_platform_friend_message(self):
|
||||
"""Test FakePlatform creates friend message events."""
|
||||
platform = FakePlatform(bot_account_id="test-bot")
|
||||
|
||||
event = platform.create_friend_message(
|
||||
text="hello bot",
|
||||
sender_id=12345,
|
||||
nickname="TestUser",
|
||||
)
|
||||
|
||||
assert event.type == "FriendMessage"
|
||||
assert event.sender.id == 12345
|
||||
assert event.sender.nickname == "TestUser"
|
||||
assert str(event.message_chain) == "hello bot"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_platform_group_message_with_mention(self):
|
||||
"""Test FakePlatform creates group message with @mention."""
|
||||
platform = FakePlatform(bot_account_id="test-bot")
|
||||
|
||||
event = platform.create_group_message(
|
||||
text="hello everyone",
|
||||
sender_id=12345,
|
||||
group_id=99999,
|
||||
mention_bot=True,
|
||||
)
|
||||
|
||||
assert event.type == "GroupMessage"
|
||||
assert event.sender.id == 12345
|
||||
assert event.group.id == 99999
|
||||
|
||||
# Check message chain has @mention
|
||||
chain = event.message_chain
|
||||
assert len(chain) >= 2 # At + Plain
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_factories_basic(self):
|
||||
"""Test basic query factory functions."""
|
||||
# Text query
|
||||
q1 = text_query("hello world")
|
||||
assert q1.launcher_type.value == "person"
|
||||
assert str(q1.message_chain) == "hello world"
|
||||
|
||||
# Group query
|
||||
from tests.factories import group_text_query
|
||||
q2 = group_text_query("hello group", group_id=88888)
|
||||
assert q2.launcher_type.value == "group"
|
||||
assert q2.launcher_id == 88888
|
||||
|
||||
# Command query
|
||||
from tests.factories import command_query
|
||||
q3 = command_query("help", prefix="/")
|
||||
assert str(q3.message_chain) == "/help"
|
||||
|
||||
# Mention query
|
||||
from tests.factories import mention_query
|
||||
q4 = mention_query("hi", target="test-bot", group_id=77777)
|
||||
assert q4.launcher_type.value == "group"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_platform_send_failure(self):
|
||||
"""Test FakePlatform simulates send failure."""
|
||||
platform = FakePlatform().send_failure()
|
||||
query = text_query("hello")
|
||||
|
||||
from tests.factories.message import text_chain
|
||||
|
||||
with pytest.raises(Exception, match="Platform send failure"):
|
||||
await platform.reply_message(
|
||||
query.message_event,
|
||||
text_chain("response"),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_platform_adapter(self):
|
||||
"""Test mock_platform_adapter helper."""
|
||||
platform = FakePlatform(bot_account_id="bot-123")
|
||||
adapter = mock_platform_adapter(platform)
|
||||
|
||||
assert adapter.bot_account_id == "bot-123"
|
||||
assert adapter._fake_platform is platform
|
||||
|
||||
# Test reply_message is wired
|
||||
from tests.factories.message import text_chain
|
||||
|
||||
query = text_query("test")
|
||||
await adapter.reply_message(query.message_event, text_chain("response"))
|
||||
|
||||
# Verify platform captured it
|
||||
assert len(platform.get_outbound_messages()) == 1
|
||||
|
||||
|
||||
class TestMessageFlowIntegration:
|
||||
"""Minimal fake flow integration tests.
|
||||
|
||||
These tests verify component interactions but do NOT run full LangBot pipeline.
|
||||
For real pipeline tests, integration tests are needed (planned).
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_minimal_message_flow(self):
|
||||
"""Minimal fake flow test: fake query -> fake provider -> fake platform.
|
||||
|
||||
This test verifies:
|
||||
1. Fake text query is created
|
||||
2. Fake provider returns LANGBOT_FAKE_PONG
|
||||
3. Fake platform captures outbound response
|
||||
4. No unexpected exception
|
||||
|
||||
Note: This does NOT run actual LangBot pipeline stages.
|
||||
"""
|
||||
# Setup
|
||||
platform = FakePlatform(bot_account_id="test-bot")
|
||||
provider = fake_provider_pong()
|
||||
model = fake_model(provider=provider)
|
||||
|
||||
# Create inbound message
|
||||
query = text_query("ping")
|
||||
|
||||
# Simulate provider processing
|
||||
response = await provider.invoke_llm(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
funcs=[],
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
# Verify provider returned pong
|
||||
assert response.content == FakeProvider.PONG_RESPONSE
|
||||
|
||||
# Simulate platform sending response
|
||||
from tests.factories.message import text_chain
|
||||
|
||||
reply_chain = text_chain(response.content)
|
||||
await platform.reply_message(query.message_event, reply_chain)
|
||||
|
||||
# Verify platform captured outbound
|
||||
outbound = platform.get_outbound_messages()
|
||||
assert len(outbound) == 1
|
||||
assert outbound[0]["type"] == "reply"
|
||||
assert str(outbound[0]["message"]) == FakeProvider.PONG_RESPONSE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_message_flow(self):
|
||||
"""Smoke test: streaming message flow."""
|
||||
platform = FakePlatform().supports_streaming()
|
||||
provider = FakeProvider().returns_streaming(["Hello", " there"])
|
||||
model = fake_model(provider=provider)
|
||||
query = text_query("hi")
|
||||
|
||||
chunks = []
|
||||
async for chunk in provider.invoke_llm_stream(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=[],
|
||||
funcs=[],
|
||||
extra_args={},
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
# Verify streaming worked
|
||||
assert len(chunks) == 2
|
||||
full_content = "".join(c.content for c in chunks)
|
||||
assert full_content == "Hello there"
|
||||
|
||||
# Verify platform supports streaming
|
||||
assert await platform.is_stream_output_supported() is True
|
||||
@@ -1,66 +0,0 @@
|
||||
"""
|
||||
PoC test for CWE-94: Authenticated RCE via exec() on user-supplied Python code.
|
||||
|
||||
The /api/v1/system/debug/exec endpoint passes raw HTTP body to exec(),
|
||||
allowing arbitrary code execution when debug_mode is True.
|
||||
|
||||
This test verifies that:
|
||||
1. The exec() endpoint is removed from the codebase entirely.
|
||||
2. No route matches /api/v1/system/debug/exec.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import pathlib
|
||||
|
||||
# Resolve project root (one level up from tests/)
|
||||
_PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
|
||||
VULN_FILE = (
|
||||
_PROJECT_ROOT
|
||||
/ "src"
|
||||
/ "langbot"
|
||||
/ "pkg"
|
||||
/ "api"
|
||||
/ "http"
|
||||
/ "controller"
|
||||
/ "groups"
|
||||
/ "system.py"
|
||||
)
|
||||
|
||||
|
||||
def test_no_exec_call_in_system_controller():
|
||||
"""Verify there is no exec() call in system.py that takes user input."""
|
||||
with open(VULN_FILE, "r") as f:
|
||||
source = f.read()
|
||||
|
||||
tree = ast.parse(source)
|
||||
|
||||
exec_calls = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Call):
|
||||
func = node.func
|
||||
# Match bare exec() call
|
||||
if isinstance(func, ast.Name) and func.id == "exec":
|
||||
exec_calls.append(node.lineno)
|
||||
|
||||
assert len(exec_calls) == 0, (
|
||||
f"Found exec() call(s) at line(s) {exec_calls} in system.py. "
|
||||
"User-supplied code must never be passed to exec()."
|
||||
)
|
||||
|
||||
|
||||
def test_no_debug_exec_route():
|
||||
"""Verify the /debug/exec route is not registered."""
|
||||
with open(VULN_FILE, "r") as f:
|
||||
source = f.read()
|
||||
|
||||
assert "debug/exec" not in source, (
|
||||
"The /debug/exec route still exists in system.py. "
|
||||
"This endpoint allows arbitrary code execution and must be removed."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_no_exec_call_in_system_controller()
|
||||
test_no_debug_exec_route()
|
||||
print("All tests passed!")
|
||||
@@ -1,179 +0,0 @@
|
||||
# 单元测试覆盖率排除说明
|
||||
|
||||
## 排除范围
|
||||
|
||||
以下外部适配器模块不纳入测试覆盖目标,因为它们需要实际外部环境才能测试:
|
||||
|
||||
### 1. 消息平台适配器 (`platform/sources/`)
|
||||
- **路径**: `src/langbot/pkg/platform/sources/`
|
||||
- **模块**: aiocqhttp, dingtalk, discord, feishu, gestep, kook, lark, slack, telegram, wecom, wechatpv, wechatmp, qqbot
|
||||
- **排除原因**: 需要真实消息平台账号和 webhook 连接,无法纯单元测试
|
||||
- **测试方式**: 需要 mock 平台 API 或集成测试环境
|
||||
- **状态**: 后续可补充 mock 测试
|
||||
|
||||
### 2. LLM Requester (`provider/modelmgr/requesters/`)
|
||||
- **路径**: `src/langbot/pkg/provider/modelmgr/requesters/`
|
||||
- **模块**: deepseek, openai, anthropic, gemini, moonshot, ollama, zhipuai 等 20+ 个 requester
|
||||
- **排除原因**: 需要真实 LLM API 密钥和网络请求,涉及付费 API 调用
|
||||
- **测试方式**: 需要 mock HTTP 响应或使用 fake LLM server
|
||||
- **状态**: 后续可补充 mock HTTP 测试
|
||||
|
||||
### 3. Agent Runner (`provider/runners/`)
|
||||
- **路径**: `src/langbot/pkg/provider/runners/`
|
||||
- **模块**: cozeapi, difysvapi, n8nsvapi, langflowapi, dashscopeapi, localagent, tboxapi
|
||||
- **排除原因**: 需要真实 Agent 平台(Coze、Dify、n8n 等)的 API 连接
|
||||
- **测试方式**: 需要 mock Agent 平台响应
|
||||
- **状态**: 后续可补充 mock 测试
|
||||
|
||||
### 4. 向量数据库 (`vector/vdbs/`)
|
||||
- **路径**: `src/langbot/pkg/vector/vdbs/`
|
||||
- **模块**: chroma, milvus, pgvector, qdrant, seekdb
|
||||
- **排除原因**: 需要真实向量数据库实例运行
|
||||
- **测试方式**: 需要 Docker 启动测试数据库或 mock
|
||||
- **状态**: 后续可补充 mock 测试
|
||||
|
||||
---
|
||||
|
||||
## 覆盖率计算(排除外部适配器)
|
||||
|
||||
### 统计方法
|
||||
|
||||
```bash
|
||||
# 排除外部适配器后计算覆盖率
|
||||
pytest tests/unit_tests/ --cov=langbot.pkg \
|
||||
--cov-fail-under=0 \
|
||||
-o "cov_exclude_patterns=platform/sources/*,provider/modelmgr/requesters/*,provider/runners/*,vector/vdbs/*"
|
||||
```
|
||||
|
||||
### 当前覆盖率(排除后)
|
||||
|
||||
| 模块 | 覆盖率 | 状态 |
|
||||
|------|--------|------|
|
||||
| `command` | **99%** | ✅ 完成 |
|
||||
| `entity` | **99%** | ✅ 完成 |
|
||||
| `vector` | **76%** | ✅ 完成 |
|
||||
| `survey` | **84%** | ✅ 完成 |
|
||||
| `pipeline` | **72%** | ✅ 核心流程 |
|
||||
| `rag` | **66%** | ✅ 完成 |
|
||||
| `telemetry` | **87%** | ✅ 完成 |
|
||||
| `storage` | **80%** | ✅ 完成 |
|
||||
| `provider` | **83%** | ✅ 完成 |
|
||||
| `discover` | **61%** | ✅ 完成 |
|
||||
| `config` | **70%** | ✅ 完成 |
|
||||
| `utils` | **48%** | 🔄 部分完成 |
|
||||
| `api` | **34%** | 🔄 需补充 controller |
|
||||
| `platform` | **35%** | 🔄 需补充 adapter base |
|
||||
| `plugin` | **27%** | 🔄 需补充 handler |
|
||||
| `core` | **28%** | 🔄 需补充 app 启动 |
|
||||
| `persistence` | **24%** | 🔄 需补充 mgr |
|
||||
|
||||
---
|
||||
|
||||
## 后续计划
|
||||
|
||||
### 可补充的 Mock 测试(优先级排序)
|
||||
|
||||
1. **`provider/modelmgr/requesters/`** (优先级:中)
|
||||
- 使用 `httpx` mock 测试 API 响应解析
|
||||
- 测试重试逻辑、错误处理
|
||||
|
||||
2. **`provider/runners/`** (优先级:中)
|
||||
- Mock Agent 平台响应
|
||||
- 测试 session 管理、错误处理
|
||||
|
||||
3. **`platform/sources/`** (优先级:低)
|
||||
- Mock 平台 webhook 事件
|
||||
- 测试消息解析、事件处理
|
||||
|
||||
4. **`vector/vdbs/`** (优先级:低)
|
||||
- Mock 向量数据库操作
|
||||
- 测试 CRUD、查询逻辑
|
||||
|
||||
---
|
||||
|
||||
## 测试文件结构
|
||||
|
||||
```
|
||||
tests/unit_tests/
|
||||
├── api/
|
||||
│ └── service/
|
||||
│ ├── test_knowledge_service.py # 22 tests ✅
|
||||
│ └── ...
|
||||
├── core/
|
||||
│ ├── test_taskmgr.py # 21 tests ✅
|
||||
│ ├── test_load_config.py # 21 tests ✅ (含env override)
|
||||
│ └── ...
|
||||
├── plugin/
|
||||
│ ├── test_connector_static.py # 8 tests ✅
|
||||
│ ├── test_connector_pure.py # 7 tests ✅
|
||||
│ ├── test_connector_methods.py # 24 tests ✅
|
||||
│ ├── test_extract_deps.py # 7 tests ✅
|
||||
│ ├── test_handler_actions.py # 15 tests ✅ (新增)
|
||||
│ └── ...
|
||||
├── provider/
|
||||
│ ├── test_session_manager.py # 11 tests ✅ (新增)
|
||||
│ ├── test_tool_manager.py # 14 tests ✅ (新增)
|
||||
│ └── ...
|
||||
├── rag/
|
||||
│ ├── test_i18n_conversion.py # 8 tests ✅
|
||||
│ ├── test_kbmgr.py # 39 tests ✅
|
||||
│ ├── test_file_storage.py # 21 tests ✅ (新增)
|
||||
│ └── ...
|
||||
├── storage/
|
||||
│ ├── test_s3storage.py # 16 tests ✅ (新增)
|
||||
│ ├── test_localstorage_path_traversal.py # 11 tests ✅
|
||||
│ └── ...
|
||||
├── survey/
|
||||
│ └── test_survey_manager.py # 22 tests ✅
|
||||
├── telemetry/
|
||||
│ └── test_telemetry.py # 25 tests ✅ (重写)
|
||||
├── vector/
|
||||
│ ├── test_filter_utils.py # 21 tests ✅
|
||||
│ ├── test_vdb_filter_conversion.py # 30 tests ✅ (新增)
|
||||
│ └── ...
|
||||
├── utils/
|
||||
│ ├── test_platform.py # 7 tests ✅
|
||||
│ ├── test_funcschema.py # 9 tests ✅
|
||||
│ └── ...
|
||||
├── pipeline/
|
||||
│ ├── test_ratelimit.py # 12 tests ✅ (新增真实算法)
|
||||
│ ├── test_msgtrun.py # 9 tests ✅ (强化断言)
|
||||
│ └── ...
|
||||
└── persistence/
|
||||
├── test_serialize_model.py # 6 tests ✅
|
||||
├── test_database_decorator.py # 7 tests ✅
|
||||
└── ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
- **总测试数**: 1193 passed
|
||||
- **总体覆盖率**: 30%
|
||||
- **核心模块覆盖率**: **51.2%** (6549/12825 语句) - 排除外部适配器
|
||||
- **外部适配器覆盖率**: 5.6% (535/9483 语句) - 不纳入目标
|
||||
|
||||
### 核心模块覆盖率详情
|
||||
|
||||
| 模块 | 覆盖率 | 语句数 | 说明 |
|
||||
|------|--------|--------|------|
|
||||
| `command` | **99%** | 93 | ✅ 完成 |
|
||||
| `entity` | **99%** | 335 | ✅ 完成 |
|
||||
| `vector` | **76%** | 139 | ✅ 完成 (新增filter转换测试) |
|
||||
| `survey` | **84%** | 95 | ✅ 完成 |
|
||||
| `pipeline` | **72%** | 1761 | ✅ 核心流程 (新增算法测试) |
|
||||
| `rag` | **66%** | 347 | ✅ 完成 (新增ZIP处理测试) |
|
||||
| `telemetry` | **87%** | 70 | ✅ 完成 (重写假测试) |
|
||||
| `storage` | **80%** | 170 | ✅ 完成 (新增S3测试) |
|
||||
| `provider` | **83%** | 854 | ✅ 完成 (新增Session/Tool测试) |
|
||||
| `discover` | **61%** | 188 | ✅ 完成 |
|
||||
| `config` | **70%** | 198 | ✅ 完成 |
|
||||
| `utils` | **48%** | 478 | 🔄 部分完成 |
|
||||
| `api` | **34%** | 4061 | 🔄 需补充 controller |
|
||||
| `platform` | **35%** | 433 | 🔄 需补充 adapter base |
|
||||
| `plugin` | **27%** | 815 | 🔄 需补充 handler (新增action测试) |
|
||||
| `core` | **28%** | 1289 | 🔄 需补充 app 启动 |
|
||||
| `persistence` | **24%** | 1099 | 🔄 需补充 mgr |
|
||||
|
||||
外部适配器测试需要 mock 环境或集成测试,不属于纯单元测试范畴。
|
||||
@@ -1 +0,0 @@
|
||||
"""Unit tests for LangBot API HTTP service layer."""
|
||||
@@ -1,62 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from sqlalchemy.sql.dml import Update
|
||||
|
||||
from langbot.pkg.api.http.service.bot import BotService
|
||||
|
||||
|
||||
class _FakeResult:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def first(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class _PersistenceManager:
|
||||
def __init__(self):
|
||||
self.update_values = None
|
||||
|
||||
async def execute_async(self, statement):
|
||||
if isinstance(statement, Update):
|
||||
self.update_values = {
|
||||
key: value for key, value in statement.compile().params.items() if not key.startswith('uuid_')
|
||||
}
|
||||
return None
|
||||
|
||||
return _FakeResult(SimpleNamespace(name='Updated Pipeline'))
|
||||
|
||||
|
||||
async def test_update_bot_copies_input_before_filtering_and_setting_pipeline_name():
|
||||
persistence_mgr = _PersistenceManager()
|
||||
runtime_bot = SimpleNamespace(enable=False)
|
||||
platform_mgr = SimpleNamespace(
|
||||
remove_bot=AsyncMock(),
|
||||
load_bot=AsyncMock(return_value=runtime_bot),
|
||||
)
|
||||
ap = SimpleNamespace(
|
||||
persistence_mgr=persistence_mgr,
|
||||
platform_mgr=platform_mgr,
|
||||
sess_mgr=SimpleNamespace(session_list=[]),
|
||||
)
|
||||
service = BotService(ap)
|
||||
service.get_bot = AsyncMock(return_value={'uuid': 'bot-1'})
|
||||
payload = {
|
||||
'uuid': 'caller-owned-uuid',
|
||||
'name': 'Test Bot',
|
||||
'use_pipeline_uuid': 'pipeline-1',
|
||||
}
|
||||
|
||||
await service.update_bot('bot-1', payload)
|
||||
|
||||
assert payload == {
|
||||
'uuid': 'caller-owned-uuid',
|
||||
'name': 'Test Bot',
|
||||
'use_pipeline_uuid': 'pipeline-1',
|
||||
}
|
||||
assert persistence_mgr.update_values == {
|
||||
'name': 'Test Bot',
|
||||
'use_pipeline_uuid': 'pipeline-1',
|
||||
'use_pipeline_name': 'Updated Pipeline',
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
"""Unit tests for API HTTP service layer.
|
||||
|
||||
Tests real service business logic with mocked dependencies:
|
||||
- persistence_mgr (database operations)
|
||||
- model_mgr (runtime model management)
|
||||
- platform_mgr (platform management)
|
||||
- plugin_connector (plugin runtime)
|
||||
- adjacent services (cross-service calls)
|
||||
|
||||
Does NOT:
|
||||
- Start real Quart server
|
||||
- Access real database
|
||||
- Call real provider/platform/network
|
||||
|
||||
Uses tests.factories.FakeApp as base mock application.
|
||||
"""
|
||||
@@ -1,429 +0,0 @@
|
||||
"""
|
||||
Unit tests for ApiKeyService.
|
||||
|
||||
Tests API key CRUD operations with mocked persistence layer.
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/apikey.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langbot.pkg.api.http.service.apikey import ApiKeyService
|
||||
from langbot.pkg.entity.persistence.apikey import ApiKey
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class TestApiKeyServiceGetApiKeys:
|
||||
"""Tests for get_api_keys method."""
|
||||
|
||||
async def test_get_api_keys_empty_list(self):
|
||||
"""Returns empty list when no API keys exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = Mock()
|
||||
mock_result.all = Mock(return_value=[])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'id': entity.id,
|
||||
'name': entity.name,
|
||||
'key': entity.key,
|
||||
'description': entity.description,
|
||||
}
|
||||
if entity
|
||||
else {}
|
||||
)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_api_keys()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_get_api_keys_returns_serialized_list(self):
|
||||
"""Returns serialized list of API keys."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Create mock API key entities
|
||||
key1 = Mock(spec=ApiKey)
|
||||
key1.id = 1
|
||||
key1.name = 'Test Key 1'
|
||||
key1.key = 'lbk_test_key_1'
|
||||
key1.description = 'First test key'
|
||||
|
||||
key2 = Mock(spec=ApiKey)
|
||||
key2.id = 2
|
||||
key2.name = 'Test Key 2'
|
||||
key2.key = 'lbk_test_key_2'
|
||||
key2.description = 'Second test key'
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.all = Mock(return_value=[key1, key2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'id': entity.id,
|
||||
'name': entity.name,
|
||||
'key': entity.key,
|
||||
'description': entity.description,
|
||||
}
|
||||
)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_api_keys()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert result[0]['name'] == 'Test Key 1'
|
||||
assert result[1]['name'] == 'Test Key 2'
|
||||
|
||||
|
||||
class TestApiKeyServiceCreateApiKey:
|
||||
"""Tests for create_api_key method."""
|
||||
|
||||
async def test_create_api_key_generates_key_with_prefix(self):
|
||||
"""Creates API key with 'lbk_' prefix."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
created_key = Mock(spec=ApiKey)
|
||||
created_key.id = 1
|
||||
created_key.name = 'New Key'
|
||||
created_key.key = 'lbk_fixed-token'
|
||||
created_key.description = 'Test description'
|
||||
select_result = Mock()
|
||||
select_result.first = Mock(return_value=created_key)
|
||||
insert_params = []
|
||||
|
||||
async def mock_execute(query):
|
||||
params = query.compile().params
|
||||
if {'name', 'key', 'description'}.issubset(params):
|
||||
insert_params.append(params)
|
||||
return Mock()
|
||||
return select_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'id': 1,
|
||||
'name': entity.name,
|
||||
'key': entity.key,
|
||||
'description': entity.description,
|
||||
}
|
||||
)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
with patch('langbot.pkg.api.http.service.apikey.secrets.token_urlsafe', return_value='fixed-token'):
|
||||
result = await service.create_api_key('New Key', 'Test description')
|
||||
|
||||
assert insert_params == [
|
||||
{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}
|
||||
]
|
||||
assert result['key'].startswith('lbk_')
|
||||
assert result['key'] == 'lbk_fixed-token'
|
||||
assert result['name'] == 'New Key'
|
||||
assert result['description'] == 'Test description'
|
||||
|
||||
async def test_create_api_key_without_description(self):
|
||||
"""Creates API key with empty description when not provided."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
created_key = Mock(spec=ApiKey)
|
||||
created_key.id = 1
|
||||
created_key.name = 'No Desc Key'
|
||||
created_key.key = 'lbk_no_desc_key'
|
||||
created_key.description = ''
|
||||
|
||||
select_result = Mock()
|
||||
select_result.first = Mock(return_value=created_key)
|
||||
insert_result = Mock()
|
||||
|
||||
async def mock_execute(query):
|
||||
if hasattr(query, 'values'):
|
||||
return insert_result
|
||||
return select_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'id': 1,
|
||||
'name': 'No Desc Key',
|
||||
'key': 'lbk_no_desc_key',
|
||||
'description': '',
|
||||
}
|
||||
)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.create_api_key('No Desc Key')
|
||||
|
||||
# Verify
|
||||
assert result['description'] == ''
|
||||
|
||||
|
||||
class TestApiKeyServiceGetApiKey:
|
||||
"""Tests for get_api_key method."""
|
||||
|
||||
async def test_get_api_key_by_id_found(self):
|
||||
"""Returns API key when found by ID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
key = Mock(spec=ApiKey)
|
||||
key.id = 1
|
||||
key.name = 'Found Key'
|
||||
key.key = 'lbk_found_key'
|
||||
key.description = 'Found'
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=key)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'id': 1,
|
||||
'name': 'Found Key',
|
||||
'key': 'lbk_found_key',
|
||||
'description': 'Found',
|
||||
}
|
||||
)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_api_key(1)
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['id'] == 1
|
||||
assert result['name'] == 'Found Key'
|
||||
|
||||
async def test_get_api_key_by_id_not_found(self):
|
||||
"""Returns None when API key not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_api_key(999)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_get_api_key_by_id_zero(self):
|
||||
"""Handles ID=0 (edge case) correctly."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_api_key(0)
|
||||
|
||||
# Verify - should return None (no key with ID 0)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestApiKeyServiceVerifyApiKey:
|
||||
"""Tests for verify_api_key method."""
|
||||
|
||||
async def test_verify_api_key_valid(self):
|
||||
"""Returns True for valid API key."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
key = Mock(spec=ApiKey)
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=key)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.verify_api_key('lbk_valid_key')
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
|
||||
async def test_verify_api_key_invalid(self):
|
||||
"""Returns False for invalid API key."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.verify_api_key('lbk_invalid_key')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
async def test_verify_api_key_empty_string(self):
|
||||
"""Returns False for empty key string."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.verify_api_key('')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
async def test_verify_api_key_unknown_key(self):
|
||||
"""Returns False when the key is not present in persistence."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.verify_api_key('unknown_key')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestApiKeyServiceDeleteApiKey:
|
||||
"""Tests for delete_api_key method."""
|
||||
|
||||
async def test_delete_api_key_by_id(self):
|
||||
"""Deletes API key by ID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_api_key(1)
|
||||
|
||||
# Verify - execute_async was called (delete operation)
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_delete_api_key_nonexistent_id(self):
|
||||
"""Delete operation completes even for nonexistent ID (no error raised)."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute - should not raise error
|
||||
await service.delete_api_key(999)
|
||||
|
||||
# Verify - execute_async was called regardless
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
|
||||
class TestApiKeyServiceUpdateApiKey:
|
||||
"""Tests for update_api_key method."""
|
||||
|
||||
async def test_update_api_key_name_only(self):
|
||||
"""Updates only the name field."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_api_key(1, name='Updated Name')
|
||||
|
||||
# Verify - execute_async was called with update
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_api_key_description_only(self):
|
||||
"""Updates only the description field."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_api_key(1, description='Updated description')
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_api_key_both_fields(self):
|
||||
"""Updates both name and description."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_api_key(1, name='New Name', description='New description')
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_api_key_no_fields(self):
|
||||
"""Does nothing when no fields provided."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_api_key(1)
|
||||
|
||||
# Verify - no execute call since no update_data
|
||||
ap.persistence_mgr.execute_async.assert_not_called()
|
||||
@@ -1,662 +0,0 @@
|
||||
"""
|
||||
Unit tests for BotService.
|
||||
|
||||
Tests bot CRUD operations with mocked persistence and runtime managers.
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/bot.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from types import SimpleNamespace
|
||||
import uuid
|
||||
|
||||
from langbot.pkg.api.http.service.bot import BotService
|
||||
from langbot.pkg.entity.persistence.bot import Bot
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_bot(
|
||||
bot_uuid: str = None,
|
||||
name: str = 'Test Bot',
|
||||
description: str = 'Test Description',
|
||||
adapter: str = 'telegram',
|
||||
adapter_config: dict = None,
|
||||
enable: bool = True,
|
||||
use_pipeline_uuid: str = None,
|
||||
use_pipeline_name: str = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock Bot entity."""
|
||||
bot = Mock(spec=Bot)
|
||||
bot.uuid = bot_uuid or str(uuid.uuid4())
|
||||
bot.name = name
|
||||
bot.description = description
|
||||
bot.adapter = adapter
|
||||
bot.adapter_config = adapter_config or {'token': 'test_token'}
|
||||
bot.enable = enable
|
||||
bot.use_pipeline_uuid = use_pipeline_uuid
|
||||
bot.use_pipeline_name = use_pipeline_name
|
||||
bot.pipeline_routing_rules = []
|
||||
return bot
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestBotServiceGetBots:
|
||||
"""Tests for get_bots method."""
|
||||
|
||||
async def test_get_bots_empty_list(self):
|
||||
"""Returns empty list when no bots exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity, masked_columns=None: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'adapter': entity.adapter,
|
||||
}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_bots()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_bots_returns_list_with_secrets(self):
|
||||
"""Returns bot list including adapter_config by default."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
bot1 = _create_mock_bot(bot_uuid='uuid-1', name='Bot 1')
|
||||
bot2 = _create_mock_bot(bot_uuid='uuid-2', name='Bot 2')
|
||||
|
||||
mock_result = _create_mock_result([bot1, bot2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity, masked_columns=None: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'adapter': entity.adapter,
|
||||
'adapter_config': entity.adapter_config if 'adapter_config' not in (masked_columns or []) else None,
|
||||
}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_bots(include_secret=True)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert result[0]['name'] == 'Bot 1'
|
||||
assert result[0]['adapter_config'] is not None
|
||||
|
||||
async def test_get_bots_masks_secrets(self):
|
||||
"""Returns bot list without adapter_config when include_secret=False."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
bot1 = _create_mock_bot(bot_uuid='uuid-1', name='Bot 1')
|
||||
|
||||
mock_result = _create_mock_result([bot1])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity, masked_columns=None: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'adapter': entity.adapter,
|
||||
'adapter_config': entity.adapter_config if 'adapter_config' not in (masked_columns or []) else None,
|
||||
}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_bots(include_secret=False)
|
||||
|
||||
# Verify - adapter_config should be masked
|
||||
assert result[0]['adapter_config'] is None
|
||||
|
||||
|
||||
class TestBotServiceGetBot:
|
||||
"""Tests for get_bot method."""
|
||||
|
||||
async def test_get_bot_by_uuid_found(self):
|
||||
"""Returns bot when found by UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
bot = _create_mock_bot(bot_uuid='test-uuid', name='Found Bot')
|
||||
mock_result = _create_mock_result(first_item=bot)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'name': 'Found Bot',
|
||||
'adapter': 'telegram',
|
||||
}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_bot('test-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['uuid'] == 'test-uuid'
|
||||
assert result['name'] == 'Found Bot'
|
||||
|
||||
async def test_get_bot_by_uuid_not_found(self):
|
||||
"""Returns None when bot not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_bot('nonexistent-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestBotServiceGetRuntimeBotInfo:
|
||||
"""Tests for get_runtime_bot_info method."""
|
||||
|
||||
async def test_get_runtime_bot_info_bot_not_found_raises(self):
|
||||
"""Raises Exception when bot not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Mock get_bot to return None
|
||||
service.get_bot = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='Bot not found'):
|
||||
await service.get_runtime_bot_info('nonexistent-uuid')
|
||||
|
||||
async def test_get_runtime_bot_info_returns_webhook_for_wecom(self):
|
||||
"""Returns webhook URL for wecom adapter."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'api': {
|
||||
'webhook_prefix': 'http://127.0.0.1:5300',
|
||||
'extra_webhook_prefix': 'http://extra.example.com',
|
||||
}
|
||||
}
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
bot_data = {
|
||||
'uuid': 'wecom-uuid',
|
||||
'name': 'WeCom Bot',
|
||||
'adapter': 'wecom',
|
||||
'adapter_config': {'token': 'test'},
|
||||
}
|
||||
|
||||
service = BotService(ap)
|
||||
service.get_bot = AsyncMock(return_value=bot_data)
|
||||
|
||||
# Execute
|
||||
result = await service.get_runtime_bot_info('wecom-uuid')
|
||||
|
||||
# Verify
|
||||
assert result['adapter_runtime_values']['webhook_url'] == '/bots/wecom-uuid'
|
||||
assert result['adapter_runtime_values']['webhook_full_url'] == 'http://127.0.0.1:5300/bots/wecom-uuid'
|
||||
|
||||
async def test_get_runtime_bot_info_no_webhook_for_telegram(self):
|
||||
"""Returns no webhook URL for non-webhook adapters like telegram."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'api': {}}
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
bot_data = {
|
||||
'uuid': 'telegram-uuid',
|
||||
'name': 'Telegram Bot',
|
||||
'adapter': 'telegram',
|
||||
'adapter_config': {'token': 'test'},
|
||||
}
|
||||
|
||||
service = BotService(ap)
|
||||
service.get_bot = AsyncMock(return_value=bot_data)
|
||||
|
||||
# Execute
|
||||
result = await service.get_runtime_bot_info('telegram-uuid')
|
||||
|
||||
# Verify - no webhook for telegram
|
||||
assert result['adapter_runtime_values']['webhook_url'] is None
|
||||
assert result['adapter_runtime_values']['webhook_full_url'] is None
|
||||
|
||||
async def test_get_runtime_bot_info_with_runtime_bot(self):
|
||||
"""Returns bot_account_id when runtime bot exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'api': {}}
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
|
||||
# Mock runtime bot with adapter
|
||||
runtime_bot = SimpleNamespace()
|
||||
runtime_bot.adapter = SimpleNamespace()
|
||||
runtime_bot.adapter.bot_account_id = 'runtime-account-123'
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
|
||||
|
||||
bot_data = {
|
||||
'uuid': 'runtime-uuid',
|
||||
'name': 'Runtime Bot',
|
||||
'adapter': 'telegram',
|
||||
'adapter_config': {},
|
||||
}
|
||||
|
||||
service = BotService(ap)
|
||||
service.get_bot = AsyncMock(return_value=bot_data)
|
||||
|
||||
# Execute
|
||||
result = await service.get_runtime_bot_info('runtime-uuid')
|
||||
|
||||
# Verify
|
||||
assert result['adapter_runtime_values']['bot_account_id'] == 'runtime-account-123'
|
||||
|
||||
|
||||
class TestBotServiceCreateBot:
|
||||
"""Tests for create_bot method."""
|
||||
|
||||
async def test_create_bot_max_limit_reached_raises(self):
|
||||
"""Raises ValueError when max_bots limit reached."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_bots': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.load_bot = AsyncMock()
|
||||
|
||||
# Mock get_bots to return 2 bots already
|
||||
bot1 = _create_mock_bot(bot_uuid='uuid-1')
|
||||
bot2 = _create_mock_bot(bot_uuid='uuid-2')
|
||||
mock_result = _create_mock_result([bot1, bot2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'uuid-1', 'name': 'Bot 1'}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Maximum number of bots'):
|
||||
await service.create_bot({'name': 'New Bot'})
|
||||
|
||||
async def test_create_bot_no_limit(self):
|
||||
"""Creates bot without limit check when max_bots=-1."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_bots': -1 # No limit
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.load_bot = AsyncMock()
|
||||
|
||||
# Mock pipeline query
|
||||
pipeline_result = Mock()
|
||||
pipeline_result.first = Mock(return_value=None)
|
||||
# Mock bot query after insert
|
||||
bot_result = Mock()
|
||||
bot_result.first = Mock(return_value=_create_mock_bot())
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count <= 2:
|
||||
return pipeline_result # First call: check pipeline
|
||||
elif call_count == 3:
|
||||
return Mock() # Insert
|
||||
return bot_result # Get bot
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'new-uuid', 'name': 'New Bot'}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
bot_uuid = await service.create_bot({'name': 'New Bot', 'adapter': 'telegram', 'adapter_config': {}})
|
||||
|
||||
# Verify
|
||||
assert bot_uuid is not None
|
||||
assert len(bot_uuid) == 36 # UUID format
|
||||
|
||||
async def test_create_bot_sets_default_pipeline(self):
|
||||
"""Sets default pipeline when one exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_bots': -1}}}
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.load_bot = AsyncMock()
|
||||
|
||||
# Mock default pipeline
|
||||
mock_pipeline = SimpleNamespace()
|
||||
mock_pipeline.uuid = 'default-pipeline-uuid'
|
||||
mock_pipeline.name = 'Default Pipeline'
|
||||
pipeline_result = Mock()
|
||||
pipeline_result.first = Mock(return_value=mock_pipeline)
|
||||
|
||||
# Mock bot after insert
|
||||
bot_result = Mock()
|
||||
bot_result.first = Mock(return_value=_create_mock_bot())
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return pipeline_result # Check default pipeline
|
||||
elif call_count == 2:
|
||||
return Mock() # Insert
|
||||
return bot_result # Get bot
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'new-uuid',
|
||||
'name': 'New Bot',
|
||||
'use_pipeline_uuid': 'default-pipeline-uuid',
|
||||
'use_pipeline_name': 'Default Pipeline',
|
||||
}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
bot_data = {'name': 'New Bot', 'adapter': 'telegram', 'adapter_config': {}}
|
||||
bot_uuid = await service.create_bot(bot_data)
|
||||
|
||||
# Verify - pipeline uuid and name were set
|
||||
assert 'use_pipeline_uuid' in bot_data
|
||||
assert 'use_pipeline_name' in bot_data
|
||||
assert bot_uuid is not None # Verify UUID was returned
|
||||
|
||||
|
||||
class TestBotServiceUpdateBot:
|
||||
"""Tests for update_bot method."""
|
||||
|
||||
async def test_update_bot_removes_uuid_from_data(self):
|
||||
"""Does not persist caller-provided uuid in update payload."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.remove_bot = AsyncMock()
|
||||
|
||||
# Mock pipeline query - not updating pipeline
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.sess_mgr = SimpleNamespace()
|
||||
ap.sess_mgr.session_list = []
|
||||
|
||||
service = BotService(ap)
|
||||
service.get_bot = AsyncMock(return_value={'uuid': 'test-uuid', 'name': 'Updated'})
|
||||
|
||||
# Create mock runtime bot
|
||||
runtime_bot = SimpleNamespace()
|
||||
runtime_bot.enable = False
|
||||
ap.platform_mgr.load_bot = AsyncMock(return_value=runtime_bot)
|
||||
|
||||
# Execute
|
||||
update_data = {'uuid': 'should-be-removed', 'name': 'Updated Name'}
|
||||
await service.update_bot('test-uuid', update_data)
|
||||
|
||||
update_params = ap.persistence_mgr.execute_async.await_args_list[0].args[0].compile().params
|
||||
assert update_params['name'] == 'Updated Name'
|
||||
assert 'should-be-removed' not in update_params.values()
|
||||
|
||||
async def test_update_bot_pipeline_not_found_raises(self):
|
||||
"""Raises Exception when updating with nonexistent pipeline UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Mock pipeline query returns None
|
||||
pipeline_result = Mock()
|
||||
pipeline_result.first = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=pipeline_result)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='Pipeline not found'):
|
||||
await service.update_bot('test-uuid', {'use_pipeline_uuid': 'nonexistent-pipeline'})
|
||||
|
||||
async def test_update_bot_sets_pipeline_name(self):
|
||||
"""Sets use_pipeline_name when updating use_pipeline_uuid."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.remove_bot = AsyncMock()
|
||||
|
||||
# Mock pipeline query
|
||||
mock_pipeline = SimpleNamespace()
|
||||
mock_pipeline.name = 'Updated Pipeline'
|
||||
pipeline_result = Mock()
|
||||
pipeline_result.first = Mock(return_value=mock_pipeline)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return pipeline_result
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.sess_mgr = SimpleNamespace()
|
||||
ap.sess_mgr.session_list = []
|
||||
|
||||
service = BotService(ap)
|
||||
service.get_bot = AsyncMock(return_value={'uuid': 'test-uuid'})
|
||||
|
||||
runtime_bot = SimpleNamespace()
|
||||
runtime_bot.enable = False
|
||||
ap.platform_mgr.load_bot = AsyncMock(return_value=runtime_bot)
|
||||
|
||||
# Execute
|
||||
await service.update_bot('test-uuid', {'use_pipeline_uuid': 'pipeline-uuid'})
|
||||
|
||||
update_params = ap.persistence_mgr.execute_async.await_args_list[1].args[0].compile().params
|
||||
assert update_params['use_pipeline_uuid'] == 'pipeline-uuid'
|
||||
assert update_params['use_pipeline_name'] == 'Updated Pipeline'
|
||||
|
||||
|
||||
class TestBotServiceDeleteBot:
|
||||
"""Tests for delete_bot method."""
|
||||
|
||||
async def test_delete_bot_calls_remove_and_delete(self):
|
||||
"""Calls both platform_mgr.remove_bot and persistence delete."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.remove_bot = AsyncMock()
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_bot('test-uuid')
|
||||
|
||||
# Verify
|
||||
ap.platform_mgr.remove_bot.assert_called_once_with('test-uuid')
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_delete_bot_nonexistent_uuid(self):
|
||||
"""Delete operation completes even for nonexistent UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.remove_bot = AsyncMock()
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute - should not raise
|
||||
await service.delete_bot('nonexistent-uuid')
|
||||
|
||||
# Verify - both called regardless
|
||||
ap.platform_mgr.remove_bot.assert_called_once()
|
||||
|
||||
|
||||
class TestBotServiceListEventLogs:
|
||||
"""Tests for list_event_logs method."""
|
||||
|
||||
async def test_list_event_logs_bot_not_found_raises(self):
|
||||
"""Raises Exception when runtime bot not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='Bot not found'):
|
||||
await service.list_event_logs('nonexistent-uuid', 0, 10)
|
||||
|
||||
async def test_list_event_logs_returns_logs(self):
|
||||
"""Returns logs from runtime bot logger."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
|
||||
# Mock runtime bot with logger
|
||||
runtime_bot = SimpleNamespace()
|
||||
runtime_bot.logger = SimpleNamespace()
|
||||
runtime_bot.logger.get_logs = AsyncMock(return_value=(
|
||||
[SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))],
|
||||
5
|
||||
))
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
logs, total = await service.list_event_logs('bot-uuid', 0, 10)
|
||||
|
||||
# Verify
|
||||
assert len(logs) == 1
|
||||
assert logs[0] == {'msg': 'log1'}
|
||||
assert total == 5
|
||||
|
||||
|
||||
class TestBotServiceSendMessage:
|
||||
"""Tests for send_message method."""
|
||||
|
||||
async def test_send_message_bot_not_found_raises(self):
|
||||
"""Raises Exception when bot not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='Bot not found'):
|
||||
await service.send_message('nonexistent-uuid', 'group', '123', {'test': 'data'})
|
||||
|
||||
async def test_send_message_invalid_message_chain_raises(self):
|
||||
"""Raises Exception when message_chain_data is invalid."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
|
||||
runtime_bot = SimpleNamespace()
|
||||
runtime_bot.adapter = SimpleNamespace()
|
||||
runtime_bot.adapter.send_message = AsyncMock()
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute & Verify - invalid format should raise
|
||||
with pytest.raises(Exception, match='Invalid message_chain format'):
|
||||
await service.send_message('bot-uuid', 'group', '123', {'invalid': 'format'})
|
||||
|
||||
async def test_send_message_valid_call(self):
|
||||
"""Sends message through adapter when all valid."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
|
||||
runtime_bot = SimpleNamespace()
|
||||
runtime_bot.adapter = SimpleNamespace()
|
||||
runtime_bot.adapter.send_message = AsyncMock()
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute with valid message chain format
|
||||
message_chain_data = {
|
||||
'messages': [
|
||||
{'type': 'text', 'data': {'text': 'Hello'}}
|
||||
]
|
||||
}
|
||||
|
||||
# Patch the import location - the module imports inside the function
|
||||
with patch('langbot_plugin.api.entities.builtin.platform.message.MessageChain') as MockMessageChain:
|
||||
mock_chain = Mock()
|
||||
MockMessageChain.model_validate = Mock(return_value=mock_chain)
|
||||
await service.send_message('bot-uuid', 'group', '123', message_chain_data)
|
||||
|
||||
# Verify adapter.send_message was called
|
||||
runtime_bot.adapter.send_message.assert_called_once_with('group', '123', mock_chain)
|
||||
@@ -1,397 +0,0 @@
|
||||
"""Unit tests for API knowledge service.
|
||||
|
||||
Tests cover:
|
||||
- Knowledge base CRUD operations
|
||||
- Capability checking
|
||||
- Knowledge engine discovery
|
||||
- File operations
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
def get_knowledge_service_module():
|
||||
"""Lazy import to avoid circular import issues."""
|
||||
return import_module('langbot.pkg.api.http.service.knowledge')
|
||||
|
||||
|
||||
def create_mock_app():
|
||||
"""Create mock Application for testing."""
|
||||
mock_app = Mock()
|
||||
mock_app.logger = Mock()
|
||||
mock_app.rag_mgr = AsyncMock()
|
||||
mock_app.persistence_mgr = AsyncMock()
|
||||
mock_app.persistence_mgr.execute_async = AsyncMock()
|
||||
mock_app.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
mock_app.plugin_connector = AsyncMock()
|
||||
mock_app.plugin_connector.is_enable_plugin = True
|
||||
return mock_app
|
||||
|
||||
|
||||
class TestKnowledgeServiceInit:
|
||||
"""Tests for KnowledgeService initialization."""
|
||||
|
||||
def test_init_stores_app_reference(self):
|
||||
"""Test that __init__ stores Application reference."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
assert service.ap is mock_app
|
||||
|
||||
|
||||
class TestGetKnowledgeBases:
|
||||
"""Tests for get_knowledge_bases method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_all_kb_details(self):
|
||||
"""Test that it returns all knowledge base details."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(
|
||||
return_value=[{'uuid': 'kb1', 'name': 'KB1'}]
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_knowledge_bases()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]['uuid'] == 'kb1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_list_when_no_kbs(self):
|
||||
"""Test that it returns empty list when no knowledge bases."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(return_value=[])
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_knowledge_bases()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestGetKnowledgeBase:
|
||||
"""Tests for get_knowledge_base method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_kb_details_by_uuid(self):
|
||||
"""Test that it returns specific KB details."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
|
||||
return_value={'uuid': 'kb1', 'name': 'KB1'}
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_knowledge_base('kb1')
|
||||
|
||||
assert result['uuid'] == 'kb1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_found(self):
|
||||
"""Test that it returns None when KB not found."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value=None)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_knowledge_base('nonexistent')
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCreateKnowledgeBase:
|
||||
"""Tests for create_knowledge_base method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_kb_with_required_fields(self):
|
||||
"""Test creating KB with required plugin ID."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_kb = Mock()
|
||||
mock_kb.uuid = 'new_kb_uuid'
|
||||
mock_app.rag_mgr.create_knowledge_base = AsyncMock(return_value=mock_kb)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
kb_data = {
|
||||
'name': 'Test KB',
|
||||
'knowledge_engine_plugin_id': 'author/engine',
|
||||
'description': 'Test description',
|
||||
}
|
||||
|
||||
result = await service.create_knowledge_base(kb_data)
|
||||
|
||||
assert result == 'new_kb_uuid'
|
||||
mock_app.rag_mgr.create_knowledge_base.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_missing_plugin_id(self):
|
||||
"""Test that ValueError is raised when plugin ID missing."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await service.create_knowledge_base({'name': 'Test'})
|
||||
|
||||
assert 'knowledge_engine_plugin_id is required' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_with_default_name(self):
|
||||
"""Test that KB is created with default name if not provided."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_kb = Mock()
|
||||
mock_kb.uuid = 'new_kb_uuid'
|
||||
mock_app.rag_mgr.create_knowledge_base = AsyncMock(return_value=mock_kb)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
await service.create_knowledge_base({
|
||||
'knowledge_engine_plugin_id': 'author/engine'
|
||||
})
|
||||
|
||||
# Check that default name 'Untitled' was used
|
||||
call_args = mock_app.rag_mgr.create_knowledge_base.call_args
|
||||
assert call_args.kwargs['name'] == 'Untitled'
|
||||
|
||||
|
||||
class TestUpdateKnowledgeBase:
|
||||
"""Tests for update_knowledge_base method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_mutable_fields_only(self):
|
||||
"""Test that only mutable fields are updated."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
|
||||
return_value={'uuid': 'kb1', 'name': 'Updated'}
|
||||
)
|
||||
mock_app.rag_mgr.remove_knowledge_base_from_runtime = AsyncMock()
|
||||
mock_app.rag_mgr.load_knowledge_base = AsyncMock()
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
# Pass both mutable and immutable fields
|
||||
await service.update_knowledge_base('kb1', {
|
||||
'name': 'New Name',
|
||||
'description': 'New desc',
|
||||
'uuid': 'should_be_filtered', # immutable
|
||||
})
|
||||
|
||||
# Check that only mutable fields were passed to update
|
||||
call_args = mock_app.persistence_mgr.execute_async.call_args
|
||||
assert call_args is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_early_when_no_mutable_fields(self):
|
||||
"""Test that update returns early when no mutable fields provided."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
# Pass only immutable fields
|
||||
await service.update_knowledge_base('kb1', {'uuid': 'should_be_filtered'})
|
||||
|
||||
# No DB update should be called
|
||||
mock_app.persistence_mgr.execute_async.assert_not_called()
|
||||
|
||||
|
||||
class TestCheckDocCapability:
|
||||
"""Tests for _check_doc_capability method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_when_capability_supported(self):
|
||||
"""Test that check passes when doc_ingestion capability exists."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
|
||||
return_value={'knowledge_engine': {'capabilities': ['doc_ingestion']}}
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
await service._check_doc_capability('kb1', 'document upload')
|
||||
|
||||
# No exception raised means success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_kb_not_found(self):
|
||||
"""Test that Exception is raised when KB not found."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value=None)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await service._check_doc_capability('nonexistent', 'test operation')
|
||||
|
||||
assert 'Knowledge base not found' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_capability_not_supported(self):
|
||||
"""Test that Exception is raised when doc_ingestion not in capabilities."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
|
||||
return_value={'knowledge_engine': {'capabilities': ['other_capability']}}
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await service._check_doc_capability('kb1', 'document upload')
|
||||
|
||||
assert 'does not support document upload' in str(exc_info.value)
|
||||
|
||||
|
||||
class TestListKnowledgeEngines:
|
||||
"""Tests for list_knowledge_engines method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_engines_from_plugin_connector(self):
|
||||
"""Test that it returns knowledge engines from plugin connector."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
|
||||
return_value=[{'id': 'engine1', 'name': 'Engine 1'}]
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.list_knowledge_engines()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]['id'] == 'engine1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_when_plugin_disabled(self):
|
||||
"""Test that it returns empty list when plugin disabled."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.is_enable_plugin = False
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.list_knowledge_engines()
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_on_exception(self):
|
||||
"""Test that it returns empty list and logs warning on exception."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
|
||||
side_effect=Exception('Connection error')
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.list_knowledge_engines()
|
||||
|
||||
assert result == []
|
||||
mock_app.logger.warning.assert_called_once()
|
||||
|
||||
|
||||
class TestListParsers:
|
||||
"""Tests for list_parsers method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_all_parsers(self):
|
||||
"""Test that it returns all parsers when no MIME type filter."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.list_parsers = AsyncMock(
|
||||
return_value=[
|
||||
{'id': 'parser1', 'supported_mime_types': ['text/plain']},
|
||||
{'id': 'parser2', 'supported_mime_types': ['application/pdf']},
|
||||
]
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.list_parsers()
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filters_by_mime_type(self):
|
||||
"""Test that it filters parsers by MIME type."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.list_parsers = AsyncMock(
|
||||
return_value=[
|
||||
{'id': 'parser1', 'supported_mime_types': ['text/plain']},
|
||||
{'id': 'parser2', 'supported_mime_types': ['application/pdf']},
|
||||
]
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.list_parsers(mime_type='application/pdf')
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]['id'] == 'parser2'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_when_plugin_disabled(self):
|
||||
"""Test that it returns empty list when plugin disabled."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.is_enable_plugin = False
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.list_parsers()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestGetEngineSchemas:
|
||||
"""Tests for get_engine_creation_schema and get_engine_retrieval_schema."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_creation_schema(self):
|
||||
"""Test that it returns creation schema for engine."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
|
||||
return_value={'properties': {'name': {'type': 'string'}}}
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_engine_creation_schema('author/engine')
|
||||
|
||||
assert 'properties' in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_retrieval_schema(self):
|
||||
"""Test that it returns retrieval schema for engine."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.get_rag_retrieval_schema = AsyncMock(
|
||||
return_value={'properties': {'top_k': {'type': 'integer'}}}
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_engine_retrieval_schema('author/engine')
|
||||
|
||||
assert 'properties' in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dict_on_exception(self):
|
||||
"""Test that it returns empty dict and logs warning on exception."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
|
||||
side_effect=Exception('Plugin error')
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_engine_creation_schema('author/engine')
|
||||
|
||||
assert result == {}
|
||||
mock_app.logger.warning.assert_called_once()
|
||||
@@ -1,824 +0,0 @@
|
||||
"""
|
||||
Unit tests for MaintenanceService.
|
||||
|
||||
Tests storage maintenance and diagnostics including:
|
||||
- Cleanup expired files
|
||||
- Storage analysis
|
||||
- File counting and sizing
|
||||
- Monitoring counts
|
||||
- Binary storage stats
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/maintenance.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||
from types import SimpleNamespace
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from langbot.pkg.api.http.service.maintenance import MaintenanceService
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_result(scalar_value=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.scalar = Mock(return_value=scalar_value)
|
||||
return result
|
||||
|
||||
|
||||
class TestMaintenanceServiceCleanupExpiredFiles:
|
||||
"""Tests for cleanup_expired_files method."""
|
||||
|
||||
async def test_cleanup_expired_files_default_retention(self):
|
||||
"""Uses default retention days when config not set."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.storage_mgr = SimpleNamespace()
|
||||
|
||||
# Create a proper mock object with __class__.__name__
|
||||
storage_provider = MagicMock()
|
||||
storage_provider.__class__.__name__ = 'LocalStorageProvider'
|
||||
ap.storage_mgr.storage_provider = storage_provider
|
||||
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock the internal cleanup methods - one is async, one is not
|
||||
service._cleanup_expired_uploaded_files = AsyncMock(return_value=0)
|
||||
service._cleanup_expired_log_files = Mock(return_value=0) # NOT async!
|
||||
|
||||
# Execute
|
||||
result = await service.cleanup_expired_files()
|
||||
|
||||
# Verify - returns counts
|
||||
assert 'uploaded_files' in result
|
||||
assert 'log_files' in result
|
||||
assert result['uploaded_files'] == 0
|
||||
assert result['log_files'] == 0
|
||||
|
||||
async def test_cleanup_expired_files_custom_retention(self):
|
||||
"""Uses custom retention days from config."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'storage': {
|
||||
'cleanup': {
|
||||
'uploaded_file_retention_days': 14,
|
||||
'log_retention_days': 7,
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.storage_mgr = SimpleNamespace()
|
||||
|
||||
storage_provider = MagicMock()
|
||||
storage_provider.__class__.__name__ = 'LocalStorageProvider'
|
||||
ap.storage_mgr.storage_provider = storage_provider
|
||||
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock the internal cleanup methods
|
||||
service._cleanup_expired_uploaded_files = AsyncMock(return_value=2)
|
||||
service._cleanup_expired_log_files = Mock(return_value=3) # NOT async
|
||||
|
||||
# Execute
|
||||
result = await service.cleanup_expired_files()
|
||||
|
||||
# Verify
|
||||
assert result['uploaded_files'] == 2
|
||||
assert result['log_files'] == 3
|
||||
|
||||
async def test_cleanup_expired_files_s3_provider(self):
|
||||
"""Handles S3StorageProvider correctly."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.storage_mgr = SimpleNamespace()
|
||||
|
||||
# Mock S3 provider
|
||||
s3_provider = MagicMock()
|
||||
s3_provider.__class__.__name__ = 'S3StorageProvider'
|
||||
s3_provider.delete = AsyncMock()
|
||||
ap.storage_mgr.storage_provider = s3_provider
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock the internal cleanup methods
|
||||
service._cleanup_expired_uploaded_files = AsyncMock(return_value=1)
|
||||
service._cleanup_expired_log_files = Mock(return_value=0) # NOT async
|
||||
|
||||
# Execute
|
||||
result = await service.cleanup_expired_files()
|
||||
|
||||
# Verify
|
||||
assert result['uploaded_files'] == 1
|
||||
assert result['log_files'] == 0
|
||||
|
||||
async def test_cleanup_expired_files_invalid_retention(self):
|
||||
"""Uses default for invalid retention config."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'storage': {
|
||||
'cleanup': {
|
||||
'uploaded_file_retention_days': 'invalid', # Invalid
|
||||
'log_retention_days': 0, # Invalid (less than 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.storage_mgr = SimpleNamespace()
|
||||
|
||||
storage_provider = MagicMock()
|
||||
storage_provider.__class__.__name__ = 'LocalStorageProvider'
|
||||
ap.storage_mgr.storage_provider = storage_provider
|
||||
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock the internal cleanup methods
|
||||
service._cleanup_expired_uploaded_files = AsyncMock(return_value=0)
|
||||
service._cleanup_expired_log_files = Mock(return_value=0) # NOT async
|
||||
|
||||
# Execute
|
||||
result = await service.cleanup_expired_files()
|
||||
|
||||
# Verify - warning logged, defaults used
|
||||
assert ap.logger.warning.called
|
||||
assert 'uploaded_files' in result
|
||||
|
||||
|
||||
class TestMaintenanceServiceGetStorageAnalysis:
|
||||
"""Tests for get_storage_analysis method."""
|
||||
|
||||
async def test_get_storage_analysis_basic(self):
|
||||
"""Returns basic storage analysis."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}
|
||||
}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
ap.task_mgr = SimpleNamespace()
|
||||
ap.task_mgr.get_stats = Mock(return_value={'running': 0})
|
||||
|
||||
# Mock monitoring counts
|
||||
count_result = _create_mock_result(scalar_value=10)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock file operations
|
||||
service._path_size = Mock(return_value=1000)
|
||||
service._file_count = Mock(return_value=5)
|
||||
service._monitoring_counts = AsyncMock(return_value={'messages': 10, 'errors': 0})
|
||||
service._binary_storage_stats = AsyncMock(return_value={'count': 5, 'size_bytes': 500})
|
||||
service._expired_uploaded_candidates = AsyncMock(return_value=[])
|
||||
service._expired_log_candidates = Mock(return_value=[])
|
||||
|
||||
# Execute
|
||||
result = await service.get_storage_analysis()
|
||||
|
||||
# Verify
|
||||
assert 'generated_at' in result
|
||||
assert 'cleanup_policy' in result
|
||||
assert 'sections' in result
|
||||
assert 'database' in result
|
||||
assert 'cleanup_candidates' in result
|
||||
|
||||
async def test_get_storage_analysis_sections(self):
|
||||
"""Returns all storage sections."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'database': {'use': 'postgresql'}}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
ap.task_mgr = None
|
||||
|
||||
count_result = _create_mock_result(scalar_value=0)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
service._path_size = Mock(return_value=0)
|
||||
service._file_count = Mock(return_value=0)
|
||||
service._monitoring_counts = AsyncMock(return_value={})
|
||||
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
|
||||
service._expired_uploaded_candidates = AsyncMock(return_value=[])
|
||||
service._expired_log_candidates = Mock(return_value=[])
|
||||
|
||||
# Execute
|
||||
result = await service.get_storage_analysis()
|
||||
|
||||
# Verify - all sections present
|
||||
sections = {s['key'] for s in result['sections']}
|
||||
assert 'database' in sections
|
||||
assert 'logs' in sections
|
||||
assert 'storage' in sections
|
||||
assert 'vector_store' in sections
|
||||
assert 'plugins' in sections
|
||||
assert 'mcp' in sections
|
||||
assert 'temp' in sections
|
||||
|
||||
async def test_get_storage_analysis_postgresql(self):
|
||||
"""Handles PostgreSQL database type."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'database': {'use': 'postgresql'}}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
ap.task_mgr = None
|
||||
|
||||
count_result = _create_mock_result(scalar_value=0)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
service._path_size = Mock(return_value=0)
|
||||
service._file_count = Mock(return_value=0)
|
||||
service._monitoring_counts = AsyncMock(return_value={})
|
||||
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': None})
|
||||
service._expired_uploaded_candidates = AsyncMock(return_value=[])
|
||||
service._expired_log_candidates = Mock(return_value=[])
|
||||
|
||||
# Execute
|
||||
result = await service.get_storage_analysis()
|
||||
|
||||
# Verify
|
||||
assert result['database']['type'] == 'postgresql'
|
||||
|
||||
async def test_get_storage_analysis_with_cleanup_candidates(self):
|
||||
"""Returns cleanup candidates in analysis."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
ap.task_mgr = None
|
||||
|
||||
count_result = _create_mock_result(scalar_value=0)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
service._path_size = Mock(return_value=0)
|
||||
service._file_count = Mock(return_value=0)
|
||||
service._monitoring_counts = AsyncMock(return_value={})
|
||||
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
|
||||
service._expired_uploaded_candidates = AsyncMock(return_value=[
|
||||
{'key': 'old_file', 'size_bytes': 100}
|
||||
])
|
||||
service._expired_log_candidates = Mock(return_value=[
|
||||
{'name': 'old_log', 'size_bytes': 50}
|
||||
])
|
||||
|
||||
# Execute
|
||||
result = await service.get_storage_analysis()
|
||||
|
||||
# Verify
|
||||
assert len(result['cleanup_candidates']['uploaded_files']) == 1
|
||||
assert len(result['cleanup_candidates']['log_files']) == 1
|
||||
|
||||
|
||||
class TestMaintenanceServiceMonitoringCounts:
|
||||
"""Tests for _monitoring_counts method."""
|
||||
|
||||
async def test_monitoring_counts_returns_counts(self):
|
||||
"""Returns counts for all monitoring tables."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
count_result = _create_mock_result(scalar_value=42)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._monitoring_counts()
|
||||
|
||||
# Verify - all table keys present
|
||||
assert 'messages' in result
|
||||
assert 'llm_calls' in result
|
||||
assert 'embedding_calls' in result
|
||||
assert 'errors' in result
|
||||
assert 'sessions' in result
|
||||
assert 'feedback' in result
|
||||
|
||||
async def test_monitoring_counts_zero_results(self):
|
||||
"""Returns zero counts when tables empty."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
count_result = _create_mock_result(scalar_value=0)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._monitoring_counts()
|
||||
|
||||
# Verify - all zero
|
||||
assert all(v == 0 for v in result.values())
|
||||
|
||||
|
||||
class TestMaintenanceServiceBinaryStorageStats:
|
||||
"""Tests for _binary_storage_stats method."""
|
||||
|
||||
async def test_binary_storage_stats_returns_stats(self):
|
||||
"""Returns count and size for binary storage."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
# Mock count result
|
||||
count_result = _create_mock_result(scalar_value=10)
|
||||
# Mock size result
|
||||
size_result = _create_mock_result(scalar_value=5000)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return count_result
|
||||
return size_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._binary_storage_stats()
|
||||
|
||||
# Verify
|
||||
assert result['count'] == 10
|
||||
assert result['size_bytes'] == 5000
|
||||
|
||||
async def test_binary_storage_stats_size_error(self):
|
||||
"""Handles error when calculating size."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
count_result = _create_mock_result(scalar_value=5)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return count_result
|
||||
raise Exception('Size calculation error')
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._binary_storage_stats()
|
||||
|
||||
# Verify - warning logged, size_bytes None or 0
|
||||
assert ap.logger.warning.called
|
||||
assert result['count'] == 5
|
||||
|
||||
|
||||
class TestMaintenanceServicePathSize:
|
||||
"""Tests for _path_size method."""
|
||||
|
||||
def test_path_size_nonexistent_path(self):
|
||||
"""Returns 0 for nonexistent path."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._path_size(Path('/nonexistent/path'))
|
||||
|
||||
# Verify
|
||||
assert result == 0
|
||||
|
||||
def test_path_size_single_file(self):
|
||||
"""Returns size for single file."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock file
|
||||
mock_stat = Mock()
|
||||
mock_stat.st_size = 100
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'is_file', return_value=True):
|
||||
with patch.object(Path, 'stat', return_value=mock_stat):
|
||||
result = service._path_size(Path('test.txt'))
|
||||
|
||||
# Verify
|
||||
assert result == 100
|
||||
|
||||
def test_path_size_directory(self):
|
||||
"""Returns total size for directory."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock os.walk
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'is_file', return_value=False):
|
||||
with patch('os.walk') as mock_walk:
|
||||
mock_walk.return_value = [
|
||||
('/test_dir', [], ['file1.txt', 'file2.txt']),
|
||||
]
|
||||
|
||||
# Mock file stat
|
||||
mock_stat = Mock()
|
||||
mock_stat.st_size = 50
|
||||
|
||||
with patch.object(Path, 'stat', return_value=mock_stat):
|
||||
result = service._path_size(Path('/test_dir'))
|
||||
|
||||
# Verify - 2 files * 50 bytes
|
||||
assert result == 100
|
||||
|
||||
|
||||
class TestMaintenanceServiceFileCount:
|
||||
"""Tests for _file_count method."""
|
||||
|
||||
def test_file_count_nonexistent_path(self):
|
||||
"""Returns 0 for nonexistent path."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._file_count(Path('/nonexistent/path'))
|
||||
|
||||
# Verify
|
||||
assert result == 0
|
||||
|
||||
def test_file_count_single_file(self):
|
||||
"""Returns 1 for single file."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'is_file', return_value=True):
|
||||
result = service._file_count(Path('test.txt'))
|
||||
|
||||
# Verify
|
||||
assert result == 1
|
||||
|
||||
def test_file_count_directory(self):
|
||||
"""Returns file count for directory."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'is_file', return_value=False):
|
||||
with patch('os.walk') as mock_walk:
|
||||
mock_walk.return_value = [
|
||||
('/test_dir', [], ['file1.txt', 'file2.txt', 'file3.txt']),
|
||||
]
|
||||
result = service._file_count(Path('/test_dir'))
|
||||
|
||||
# Verify
|
||||
assert result == 3
|
||||
|
||||
|
||||
class TestMaintenanceServicePositiveInt:
|
||||
"""Tests for _positive_int helper method."""
|
||||
|
||||
def test_positive_int_valid_value(self):
|
||||
"""Returns valid positive integer."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._positive_int(7, 5, 'test_param')
|
||||
|
||||
# Verify
|
||||
assert result == 7
|
||||
assert not ap.logger.warning.called
|
||||
|
||||
def test_positive_int_invalid_string(self):
|
||||
"""Returns default for invalid string."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._positive_int('invalid', 5, 'test_param')
|
||||
|
||||
# Verify
|
||||
assert result == 5
|
||||
assert ap.logger.warning.called
|
||||
|
||||
def test_positive_int_invalid_none(self):
|
||||
"""Returns default for None."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._positive_int(None, 5, 'test_param')
|
||||
|
||||
# Verify
|
||||
assert result == 5
|
||||
assert ap.logger.warning.called
|
||||
|
||||
def test_positive_int_negative_value(self):
|
||||
"""Returns default for negative value."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._positive_int(-1, 5, 'test_param')
|
||||
|
||||
# Verify
|
||||
assert result == 5
|
||||
assert ap.logger.warning.called
|
||||
|
||||
def test_positive_int_zero_value(self):
|
||||
"""Returns default for zero value."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._positive_int(0, 5, 'test_param')
|
||||
|
||||
# Verify
|
||||
assert result == 5
|
||||
assert ap.logger.warning.called
|
||||
|
||||
|
||||
class TestMaintenanceServiceIsUploadedFileKey:
|
||||
"""Tests for _is_uploaded_file_key helper method."""
|
||||
|
||||
def test_is_uploaded_file_key_valid(self):
|
||||
"""Returns True for valid upload file key."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute - simple filename without path
|
||||
result = service._is_uploaded_file_key('uploaded_file.txt')
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
|
||||
def test_is_uploaded_file_key_with_path(self):
|
||||
"""Returns False for key with path separator."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute - key with path
|
||||
result = service._is_uploaded_file_key('path/to/file.txt')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
def test_is_uploaded_file_key_plugin_config(self):
|
||||
"""Returns False for plugin config prefix."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute - plugin config file
|
||||
result = service._is_uploaded_file_key('plugin_config_some_plugin.json')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestMaintenanceServiceExpiredLogCandidates:
|
||||
"""Tests for _expired_log_candidates method."""
|
||||
|
||||
def test_expired_log_candidates_nonexistent_dir(self):
|
||||
"""Returns empty list when logs dir not exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=False):
|
||||
result = service._expired_log_candidates(3)
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
def test_expired_log_candidates_matches_pattern(self):
|
||||
"""Matches log file pattern correctly."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock directory with log files
|
||||
old_date = datetime.date.today() - datetime.timedelta(days=10)
|
||||
old_log_name = f'langbot-{old_date.isoformat()}.log'
|
||||
recent_log_name = f'langbot-{datetime.date.today().isoformat()}.log'
|
||||
|
||||
mock_entry_old = Mock(spec=Path)
|
||||
mock_entry_old.is_file = Mock(return_value=True)
|
||||
mock_entry_old.name = old_log_name
|
||||
mock_stat = Mock()
|
||||
mock_stat.st_size = 1000
|
||||
mock_entry_old.stat = Mock(return_value=mock_stat)
|
||||
|
||||
mock_entry_recent = Mock(spec=Path)
|
||||
mock_entry_recent.is_file = Mock(return_value=True)
|
||||
mock_entry_recent.name = recent_log_name
|
||||
mock_stat2 = Mock()
|
||||
mock_stat2.st_size = 500
|
||||
mock_entry_recent.stat = Mock(return_value=mock_stat2)
|
||||
|
||||
# Non-log file
|
||||
mock_entry_other = Mock(spec=Path)
|
||||
mock_entry_other.is_file = Mock(return_value=True)
|
||||
mock_entry_other.name = 'other_file.txt'
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'iterdir') as mock_iterdir:
|
||||
mock_iterdir.return_value = [mock_entry_old, mock_entry_recent, mock_entry_other]
|
||||
result = service._expired_log_candidates(3)
|
||||
|
||||
# Verify - only old log included
|
||||
assert len(result) == 1
|
||||
assert result[0]['name'] == old_log_name
|
||||
|
||||
def test_expired_log_candidates_includes_path(self):
|
||||
"""Includes path when include_paths=True."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
old_date = datetime.date.today() - datetime.timedelta(days=10)
|
||||
old_log_name = f'langbot-{old_date.isoformat()}.log'
|
||||
|
||||
mock_entry = Mock(spec=Path)
|
||||
mock_entry.is_file = Mock(return_value=True)
|
||||
mock_entry.name = old_log_name
|
||||
mock_entry.__str__ = Mock(return_value='/data/logs/' + old_log_name)
|
||||
mock_stat = Mock()
|
||||
mock_stat.st_size = 1000
|
||||
mock_entry.stat = Mock(return_value=mock_stat)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'iterdir') as mock_iterdir:
|
||||
mock_iterdir.return_value = [mock_entry]
|
||||
result = service._expired_log_candidates(3, include_paths=True)
|
||||
|
||||
# Verify - path included
|
||||
assert 'path' in result[0]
|
||||
|
||||
|
||||
class TestMaintenanceServiceExpiredLocalUploadCandidates:
|
||||
"""Tests for _expired_local_upload_candidates method."""
|
||||
|
||||
def test_expired_local_upload_candidates_nonexistent_dir(self):
|
||||
"""Returns empty list when storage dir not exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=False):
|
||||
result = service._expired_local_upload_candidates(7)
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
def test_expired_local_upload_candidates_filters_uploaded(self):
|
||||
"""Only returns uploaded files matching pattern."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
# Mock _is_uploaded_file_key
|
||||
service._is_uploaded_file_key = Mock(side_effect=lambda key: 'plugin_config_' not in key and '/' not in key)
|
||||
|
||||
# Create mock files - one valid, one plugin config
|
||||
mock_entry_valid = Mock(spec=Path)
|
||||
mock_entry_valid.is_file = Mock(return_value=True)
|
||||
mock_entry_valid.name = 'valid_upload.txt'
|
||||
mock_stat = Mock()
|
||||
mock_stat.st_size = 100
|
||||
mock_stat.st_mtime = 0 # Very old
|
||||
mock_entry_valid.stat = Mock(return_value=mock_stat)
|
||||
|
||||
mock_entry_plugin = Mock(spec=Path)
|
||||
mock_entry_plugin.is_file = Mock(return_value=True)
|
||||
mock_entry_plugin.name = 'plugin_config_test.json'
|
||||
mock_stat2 = Mock()
|
||||
mock_stat2.st_size = 200
|
||||
mock_stat2.st_mtime = 0
|
||||
mock_entry_plugin.stat = Mock(return_value=mock_stat2)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'iterdir') as mock_iterdir:
|
||||
mock_iterdir.return_value = [mock_entry_valid, mock_entry_plugin]
|
||||
result = service._expired_local_upload_candidates(7)
|
||||
|
||||
# Verify - only valid upload included
|
||||
assert len(result) == 1
|
||||
assert result[0]['key'] == 'valid_upload.txt'
|
||||
|
||||
def test_expired_local_upload_candidates_includes_path(self):
|
||||
"""Includes path when include_paths=True."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
service._is_uploaded_file_key = Mock(return_value=True)
|
||||
|
||||
mock_entry = Mock(spec=Path)
|
||||
mock_entry.is_file = Mock(return_value=True)
|
||||
mock_entry.name = 'old_file.txt'
|
||||
mock_entry.__str__ = Mock(return_value='/data/storage/old_file.txt')
|
||||
mock_stat = Mock()
|
||||
mock_stat.st_size = 100
|
||||
mock_stat.st_mtime = 0
|
||||
mock_entry.stat = Mock(return_value=mock_stat)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'iterdir') as mock_iterdir:
|
||||
mock_iterdir.return_value = [mock_entry]
|
||||
result = service._expired_local_upload_candidates(7, include_paths=True)
|
||||
|
||||
# Verify - path included
|
||||
assert 'path' in result[0]
|
||||
@@ -1,648 +0,0 @@
|
||||
"""
|
||||
Unit tests for MCPService.
|
||||
|
||||
Tests MCP server CRUD operations including:
|
||||
- MCP server listing with runtime info
|
||||
- MCP server creation with limitations
|
||||
- MCP server update with enable/disable
|
||||
- MCP server deletion
|
||||
- MCP server connection testing
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/mcp.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, MagicMock
|
||||
from types import SimpleNamespace
|
||||
import uuid
|
||||
|
||||
from langbot.pkg.api.http.service.mcp import MCPService
|
||||
from langbot.pkg.entity.persistence.mcp import MCPServer
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_mcp_server(
|
||||
server_uuid: str = None,
|
||||
name: str = 'Test MCP Server',
|
||||
enable: bool = True,
|
||||
mode: str = 'stdio',
|
||||
extra_args: dict = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock MCPServer entity."""
|
||||
server = Mock(spec=MCPServer)
|
||||
server.uuid = server_uuid or str(uuid.uuid4())
|
||||
server.name = name
|
||||
server.enable = enable
|
||||
server.mode = mode
|
||||
server.extra_args = extra_args or {}
|
||||
return server
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestMCPServiceGetRuntimeInfo:
|
||||
"""Tests for get_runtime_info method."""
|
||||
|
||||
async def test_get_runtime_info_session_exists(self):
|
||||
"""Returns runtime info when session exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
|
||||
mock_session = SimpleNamespace()
|
||||
mock_session.get_runtime_info_dict = Mock(return_value={'status': 'running', 'tools': 5})
|
||||
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_runtime_info('test-server')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['status'] == 'running'
|
||||
|
||||
async def test_get_runtime_info_session_not_exists(self):
|
||||
"""Returns None when session not exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=None)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_runtime_info('nonexistent-server')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestMCPServiceGetMCPServers:
|
||||
"""Tests for get_mcp_servers method."""
|
||||
|
||||
async def test_get_mcp_servers_empty_list(self):
|
||||
"""Returns empty list when no MCP servers exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
}
|
||||
)
|
||||
ap.tool_mgr = None
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_mcp_servers()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_mcp_servers_returns_serialized_list(self):
|
||||
"""Returns serialized list of MCP servers."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
server1 = _create_mock_mcp_server(server_uuid='uuid-1', name='Server 1')
|
||||
server2 = _create_mock_mcp_server(server_uuid='uuid-2', name='Server 2')
|
||||
|
||||
mock_result = _create_mock_result([server1, server2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'enable': entity.enable,
|
||||
'mode': entity.mode,
|
||||
}
|
||||
)
|
||||
ap.tool_mgr = None
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_mcp_servers()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert result[0]['name'] == 'Server 1'
|
||||
assert result[1]['name'] == 'Server 2'
|
||||
|
||||
async def test_get_mcp_servers_with_runtime_info(self):
|
||||
"""Returns MCP servers with runtime info when requested."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
server1 = _create_mock_mcp_server(server_uuid='uuid-1', name='Server 1')
|
||||
|
||||
mock_result = _create_mock_result([server1])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
}
|
||||
)
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=None)
|
||||
|
||||
service = MCPService(ap)
|
||||
service.get_runtime_info = AsyncMock(return_value={'status': 'connected'})
|
||||
|
||||
# Execute
|
||||
result = await service.get_mcp_servers(contain_runtime_info=True)
|
||||
|
||||
# Verify - runtime info included
|
||||
assert result[0]['runtime_info'] == {'status': 'connected'}
|
||||
|
||||
|
||||
class TestMCPServiceCreateMCPServer:
|
||||
"""Tests for create_mcp_server method."""
|
||||
|
||||
async def test_create_mcp_server_max_extensions_reached_raises(self):
|
||||
"""Raises ValueError when max_extensions limit reached."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_extensions': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.plugin_connector = SimpleNamespace()
|
||||
ap.plugin_connector.list_plugins = AsyncMock(return_value=[Mock(), Mock()]) # 2 plugins
|
||||
|
||||
# Mock get_mcp_servers to return 0 servers (2 plugins already)
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
ap.tool_mgr = None
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute & Verify - 2 plugins + new server would exceed limit
|
||||
with pytest.raises(ValueError, match='Maximum number of extensions'):
|
||||
await service.create_mcp_server({'name': 'New Server'})
|
||||
|
||||
async def test_create_mcp_server_no_limit(self):
|
||||
"""Creates MCP server without limit when max_extensions=-1."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_extensions': -1 # No limit
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.tool_mgr = None
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid'})
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
server_uuid = await service.create_mcp_server({'name': 'New Server'})
|
||||
|
||||
# Verify
|
||||
assert server_uuid is not None
|
||||
assert len(server_uuid) == 36 # UUID format
|
||||
|
||||
async def test_create_mcp_server_loads_server(self):
|
||||
"""Loads server into tool_mgr when enabled."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_extensions': -1}}}
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.host_mcp_server = AsyncMock()
|
||||
ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks = []
|
||||
|
||||
# Create mock server entity
|
||||
server_entity = _create_mock_mcp_server(server_uuid='new-uuid', enable=True)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result([]) # Empty list for limit check
|
||||
elif call_count == 2:
|
||||
return Mock() # Insert
|
||||
return _create_mock_result(first_item=server_entity) # Select created
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'new-uuid', 'name': 'New Server', 'enable': True}
|
||||
)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
await service.create_mcp_server({'name': 'New Server', 'enable': True})
|
||||
|
||||
# Verify - host_mcp_server was called
|
||||
ap.tool_mgr.mcp_tool_loader.host_mcp_server.assert_called_once()
|
||||
|
||||
async def test_create_mcp_server_disabled_no_load(self):
|
||||
"""Does not load server when disabled."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_extensions': -1}}}
|
||||
ap.tool_mgr = None
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid'})
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute with enable=False
|
||||
server_uuid = await service.create_mcp_server({'name': 'New Server', 'enable': False})
|
||||
|
||||
# Verify - no tool_mgr load attempt
|
||||
assert server_uuid is not None
|
||||
|
||||
|
||||
class TestMCPServiceGetMCPServerByName:
|
||||
"""Tests for get_mcp_server_by_name method."""
|
||||
|
||||
async def test_get_mcp_server_by_name_found(self):
|
||||
"""Returns MCP server when found by name."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
server = _create_mock_mcp_server(name='Found Server')
|
||||
mock_result = _create_mock_result(first_item=server)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'name': 'Found Server',
|
||||
'runtime_info': None,
|
||||
}
|
||||
)
|
||||
ap.tool_mgr = None
|
||||
|
||||
service = MCPService(ap)
|
||||
service.get_runtime_info = AsyncMock(return_value=None)
|
||||
|
||||
# Execute
|
||||
result = await service.get_mcp_server_by_name('Found Server')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['name'] == 'Found Server'
|
||||
|
||||
async def test_get_mcp_server_by_name_not_found(self):
|
||||
"""Returns None when MCP server not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_mcp_server_by_name('Nonexistent Server')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestMCPServiceUpdateMCPServer:
|
||||
"""Tests for update_mcp_server method."""
|
||||
|
||||
async def test_update_mcp_server_disable_enabled_server(self):
|
||||
"""Removes server when disabling previously enabled server."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.sessions = {'Old Server': Mock()}
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
|
||||
|
||||
old_server = _create_mock_mcp_server(name='Old Server', enable=True)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=old_server)
|
||||
return Mock() # Update
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute - disable server
|
||||
await service.update_mcp_server('test-uuid', {'enable': False})
|
||||
|
||||
# Verify - server was removed
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server.assert_called_once()
|
||||
|
||||
async def test_update_mcp_server_enable_disabled_server(self):
|
||||
"""Loads server when enabling previously disabled server."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.sessions = {}
|
||||
ap.tool_mgr.mcp_tool_loader.host_mcp_server = AsyncMock()
|
||||
ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks = []
|
||||
|
||||
old_server = _create_mock_mcp_server(name='Old Server', enable=False)
|
||||
|
||||
updated_server = _create_mock_mcp_server(name='Old Server', enable=True)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=old_server)
|
||||
elif call_count == 2:
|
||||
return Mock() # Update
|
||||
return _create_mock_result(first_item=updated_server) # Select updated
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'test-uuid', 'name': 'Old Server', 'enable': True}
|
||||
)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute - enable server
|
||||
await service.update_mcp_server('test-uuid', {'enable': True})
|
||||
|
||||
# Verify - server was loaded
|
||||
ap.tool_mgr.mcp_tool_loader.host_mcp_server.assert_called_once()
|
||||
|
||||
async def test_update_mcp_server_update_enabled_server(self):
|
||||
"""Removes and reloads server when updating enabled server."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.sessions = {'Old Server': Mock()}
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
|
||||
ap.tool_mgr.mcp_tool_loader.host_mcp_server = AsyncMock()
|
||||
ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks = []
|
||||
|
||||
old_server = _create_mock_mcp_server(name='Old Server', enable=True)
|
||||
|
||||
# Mock for: first select -> update -> second select (for updated server)
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# All selects return the server
|
||||
return _create_mock_result(first_item=old_server)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'test-uuid', 'name': 'Old Server', 'enable': True}
|
||||
)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute - update enabled server (keep enabled, update extra_args)
|
||||
await service.update_mcp_server('test-uuid', {'enable': True, 'extra_args': {'new': 'args'}})
|
||||
|
||||
# Verify - remove and reload
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server.assert_called_once_with('Old Server')
|
||||
ap.tool_mgr.mcp_tool_loader.host_mcp_server.assert_called_once()
|
||||
|
||||
async def test_update_mcp_server_no_tool_mgr(self):
|
||||
"""Updates persistence without tool_mgr operations."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
# Set mcp_tool_loader to None, not tool_mgr itself
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = None
|
||||
|
||||
old_server = _create_mock_mcp_server(name='Server', enable=True)
|
||||
|
||||
# Mock execute for select and update
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=old_server)
|
||||
return Mock() # Update
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute - should not raise
|
||||
await service.update_mcp_server('test-uuid', {'name': 'New Name'})
|
||||
|
||||
# Verify - persistence was called
|
||||
assert ap.persistence_mgr.execute_async.call_count >= 2
|
||||
|
||||
|
||||
class TestMCPServiceDeleteMCPServer:
|
||||
"""Tests for delete_mcp_server method."""
|
||||
|
||||
async def test_delete_mcp_server_calls_remove_and_delete(self):
|
||||
"""Calls both persistence delete and tool_mgr remove."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.sessions = {'Server to Delete': Mock()}
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
|
||||
|
||||
server = _create_mock_mcp_server(name='Server to Delete')
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=server)
|
||||
return Mock() # Delete
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_mcp_server('test-uuid')
|
||||
|
||||
# Verify
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server.assert_called_once_with('Server to Delete')
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
|
||||
async def test_delete_mcp_server_not_in_sessions(self):
|
||||
"""Does not attempt remove if server not in sessions."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.sessions = {} # Server not in sessions
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
|
||||
|
||||
server = _create_mock_mcp_server(name='Not in Sessions')
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=server)
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_mcp_server('test-uuid')
|
||||
|
||||
# Verify - remove not called (server not in sessions)
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server.assert_not_called()
|
||||
|
||||
async def test_delete_mcp_server_nonexistent_uuid(self):
|
||||
"""Delete operation completes even for nonexistent UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.sessions = {}
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
|
||||
|
||||
# No server found
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=None)
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute - should not raise
|
||||
await service.delete_mcp_server('nonexistent-uuid')
|
||||
|
||||
# Verify - delete was called regardless
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
|
||||
|
||||
class TestMCPServiceTestMCPServer:
|
||||
"""Tests for test_mcp_server method."""
|
||||
|
||||
async def test_test_mcp_server_existing_server(self):
|
||||
"""Tests existing MCP server connection."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
|
||||
from langbot.pkg.provider.tools.loaders.mcp import MCPSessionStatus
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.status = MCPSessionStatus.ERROR
|
||||
mock_session.start = AsyncMock()
|
||||
mock_session.refresh = AsyncMock()
|
||||
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session)
|
||||
|
||||
ap.task_mgr = SimpleNamespace()
|
||||
ap.task_mgr.create_user_task = Mock(
|
||||
return_value=SimpleNamespace(id=123)
|
||||
)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
task_id = await service.test_mcp_server('existing-server', {})
|
||||
|
||||
# Verify - returns task ID
|
||||
assert task_id == 123
|
||||
|
||||
async def test_test_mcp_server_not_found_raises(self):
|
||||
"""Raises ValueError when server not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=None)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Server not found'):
|
||||
await service.test_mcp_server('nonexistent-server', {})
|
||||
|
||||
async def test_test_mcp_server_new_server(self):
|
||||
"""Tests new MCP server with underscore name."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.start = AsyncMock()
|
||||
ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session)
|
||||
|
||||
ap.task_mgr = SimpleNamespace()
|
||||
ap.task_mgr.create_user_task = Mock(
|
||||
return_value=SimpleNamespace(id=456)
|
||||
)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute with '_' name (new server)
|
||||
task_id = await service.test_mcp_server('_', {'name': 'New Server'})
|
||||
|
||||
# Verify - load_mcp_server called
|
||||
ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once()
|
||||
assert task_id == 456
|
||||
@@ -1,964 +0,0 @@
|
||||
"""
|
||||
Unit tests for LLMModelsService, EmbeddingModelsService, and RerankModelsService.
|
||||
|
||||
Tests model management operations including:
|
||||
- Model CRUD operations
|
||||
- Model with provider info
|
||||
- Provider auto-creation on model create/update
|
||||
- Runtime model loading/unloading
|
||||
- Model deletion
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/model.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langbot.pkg.api.http.service.model import (
|
||||
LLMModelsService,
|
||||
EmbeddingModelsService,
|
||||
RerankModelsService,
|
||||
_parse_provider_api_keys,
|
||||
_runtime_model_data,
|
||||
)
|
||||
from langbot.pkg.entity.persistence.model import LLMModel, EmbeddingModel, RerankModel, ModelProvider
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_llm_model(
|
||||
model_uuid: str = 'llm-uuid',
|
||||
name: str = 'Test LLM',
|
||||
provider_uuid: str = 'provider-uuid',
|
||||
abilities: list = None,
|
||||
extra_args: dict = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock LLMModel entity."""
|
||||
model = Mock(spec=LLMModel)
|
||||
model.uuid = model_uuid
|
||||
model.name = name
|
||||
model.provider_uuid = provider_uuid
|
||||
model.abilities = abilities or []
|
||||
model.extra_args = extra_args or {}
|
||||
return model
|
||||
|
||||
|
||||
def _create_mock_embedding_model(
|
||||
model_uuid: str = 'embedding-uuid',
|
||||
name: str = 'Test Embedding',
|
||||
provider_uuid: str = 'provider-uuid',
|
||||
) -> Mock:
|
||||
"""Helper to create mock EmbeddingModel entity."""
|
||||
model = Mock(spec=EmbeddingModel)
|
||||
model.uuid = model_uuid
|
||||
model.name = name
|
||||
model.provider_uuid = provider_uuid
|
||||
model.extra_args = {}
|
||||
return model
|
||||
|
||||
|
||||
def _create_mock_rerank_model(
|
||||
model_uuid: str = 'rerank-uuid',
|
||||
name: str = 'Test Rerank',
|
||||
provider_uuid: str = 'provider-uuid',
|
||||
) -> Mock:
|
||||
"""Helper to create mock RerankModel entity."""
|
||||
model = Mock(spec=RerankModel)
|
||||
model.uuid = model_uuid
|
||||
model.name = name
|
||||
model.provider_uuid = provider_uuid
|
||||
model.extra_args = {}
|
||||
return model
|
||||
|
||||
|
||||
def _create_mock_provider(
|
||||
provider_uuid: str = 'provider-uuid',
|
||||
name: str = 'Test Provider',
|
||||
api_keys: list = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock ModelProvider entity."""
|
||||
provider = Mock(spec=ModelProvider)
|
||||
provider.uuid = provider_uuid
|
||||
provider.name = name
|
||||
provider.requester = 'openai'
|
||||
provider.base_url = 'https://api.openai.com'
|
||||
provider.api_keys = api_keys or ['key']
|
||||
return provider
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestParseProviderApiKeys:
|
||||
"""Tests for _parse_provider_api_keys helper function."""
|
||||
|
||||
def test_parse_valid_json_string(self):
|
||||
"""Parses valid JSON string to list."""
|
||||
provider_dict = {'api_keys': '["key1", "key2"]'}
|
||||
result = _parse_provider_api_keys(provider_dict)
|
||||
assert result['api_keys'] == ['key1', 'key2']
|
||||
|
||||
def test_parse_invalid_json_returns_empty(self):
|
||||
"""Returns empty list for invalid JSON."""
|
||||
provider_dict = {'api_keys': 'invalid json'}
|
||||
result = _parse_provider_api_keys(provider_dict)
|
||||
assert result['api_keys'] == []
|
||||
|
||||
def test_parse_already_list(self):
|
||||
"""Returns unchanged if already a list."""
|
||||
provider_dict = {'api_keys': ['key1', 'key2']}
|
||||
result = _parse_provider_api_keys(provider_dict)
|
||||
assert result['api_keys'] == ['key1', 'key2']
|
||||
|
||||
def test_parse_missing_key(self):
|
||||
"""Handles missing api_keys key."""
|
||||
provider_dict = {'name': 'Provider'}
|
||||
result = _parse_provider_api_keys(provider_dict)
|
||||
assert 'api_keys' not in result
|
||||
|
||||
|
||||
class TestRuntimeModelData:
|
||||
"""Tests for _runtime_model_data helper function."""
|
||||
|
||||
def test_runtime_data_preserves_uuid(self):
|
||||
"""Preserves UUID in runtime data."""
|
||||
update_payload = {'name': 'Updated', 'provider_uuid': 'provider'}
|
||||
result = _runtime_model_data('model-uuid', update_payload)
|
||||
assert result['uuid'] == 'model-uuid'
|
||||
assert result['name'] == 'Updated'
|
||||
|
||||
def test_runtime_data_copies_all_fields(self):
|
||||
"""Copies all fields from payload."""
|
||||
update_payload = {
|
||||
'name': 'Model',
|
||||
'provider_uuid': 'provider',
|
||||
'abilities': ['vision'],
|
||||
'extra_args': {'temp': 0.7},
|
||||
}
|
||||
result = _runtime_model_data('uuid', update_payload)
|
||||
assert result['abilities'] == ['vision']
|
||||
assert result['extra_args'] == {'temp': 0.7}
|
||||
|
||||
|
||||
class TestLLMModelsServiceGetLLMModels:
|
||||
"""Tests for LLMModelsService.get_llm_models method."""
|
||||
|
||||
async def test_get_llm_models_empty_list(self):
|
||||
"""Returns empty list when no models exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
mock_provider_result = _create_mock_result([])
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
return mock_result if call_count == 0 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': entity.provider_uuid,
|
||||
}
|
||||
)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_llm_models()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_llm_models_with_provider_info(self):
|
||||
"""Returns models with provider info."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_llm_model()
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([model])
|
||||
mock_provider_result = _create_mock_result([provider])
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': entity.provider_uuid if hasattr(entity, 'provider_uuid') else None,
|
||||
'api_keys': entity.api_keys if hasattr(entity, 'api_keys') else None,
|
||||
}
|
||||
)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_llm_models()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0]['name'] == 'Test LLM'
|
||||
|
||||
async def test_get_llm_models_hide_secret_keys(self):
|
||||
"""Hides secret API keys when include_secret=False."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_llm_model()
|
||||
provider = _create_mock_provider(api_keys=['secret-key-1', 'secret-key-2'])
|
||||
|
||||
mock_model_result = _create_mock_result([model])
|
||||
mock_provider_result = _create_mock_result([provider])
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': entity.provider_uuid if hasattr(entity, 'provider_uuid') else None,
|
||||
'api_keys': entity.api_keys if hasattr(entity, 'api_keys') else None,
|
||||
}
|
||||
)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_llm_models(include_secret=False)
|
||||
|
||||
# Verify - keys should be masked
|
||||
assert result[0]['provider']['api_keys'] == ['***', '***']
|
||||
|
||||
|
||||
class TestLLMModelsServiceGetLLMModel:
|
||||
"""Tests for LLMModelsService.get_llm_model method."""
|
||||
|
||||
async def test_get_llm_model_found(self):
|
||||
"""Returns model when found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_llm_model(model_uuid='found-uuid')
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([], first_item=model)
|
||||
mock_provider_result = _create_mock_result([], first_item=provider)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'found-uuid',
|
||||
'name': 'Test LLM',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'provider': {'uuid': 'provider-uuid', 'api_keys': ['key']},
|
||||
}
|
||||
)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_llm_model('found-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['uuid'] == 'found-uuid'
|
||||
|
||||
async def test_get_llm_model_not_found(self):
|
||||
"""Returns None when model not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([], first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_llm_model('nonexistent-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLLMModelsServiceGetLLMModelsByProvider:
|
||||
"""Tests for LLMModelsService.get_llm_models_by_provider method."""
|
||||
|
||||
async def test_get_models_by_provider_uuid(self):
|
||||
"""Returns models for specific provider."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model1 = _create_mock_llm_model(model_uuid='model-1', provider_uuid='target-provider')
|
||||
model2 = _create_mock_llm_model(model_uuid='model-2', provider_uuid='target-provider')
|
||||
|
||||
mock_result = _create_mock_result([model1, model2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'model-1', 'name': 'Model 1'}
|
||||
)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_llm_models_by_provider('target-provider')
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestLLMModelsServiceCreateLLMModel:
|
||||
"""Tests for LLMModelsService.create_llm_model method."""
|
||||
|
||||
async def test_create_llm_model_generates_uuid(self):
|
||||
"""Creates LLM model with generated UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
ap.pipeline_service = SimpleNamespace()
|
||||
ap.pipeline_service.update_pipeline = AsyncMock()
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
model_uuid = await service.create_llm_model({
|
||||
'name': 'New LLM',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
# Verify
|
||||
assert model_uuid is not None
|
||||
assert len(model_uuid) == 36 # UUID format
|
||||
|
||||
async def test_create_llm_model_preserve_uuid(self):
|
||||
"""Creates LLM model preserving provided UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
ap.pipeline_service = SimpleNamespace()
|
||||
ap.pipeline_service.update_pipeline = AsyncMock()
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
model_uuid = await service.create_llm_model({
|
||||
'uuid': 'preserved-uuid',
|
||||
'name': 'Preserved UUID Model',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
}, preserve_uuid=True)
|
||||
|
||||
# Verify
|
||||
assert model_uuid == 'preserved-uuid'
|
||||
|
||||
async def test_create_llm_model_provider_not_found_raises_error(self):
|
||||
"""Raises Exception when provider not found in runtime."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {} # Empty - no provider
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='provider not found'):
|
||||
await service.create_llm_model({
|
||||
'name': 'No Provider Model',
|
||||
'provider_uuid': 'nonexistent-provider',
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
async def test_create_llm_model_with_provider_data(self):
|
||||
"""Creates provider when provider data provided."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
ap.provider_service = SimpleNamespace()
|
||||
ap.provider_service.find_or_create_provider = AsyncMock(return_value='new-provider-uuid')
|
||||
ap.pipeline_service = SimpleNamespace()
|
||||
ap.pipeline_service.update_pipeline = AsyncMock()
|
||||
|
||||
# Create runtime provider
|
||||
runtime_provider = Mock()
|
||||
ap.model_mgr.provider_dict['new-provider-uuid'] = runtime_provider
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute - with provider data (no UUID)
|
||||
result_uuid = await service.create_llm_model({
|
||||
'name': 'Model with New Provider',
|
||||
'provider': {
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
},
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
# Verify - provider_service was called and UUID generated
|
||||
ap.provider_service.find_or_create_provider.assert_called_once()
|
||||
assert result_uuid is not None
|
||||
|
||||
|
||||
class TestLLMModelsServiceUpdateLLMModel:
|
||||
"""Tests for LLMModelsService.update_llm_model method."""
|
||||
|
||||
async def test_update_llm_model_removes_uuid_from_data(self):
|
||||
"""Removes uuid from update data before persisting."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.remove_llm_model = AsyncMock()
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_llm_model('existing-uuid', {
|
||||
'uuid': 'should-be-removed',
|
||||
'name': 'Updated Name',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
})
|
||||
|
||||
# Verify - remove and load called
|
||||
ap.model_mgr.remove_llm_model.assert_called_once_with('existing-uuid')
|
||||
|
||||
async def test_update_llm_model_provider_not_found_raises_error(self):
|
||||
"""Raises Exception when provider not found after update."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {} # Empty
|
||||
ap.model_mgr.remove_llm_model = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='provider not found'):
|
||||
await service.update_llm_model('model-uuid', {
|
||||
'name': 'Update',
|
||||
'provider_uuid': 'nonexistent-provider',
|
||||
})
|
||||
|
||||
|
||||
class TestLLMModelsServiceDeleteLLMModel:
|
||||
"""Tests for LLMModelsService.delete_llm_model method."""
|
||||
|
||||
async def test_delete_llm_model_success(self):
|
||||
"""Deletes LLM model successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.remove_llm_model = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_llm_model('delete-uuid')
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
ap.model_mgr.remove_llm_model.assert_called_once_with('delete-uuid')
|
||||
|
||||
|
||||
class TestEmbeddingModelsServiceGetEmbeddingModels:
|
||||
"""Tests for EmbeddingModelsService.get_embedding_models method."""
|
||||
|
||||
async def test_get_embedding_models_empty_list(self):
|
||||
"""Returns empty list when no models exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'embedding-uuid', 'name': 'Test'}
|
||||
)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_embedding_models()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_embedding_models_with_provider(self):
|
||||
"""Returns embedding models with provider info."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_embedding_model()
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([model])
|
||||
mock_provider_result = _create_mock_result([provider])
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': getattr(entity, 'provider_uuid', None),
|
||||
'api_keys': getattr(entity, 'api_keys', ['key']),
|
||||
}
|
||||
)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_embedding_models()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestEmbeddingModelsServiceGetEmbeddingModel:
|
||||
"""Tests for EmbeddingModelsService.get_embedding_model method."""
|
||||
|
||||
async def test_get_embedding_model_found(self):
|
||||
"""Returns embedding model when found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_embedding_model(model_uuid='found-embedding')
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([], first_item=model)
|
||||
mock_provider_result = _create_mock_result([], first_item=provider)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'found-embedding',
|
||||
'name': 'Found Embedding',
|
||||
'provider': {'uuid': 'provider-uuid'},
|
||||
}
|
||||
)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_embedding_model('found-embedding')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
|
||||
async def test_get_embedding_model_not_found(self):
|
||||
"""Returns None when model not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([], first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_embedding_model('nonexistent-embedding')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestEmbeddingModelsServiceCreateEmbeddingModel:
|
||||
"""Tests for EmbeddingModelsService.create_embedding_model method."""
|
||||
|
||||
async def test_create_embedding_model_success(self):
|
||||
"""Creates embedding model successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.embedding_models = []
|
||||
ap.model_mgr.load_embedding_model_with_provider = AsyncMock(return_value=Mock())
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
model_uuid = await service.create_embedding_model({
|
||||
'name': 'New Embedding',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
# Verify
|
||||
assert model_uuid is not None
|
||||
assert len(model_uuid) == 36
|
||||
|
||||
async def test_create_embedding_model_provider_not_found_raises(self):
|
||||
"""Raises Exception when provider not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {} # Empty
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='provider not found'):
|
||||
await service.create_embedding_model({
|
||||
'name': 'No Provider Embedding',
|
||||
'provider_uuid': 'nonexistent',
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
|
||||
class TestEmbeddingModelsServiceDeleteEmbeddingModel:
|
||||
"""Tests for EmbeddingModelsService.delete_embedding_model method."""
|
||||
|
||||
async def test_delete_embedding_model_success(self):
|
||||
"""Deletes embedding model successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.remove_embedding_model = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_embedding_model('delete-embedding-uuid')
|
||||
|
||||
# Verify
|
||||
ap.model_mgr.remove_embedding_model.assert_called_once()
|
||||
|
||||
|
||||
class TestRerankModelsServiceGetRerankModels:
|
||||
"""Tests for RerankModelsService.get_rerank_models method."""
|
||||
|
||||
async def test_get_rerank_models_empty_list(self):
|
||||
"""Returns empty list when no models exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_rerank_models()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_rerank_models_with_provider(self):
|
||||
"""Returns rerank models with provider info."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_rerank_model()
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([model])
|
||||
mock_provider_result = _create_mock_result([provider])
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': getattr(entity, 'provider_uuid', None),
|
||||
'api_keys': getattr(entity, 'api_keys', ['key']),
|
||||
}
|
||||
)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_rerank_models()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestRerankModelsServiceGetRerankModel:
|
||||
"""Tests for RerankModelsService.get_rerank_model method."""
|
||||
|
||||
async def test_get_rerank_model_found(self):
|
||||
"""Returns rerank model when found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_rerank_model(model_uuid='found-rerank')
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([], first_item=model)
|
||||
mock_provider_result = _create_mock_result([], first_item=provider)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'found-rerank',
|
||||
'name': 'Found Rerank',
|
||||
'provider': {'uuid': 'provider-uuid'},
|
||||
}
|
||||
)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_rerank_model('found-rerank')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
|
||||
async def test_get_rerank_model_not_found(self):
|
||||
"""Returns None when model not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([], first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_rerank_model('nonexistent-rerank')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestRerankModelsServiceCreateRerankModel:
|
||||
"""Tests for RerankModelsService.create_rerank_model method."""
|
||||
|
||||
async def test_create_rerank_model_success(self):
|
||||
"""Creates rerank model successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.rerank_models = []
|
||||
ap.model_mgr.load_rerank_model_with_provider = AsyncMock(return_value=Mock())
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
model_uuid = await service.create_rerank_model({
|
||||
'name': 'New Rerank',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
# Verify
|
||||
assert model_uuid is not None
|
||||
|
||||
async def test_create_rerank_model_provider_not_found_raises(self):
|
||||
"""Raises Exception when provider not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {}
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='provider not found'):
|
||||
await service.create_rerank_model({
|
||||
'name': 'No Provider Rerank',
|
||||
'provider_uuid': 'nonexistent',
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
|
||||
class TestRerankModelsServiceDeleteRerankModel:
|
||||
"""Tests for RerankModelsService.delete_rerank_model method."""
|
||||
|
||||
async def test_delete_rerank_model_success(self):
|
||||
"""Deletes rerank model successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.remove_rerank_model = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_rerank_model('delete-rerank-uuid')
|
||||
|
||||
# Verify
|
||||
ap.model_mgr.remove_rerank_model.assert_called_once()
|
||||
|
||||
|
||||
class TestEmbeddingModelsServiceGetEmbeddingModelsByProvider:
|
||||
"""Tests for EmbeddingModelsService.get_embedding_models_by_provider method."""
|
||||
|
||||
async def test_get_embedding_models_by_provider_uuid(self):
|
||||
"""Returns embedding models for specific provider."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model1 = _create_mock_embedding_model(model_uuid='emb-1', provider_uuid='provider-uuid')
|
||||
model2 = _create_mock_embedding_model(model_uuid='emb-2', provider_uuid='provider-uuid')
|
||||
|
||||
mock_result = _create_mock_result([model1, model2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'emb-1', 'name': 'Embedding 1'}
|
||||
)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_embedding_models_by_provider('provider-uuid')
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestRerankModelsServiceGetRerankModelsByProvider:
|
||||
"""Tests for RerankModelsService.get_rerank_models_by_provider method."""
|
||||
|
||||
async def test_get_rerank_models_by_provider_uuid(self):
|
||||
"""Returns rerank models for specific provider."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model1 = _create_mock_rerank_model(model_uuid='rerank-1', provider_uuid='provider-uuid')
|
||||
model2 = _create_mock_rerank_model(model_uuid='rerank-2', provider_uuid='provider-uuid')
|
||||
|
||||
mock_result = _create_mock_result([model1, model2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'}
|
||||
)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_rerank_models_by_provider('provider-uuid')
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
@@ -1,831 +0,0 @@
|
||||
"""
|
||||
Unit tests for PipelineService.
|
||||
|
||||
Tests pipeline CRUD operations including:
|
||||
- Pipeline listing with sorting
|
||||
- Pipeline creation with default config
|
||||
- Pipeline update with bot sync
|
||||
- Pipeline copy functionality
|
||||
- Extensions preferences management
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/pipeline.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch, mock_open
|
||||
from types import SimpleNamespace
|
||||
import uuid
|
||||
import json
|
||||
|
||||
from langbot.pkg.api.http.service.pipeline import PipelineService, default_stage_order
|
||||
from langbot.pkg.entity.persistence.pipeline import LegacyPipeline
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_pipeline(
|
||||
pipeline_uuid: str = None,
|
||||
name: str = 'Test Pipeline',
|
||||
description: str = 'Test Description',
|
||||
is_default: bool = False,
|
||||
stages: list = None,
|
||||
config: dict = None,
|
||||
extensions_preferences: dict = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock LegacyPipeline entity."""
|
||||
pipeline = Mock(spec=LegacyPipeline)
|
||||
pipeline.uuid = pipeline_uuid or str(uuid.uuid4())
|
||||
pipeline.name = name
|
||||
pipeline.description = description
|
||||
pipeline.emoji = '⚙️'
|
||||
pipeline.is_default = is_default
|
||||
pipeline.for_version = '1.0.0'
|
||||
pipeline.stages = stages or default_stage_order.copy()
|
||||
pipeline.config = config or {}
|
||||
pipeline.extensions_preferences = extensions_preferences or {
|
||||
'enable_all_plugins': True,
|
||||
'enable_all_mcp_servers': True,
|
||||
'plugins': [],
|
||||
'mcp_servers': [],
|
||||
}
|
||||
return pipeline
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestPipelineServiceGetPipelineMetadata:
|
||||
"""Tests for get_pipeline_metadata method."""
|
||||
|
||||
async def test_get_pipeline_metadata_returns_list(self):
|
||||
"""Returns list of pipeline metadata configs."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.pipeline_config_meta_trigger = {'trigger': {}}
|
||||
ap.pipeline_config_meta_safety = {'safety': {}}
|
||||
ap.pipeline_config_meta_ai = {'ai': {}}
|
||||
ap.pipeline_config_meta_output = {'output': {}}
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_pipeline_metadata()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 4
|
||||
assert 'trigger' in result[0]
|
||||
assert 'safety' in result[1]
|
||||
assert 'ai' in result[2]
|
||||
assert 'output' in result[3]
|
||||
|
||||
|
||||
class TestPipelineServiceGetPipelines:
|
||||
"""Tests for get_pipelines method."""
|
||||
|
||||
async def test_get_pipelines_empty_list(self):
|
||||
"""Returns empty list when no pipelines exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_pipelines()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_pipelines_returns_sorted_by_created_at_desc(self):
|
||||
"""Returns pipelines sorted by created_at descending by default."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
pipeline1 = _create_mock_pipeline(pipeline_uuid='uuid-1', name='Pipeline 1')
|
||||
pipeline2 = _create_mock_pipeline(pipeline_uuid='uuid-2', name='Pipeline 2')
|
||||
|
||||
mock_result = _create_mock_result([pipeline1, pipeline2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_pipelines()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_get_pipelines_sort_by_updated_at_asc(self):
|
||||
"""Returns pipelines sorted by updated_at ascending."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
await service.get_pipelines(sort_by='updated_at', sort_order='ASC')
|
||||
|
||||
# Verify - execute was called with sort parameters
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
|
||||
class TestPipelineServiceGetPipeline:
|
||||
"""Tests for get_pipeline method."""
|
||||
|
||||
async def test_get_pipeline_by_uuid_found(self):
|
||||
"""Returns pipeline when found by UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
pipeline = _create_mock_pipeline(pipeline_uuid='test-uuid', name='Found Pipeline')
|
||||
mock_result = _create_mock_result(first_item=pipeline)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'name': 'Found Pipeline',
|
||||
'stages': default_stage_order,
|
||||
}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_pipeline('test-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['uuid'] == 'test-uuid'
|
||||
assert result['name'] == 'Found Pipeline'
|
||||
|
||||
async def test_get_pipeline_by_uuid_not_found(self):
|
||||
"""Returns None when pipeline not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_pipeline('nonexistent-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestPipelineServiceCreatePipeline:
|
||||
"""Tests for create_pipeline method."""
|
||||
|
||||
async def test_create_pipeline_max_limit_reached_raises(self):
|
||||
"""Raises ValueError when max_pipelines limit reached."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_pipelines': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
mock_result = _create_mock_result([_create_mock_pipeline(), _create_mock_pipeline()])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Maximum number of pipelines'):
|
||||
await service.create_pipeline({'name': 'New Pipeline'})
|
||||
|
||||
async def test_create_pipeline_no_limit(self):
|
||||
"""Creates pipeline without limit when max_pipelines=-1."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
service = PipelineService(ap)
|
||||
# Override get_pipelines to return empty list (no limit check issue)
|
||||
service.get_pipelines = AsyncMock(return_value=[])
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'})
|
||||
|
||||
# Mock persistence for insert
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'}
|
||||
)
|
||||
|
||||
# Mock the file read for default config - patch at the utils module level
|
||||
default_config = {'trigger': {}, 'safety': {}, 'ai': {}, 'output': {}}
|
||||
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
||||
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
|
||||
bot_uuid = await service.create_pipeline({'name': 'New Pipeline'})
|
||||
|
||||
# Verify
|
||||
assert bot_uuid is not None
|
||||
assert len(bot_uuid) == 36 # UUID format
|
||||
|
||||
async def test_create_pipeline_as_default(self):
|
||||
"""Creates pipeline with is_default=True."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipelines = AsyncMock(return_value=[])
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True})
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True}
|
||||
)
|
||||
|
||||
# Mock the file read
|
||||
default_config = {}
|
||||
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
||||
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
|
||||
await service.create_pipeline({'name': 'Default Pipeline'}, default=True)
|
||||
|
||||
# Verify - execute was called
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
|
||||
async def test_create_pipeline_sets_default_extensions_preferences(self):
|
||||
"""Sets default extensions_preferences when not provided."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipelines = AsyncMock(return_value=[])
|
||||
service.get_pipeline = AsyncMock(return_value={
|
||||
'uuid': 'new-uuid',
|
||||
'extensions_preferences': {},
|
||||
})
|
||||
|
||||
insert_params = []
|
||||
|
||||
async def mock_execute(query):
|
||||
params = query.compile().params
|
||||
if 'extensions_preferences' in params:
|
||||
insert_params.append(params)
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'new-uuid',
|
||||
'extensions_preferences': {},
|
||||
}
|
||||
)
|
||||
|
||||
default_config = {}
|
||||
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
||||
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
|
||||
await service.create_pipeline({'name': 'New Pipeline'})
|
||||
|
||||
assert len(insert_params) == 1
|
||||
assert insert_params[0]['extensions_preferences'] == {
|
||||
'enable_all_plugins': True,
|
||||
'enable_all_mcp_servers': True,
|
||||
'plugins': [],
|
||||
'mcp_servers': [],
|
||||
}
|
||||
|
||||
|
||||
class _MockResultWithBots:
|
||||
"""Helper class to mock SQLAlchemy result with iterable .all() method."""
|
||||
def __init__(self, bots_list):
|
||||
self._bots_list = bots_list
|
||||
|
||||
def all(self):
|
||||
return self._bots_list
|
||||
|
||||
def first(self):
|
||||
return self._bots_list[0] if self._bots_list else None
|
||||
|
||||
|
||||
class TestPipelineServiceUpdatePipeline:
|
||||
"""Tests for update_pipeline method."""
|
||||
|
||||
async def test_update_pipeline_removes_protected_fields(self):
|
||||
"""Does not persist protected fields from update data."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.sess_mgr = SimpleNamespace()
|
||||
ap.sess_mgr.session_list = []
|
||||
ap.bot_service = None # No bot_service when not updating name
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'test-uuid', 'name': 'Updated'})
|
||||
|
||||
# Execute with protected fields - no name change, so no bot sync
|
||||
pipeline_data = {
|
||||
'uuid': 'should-be-removed',
|
||||
'for_version': 'should-be-removed',
|
||||
'stages': ['should-be-removed'],
|
||||
'is_default': True,
|
||||
'description': 'New description', # Not name change, so no bot_service needed
|
||||
}
|
||||
await service.update_pipeline('test-uuid', pipeline_data)
|
||||
|
||||
update_params = ap.persistence_mgr.execute_async.await_args_list[0].args[0].compile().params
|
||||
assert update_params['description'] == 'New description'
|
||||
assert 'should-be-removed' not in update_params.values()
|
||||
assert ['should-be-removed'] not in update_params.values()
|
||||
assert not any(value is True for value in update_params.values())
|
||||
|
||||
async def test_update_pipeline_syncs_bot_names(self):
|
||||
"""Updates bot use_pipeline_name when pipeline name changes."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.sess_mgr = SimpleNamespace()
|
||||
ap.sess_mgr.session_list = []
|
||||
ap.bot_service = SimpleNamespace()
|
||||
ap.bot_service.update_bot = AsyncMock()
|
||||
|
||||
# Create proper mock Bot entities with uuid attribute
|
||||
mock_bot1 = Mock()
|
||||
mock_bot1.uuid = 'bot-uuid-1'
|
||||
mock_bot2 = Mock()
|
||||
mock_bot2.uuid = 'bot-uuid-2'
|
||||
|
||||
# Create bot list
|
||||
bot_list = [mock_bot1, mock_bot2]
|
||||
|
||||
# Create mock result using helper class
|
||||
bot_result = _MockResultWithBots(bot_list)
|
||||
|
||||
# The order of calls in update_pipeline:
|
||||
# 1. UPDATE (line 125) - returns Mock (no result needed)
|
||||
# 2. SELECT bots (line 136) - returns bot_result with .all()
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# First call is the UPDATE - just return a Mock
|
||||
return Mock()
|
||||
elif call_count == 2:
|
||||
# Second call is the SELECT bots - return proper result
|
||||
return bot_result
|
||||
return Mock() # Any additional calls
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'test-uuid', 'name': 'New Name'})
|
||||
|
||||
# Execute with name change
|
||||
await service.update_pipeline('test-uuid', {'name': 'New Name'})
|
||||
|
||||
# Verify - bot_service.update_bot was called for each bot
|
||||
assert ap.bot_service.update_bot.call_count == 2
|
||||
|
||||
async def test_update_pipeline_clears_conversations(self):
|
||||
"""Clears session conversations using this pipeline."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.sess_mgr = SimpleNamespace()
|
||||
|
||||
# Mock session with conversation using this pipeline
|
||||
session = SimpleNamespace()
|
||||
session.using_conversation = SimpleNamespace()
|
||||
session.using_conversation.pipeline_uuid = 'test-uuid'
|
||||
ap.sess_mgr.session_list = [session]
|
||||
ap.bot_service = SimpleNamespace()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'test-uuid'})
|
||||
|
||||
# Execute
|
||||
await service.update_pipeline('test-uuid', {'description': 'Updated'})
|
||||
|
||||
# Verify - conversation was cleared
|
||||
assert session.using_conversation is None
|
||||
|
||||
|
||||
class TestPipelineServiceDeletePipeline:
|
||||
"""Tests for delete_pipeline method."""
|
||||
|
||||
async def test_delete_pipeline_calls_remove_and_delete(self):
|
||||
"""Calls both pipeline_mgr.remove_pipeline and persistence delete."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_pipeline('test-uuid')
|
||||
|
||||
# Verify
|
||||
ap.pipeline_mgr.remove_pipeline.assert_called_once_with('test-uuid')
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_delete_pipeline_nonexistent_uuid(self):
|
||||
"""Delete operation completes even for nonexistent UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute - should not raise
|
||||
await service.delete_pipeline('nonexistent-uuid')
|
||||
|
||||
# Verify
|
||||
ap.pipeline_mgr.remove_pipeline.assert_called_once()
|
||||
|
||||
|
||||
class TestPipelineServiceCopyPipeline:
|
||||
"""Tests for copy_pipeline method."""
|
||||
|
||||
async def test_copy_pipeline_max_limit_reached_raises(self):
|
||||
"""Raises ValueError when max_pipelines limit reached."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_pipelines': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
service = PipelineService(ap)
|
||||
# Mock get_pipelines to return 2 pipelines
|
||||
service.get_pipelines = AsyncMock(return_value=[
|
||||
{'uuid': 'uuid-1', 'name': 'Pipeline 1'},
|
||||
{'uuid': 'uuid-2', 'name': 'Pipeline 2'},
|
||||
])
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Maximum number of pipelines'):
|
||||
await service.copy_pipeline('original-uuid')
|
||||
|
||||
async def test_copy_pipeline_not_found_raises(self):
|
||||
"""Raises ValueError when original pipeline not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipelines = AsyncMock(return_value=[]) # No limit check issue
|
||||
ap.persistence_mgr.execute_async = AsyncMock(
|
||||
return_value=_create_mock_result(first_item=None) # Original not found
|
||||
)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Pipeline original-uuid not found'):
|
||||
await service.copy_pipeline('original-uuid')
|
||||
|
||||
async def test_copy_pipeline_creates_copy(self):
|
||||
"""Creates a copy with (Copy) suffix."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
original = _create_mock_pipeline(
|
||||
pipeline_uuid='original-uuid',
|
||||
name='Original Pipeline',
|
||||
description='Original description',
|
||||
stages=['Stage1', 'Stage2'],
|
||||
config={'key': 'value'},
|
||||
extensions_preferences={'enable_all_plugins': False, 'plugins': ['plugin1']},
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipelines = AsyncMock(return_value=[]) # No limit check issue
|
||||
|
||||
# Mock persistence - get original, then insert, then get new
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original))
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'new-copy-uuid',
|
||||
'name': 'Original Pipeline (Copy)',
|
||||
}
|
||||
)
|
||||
|
||||
service.get_pipeline = AsyncMock(
|
||||
return_value={
|
||||
'uuid': 'new-copy-uuid',
|
||||
'name': 'Original Pipeline (Copy)',
|
||||
}
|
||||
)
|
||||
|
||||
# Execute
|
||||
new_uuid = await service.copy_pipeline('original-uuid')
|
||||
|
||||
# Verify
|
||||
assert new_uuid is not None
|
||||
assert len(new_uuid) == 36 # UUID format
|
||||
|
||||
async def test_copy_pipeline_is_not_default(self):
|
||||
"""Copy is never set as default."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
# Original is default
|
||||
original = _create_mock_pipeline(
|
||||
pipeline_uuid='original-uuid',
|
||||
name='Default Pipeline',
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipelines = AsyncMock(return_value=[])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original))
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'copy-uuid', 'is_default': False}
|
||||
)
|
||||
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'copy-uuid', 'is_default': False})
|
||||
|
||||
# Execute
|
||||
await service.copy_pipeline('original-uuid')
|
||||
|
||||
# Verify - pipeline_mgr.load_pipeline called (copy created)
|
||||
ap.pipeline_mgr.load_pipeline.assert_called_once()
|
||||
|
||||
|
||||
class TestPipelineServiceUpdatePipelineExtensions:
|
||||
"""Tests for update_pipeline_extensions method."""
|
||||
|
||||
async def test_update_extensions_pipeline_not_found_raises(self):
|
||||
"""Raises ValueError when pipeline not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Pipeline nonexistent-uuid not found'):
|
||||
await service.update_pipeline_extensions('nonexistent-uuid', [])
|
||||
|
||||
async def test_update_extensions_sets_plugins(self):
|
||||
"""Updates plugins in extensions_preferences."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
|
||||
original_pipeline = _create_mock_pipeline(
|
||||
extensions_preferences={'enable_all_plugins': True, 'plugins': []}
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=original_pipeline)
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'extensions_preferences': {
|
||||
'enable_all_plugins': False,
|
||||
'plugins': [{'plugin_uuid': 'plugin-1'}],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipeline = AsyncMock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'extensions_preferences': {
|
||||
'enable_all_plugins': False,
|
||||
'plugins': [{'plugin_uuid': 'plugin-1'}],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Execute
|
||||
bound_plugins = [{'plugin_uuid': 'plugin-1'}]
|
||||
await service.update_pipeline_extensions(
|
||||
'test-uuid',
|
||||
bound_plugins=bound_plugins,
|
||||
enable_all_plugins=False,
|
||||
)
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
|
||||
async def test_update_extensions_sets_mcp_servers(self):
|
||||
"""Updates MCP servers in extensions_preferences."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
|
||||
original_pipeline = _create_mock_pipeline()
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=original_pipeline)
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'extensions_preferences': {
|
||||
'enable_all_mcp_servers': False,
|
||||
'mcp_servers': ['mcp-server-1'],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipeline = AsyncMock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'extensions_preferences': {'mcp_servers': ['mcp-server-1']},
|
||||
}
|
||||
)
|
||||
|
||||
# Execute
|
||||
await service.update_pipeline_extensions(
|
||||
'test-uuid',
|
||||
bound_plugins=[],
|
||||
bound_mcp_servers=['mcp-server-1'],
|
||||
enable_all_mcp_servers=False,
|
||||
)
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
|
||||
async def test_update_extensions_none_mcp_servers_keeps_existing(self):
|
||||
"""Does not modify mcp_servers when bound_mcp_servers is None."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
|
||||
original_pipeline = _create_mock_pipeline(
|
||||
extensions_preferences={
|
||||
'enable_all_plugins': True,
|
||||
'enable_all_mcp_servers': True,
|
||||
'plugins': [],
|
||||
'mcp_servers': ['existing-server'],
|
||||
}
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=original_pipeline)
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'test-uuid', 'extensions_preferences': {'mcp_servers': ['existing-server']}}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipeline = AsyncMock(
|
||||
return_value={'uuid': 'test-uuid', 'extensions_preferences': {'mcp_servers': ['existing-server']}}
|
||||
)
|
||||
|
||||
# Execute - bound_mcp_servers is None (not provided)
|
||||
await service.update_pipeline_extensions('test-uuid', bound_plugins=[])
|
||||
|
||||
# Verify - persistence was called
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
|
||||
|
||||
class TestDefaultStageOrder:
|
||||
"""Tests for default_stage_order constant."""
|
||||
|
||||
def test_default_stage_order_not_empty(self):
|
||||
"""Default stage order is not empty."""
|
||||
assert len(default_stage_order) > 0
|
||||
|
||||
def test_default_stage_order_contains_key_stages(self):
|
||||
"""Default stage order contains key processing stages."""
|
||||
assert 'MessageProcessor' in default_stage_order
|
||||
assert 'SendResponseBackStage' in default_stage_order
|
||||
@@ -1,866 +0,0 @@
|
||||
"""
|
||||
Unit tests for ModelProviderService.
|
||||
|
||||
Tests model provider management operations including:
|
||||
- Provider CRUD operations
|
||||
- Provider model count checking
|
||||
- Find or create provider logic
|
||||
- Space model provider API key updates
|
||||
- Provider model scanning
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/provider.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langbot.pkg.api.http.service.provider import ModelProviderService
|
||||
from langbot.pkg.entity.persistence.model import ModelProvider, LLMModel, EmbeddingModel, RerankModel
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_provider(
|
||||
provider_uuid: str = 'test-provider-uuid',
|
||||
name: str = 'Test Provider',
|
||||
requester: str = 'openai',
|
||||
base_url: str = 'https://api.openai.com',
|
||||
api_keys: list = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock ModelProvider entity."""
|
||||
provider = Mock(spec=ModelProvider)
|
||||
provider.uuid = provider_uuid
|
||||
provider.name = name
|
||||
provider.requester = requester
|
||||
provider.base_url = base_url
|
||||
provider.api_keys = api_keys or ['test-key']
|
||||
return provider
|
||||
|
||||
|
||||
def _create_mock_llm_model(
|
||||
model_uuid: str = 'test-llm-uuid',
|
||||
name: str = 'Test LLM',
|
||||
provider_uuid: str = 'test-provider-uuid',
|
||||
) -> Mock:
|
||||
"""Helper to create mock LLMModel entity."""
|
||||
model = Mock(spec=LLMModel)
|
||||
model.uuid = model_uuid
|
||||
model.name = name
|
||||
model.provider_uuid = provider_uuid
|
||||
return model
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
result.scalar = Mock(return_value=len(items) if items else 0)
|
||||
return result
|
||||
|
||||
|
||||
class TestModelProviderServiceGetProviders:
|
||||
"""Tests for get_providers method."""
|
||||
|
||||
async def test_get_providers_empty_list(self):
|
||||
"""Returns empty list when no providers exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'requester': entity.requester,
|
||||
'base_url': entity.base_url,
|
||||
'api_keys': entity.api_keys,
|
||||
}
|
||||
)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_providers()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_providers_returns_serialized_list(self):
|
||||
"""Returns serialized list of providers."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
provider1 = _create_mock_provider(provider_uuid='provider-1', name='Provider 1')
|
||||
provider2 = _create_mock_provider(provider_uuid='provider-2', name='Provider 2')
|
||||
|
||||
mock_result = _create_mock_result([provider1, provider2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'requester': entity.requester,
|
||||
'base_url': entity.base_url,
|
||||
'api_keys': entity.api_keys,
|
||||
}
|
||||
)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_providers()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert result[0]['name'] == 'Provider 1'
|
||||
assert result[1]['name'] == 'Provider 2'
|
||||
|
||||
async def test_get_providers_parse_api_keys_json_string(self):
|
||||
"""Parses api_keys from JSON string if needed."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='provider-1', api_keys='["key1", "key2"]')
|
||||
|
||||
mock_result = _create_mock_result([provider])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'api_keys': entity.api_keys, # Returns string
|
||||
}
|
||||
)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_providers()
|
||||
|
||||
# Verify - api_keys should be parsed from string
|
||||
assert result[0]['api_keys'] == ['key1', 'key2']
|
||||
|
||||
async def test_get_providers_invalid_json_api_keys_returns_empty(self):
|
||||
"""Returns empty list for invalid JSON api_keys."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='provider-1', api_keys='invalid-json')
|
||||
|
||||
mock_result = _create_mock_result([provider])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'api_keys': entity.api_keys, # Returns invalid string
|
||||
}
|
||||
)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_providers()
|
||||
|
||||
# Verify - invalid JSON returns empty list
|
||||
assert result[0]['api_keys'] == []
|
||||
|
||||
|
||||
class TestModelProviderServiceGetProvider:
|
||||
"""Tests for get_provider method."""
|
||||
|
||||
async def test_get_provider_by_uuid_found(self):
|
||||
"""Returns provider when found by UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='found-uuid', name='Found Provider')
|
||||
|
||||
mock_result = _create_mock_result([], first_item=provider)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'found-uuid',
|
||||
'name': 'Found Provider',
|
||||
'api_keys': ['key'],
|
||||
}
|
||||
)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_provider('found-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['uuid'] == 'found-uuid'
|
||||
|
||||
async def test_get_provider_by_uuid_not_found(self):
|
||||
"""Returns None when provider not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([], first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_provider('nonexistent-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestModelProviderServiceCreateProvider:
|
||||
"""Tests for create_provider method."""
|
||||
|
||||
async def test_create_provider_generates_uuid(self):
|
||||
"""Creates provider with generated UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {}
|
||||
|
||||
# Mock load_provider to return runtime provider
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.provider_entity = Mock()
|
||||
runtime_provider.provider_entity.uuid = 'generated-uuid'
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
provider_uuid = await service.create_provider({
|
||||
'name': 'New Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
})
|
||||
|
||||
# Verify - UUID is generated
|
||||
assert provider_uuid is not None
|
||||
assert len(provider_uuid) == 36 # UUID format
|
||||
|
||||
async def test_create_provider_loads_to_runtime(self):
|
||||
"""Loads provider to runtime model_mgr."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {}
|
||||
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.provider_entity = Mock()
|
||||
runtime_provider.provider_entity.uuid = 'runtime-uuid'
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result_uuid = await service.create_provider({
|
||||
'name': 'Runtime Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
})
|
||||
|
||||
# Verify - provider added to runtime dict and UUID generated
|
||||
ap.model_mgr.load_provider.assert_called_once()
|
||||
assert result_uuid is not None
|
||||
|
||||
|
||||
class TestModelProviderServiceUpdateProvider:
|
||||
"""Tests for update_provider method."""
|
||||
|
||||
async def test_update_provider_removes_uuid_from_data(self):
|
||||
"""Removes uuid from update data before persisting."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.reload_provider = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_provider('existing-uuid', {
|
||||
'uuid': 'should-be-removed', # Will be removed
|
||||
'name': 'Updated Name',
|
||||
})
|
||||
|
||||
# Verify - reload called
|
||||
ap.model_mgr.reload_provider.assert_called_once_with('existing-uuid')
|
||||
|
||||
async def test_update_provider_reloads_runtime(self):
|
||||
"""Reloads provider in runtime after update."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.reload_provider = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_provider('update-uuid', {'name': 'New Name'})
|
||||
|
||||
# Verify
|
||||
ap.model_mgr.reload_provider.assert_called_once()
|
||||
|
||||
|
||||
class TestModelProviderServiceDeleteProvider:
|
||||
"""Tests for delete_provider method."""
|
||||
|
||||
async def test_delete_provider_with_llm_models_raises_error(self):
|
||||
"""Raises ValueError when LLM models reference provider."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Mock LLM model exists - only return LLM result since that's first check
|
||||
llm_result = _create_mock_result([], first_item=_create_mock_llm_model())
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=llm_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Cannot delete provider: LLM models'):
|
||||
await service.delete_provider('provider-with-llm')
|
||||
|
||||
async def test_delete_provider_with_embedding_models_raises_error(self):
|
||||
"""Raises ValueError when Embedding models reference provider."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Create results for each check type
|
||||
llm_result = Mock()
|
||||
llm_result.first = Mock(return_value=None) # No LLM models
|
||||
embedding_result = Mock()
|
||||
embedding_result.first = Mock(return_value=Mock(spec=EmbeddingModel)) # Has embedding model
|
||||
rerank_result = Mock()
|
||||
rerank_result.first = Mock(return_value=None)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return llm_result
|
||||
elif call_count == 2:
|
||||
return embedding_result
|
||||
return rerank_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute & Verify - should raise embedding error (LLM check passes, embedding check fails)
|
||||
with pytest.raises(ValueError, match='Cannot delete provider: Embedding models'):
|
||||
await service.delete_provider('provider-with-embedding')
|
||||
|
||||
async def test_delete_provider_with_rerank_models_raises_error(self):
|
||||
"""Raises ValueError when Rerank models reference provider."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Create results for each check type
|
||||
llm_result = Mock()
|
||||
llm_result.first = Mock(return_value=None) # No LLM models
|
||||
embedding_result = Mock()
|
||||
embedding_result.first = Mock(return_value=None) # No embedding models
|
||||
rerank_result = Mock()
|
||||
rerank_result.first = Mock(return_value=Mock(spec=RerankModel)) # Has rerank model
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return llm_result
|
||||
elif call_count == 2:
|
||||
return embedding_result
|
||||
return rerank_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute & Verify - should raise rerank error (LLM and embedding checks pass, rerank check fails)
|
||||
with pytest.raises(ValueError, match='Cannot delete provider: Rerank models'):
|
||||
await service.delete_provider('provider-with-rerank')
|
||||
|
||||
async def test_delete_provider_no_models_success(self):
|
||||
"""Deletes provider when no models reference it."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.remove_provider = AsyncMock()
|
||||
|
||||
# Mock no models reference provider
|
||||
empty_result = Mock()
|
||||
empty_result.first = Mock(return_value=None)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=empty_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_provider('provider-no-models')
|
||||
|
||||
# Verify - delete and remove called
|
||||
ap.model_mgr.remove_provider.assert_called_once_with('provider-no-models')
|
||||
|
||||
|
||||
class TestModelProviderServiceGetProviderModelCounts:
|
||||
"""Tests for get_provider_model_counts method."""
|
||||
|
||||
async def test_get_model_counts_returns_correct_counts(self):
|
||||
"""Returns correct counts for each model type."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Mock scalar results for counts
|
||||
llm_result = Mock()
|
||||
llm_result.scalar = Mock(return_value=3)
|
||||
embedding_result = Mock()
|
||||
embedding_result.scalar = Mock(return_value=2)
|
||||
rerank_result = Mock()
|
||||
rerank_result.scalar = Mock(return_value=1)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return llm_result
|
||||
elif call_count == 2:
|
||||
return embedding_result
|
||||
return rerank_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_provider_model_counts('provider-uuid')
|
||||
|
||||
# Verify
|
||||
assert result['llm_count'] == 3
|
||||
assert result['embedding_count'] == 2
|
||||
assert result['rerank_count'] == 1
|
||||
|
||||
async def test_get_model_counts_zero_counts(self):
|
||||
"""Returns zero counts when no models."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
zero_result = Mock()
|
||||
zero_result.scalar = Mock(return_value=0)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=zero_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_provider_model_counts('empty-provider')
|
||||
|
||||
# Verify
|
||||
assert result['llm_count'] == 0
|
||||
assert result['embedding_count'] == 0
|
||||
assert result['rerank_count'] == 0
|
||||
|
||||
|
||||
class TestModelProviderServiceFindOrCreateProvider:
|
||||
"""Tests for find_or_create_provider method."""
|
||||
|
||||
async def test_find_existing_provider_matching_config(self):
|
||||
"""Returns existing provider UUID when config matches."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
existing_provider = _create_mock_provider(
|
||||
provider_uuid='existing-uuid',
|
||||
requester='openai',
|
||||
base_url='https://api.openai.com',
|
||||
api_keys=['key1', 'key2'],
|
||||
)
|
||||
|
||||
mock_result = _create_mock_result([existing_provider])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.find_or_create_provider(
|
||||
requester='openai',
|
||||
base_url='https://api.openai.com',
|
||||
api_keys=['key1', 'key2'], # Same keys (sorted)
|
||||
)
|
||||
|
||||
# Verify - returns existing UUID
|
||||
assert result == 'existing-uuid'
|
||||
|
||||
async def test_find_existing_provider_keys_order_mismatch(self):
|
||||
"""Returns existing provider when keys match but order differs."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
existing_provider = _create_mock_provider(
|
||||
provider_uuid='existing-uuid',
|
||||
requester='openai',
|
||||
base_url='https://api.openai.com',
|
||||
api_keys=['key1', 'key2'],
|
||||
)
|
||||
|
||||
mock_result = _create_mock_result([existing_provider])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute with reversed key order
|
||||
result = await service.find_or_create_provider(
|
||||
requester='openai',
|
||||
base_url='https://api.openai.com',
|
||||
api_keys=['key2', 'key1'], # Different order, should still match
|
||||
)
|
||||
|
||||
# Verify - returns existing UUID (keys are sorted in comparison)
|
||||
assert result == 'existing-uuid'
|
||||
|
||||
async def test_create_new_provider_no_match(self):
|
||||
"""Creates new provider when no existing match."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {}
|
||||
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.provider_entity = Mock()
|
||||
runtime_provider.provider_entity.uuid = None # Will be set by uuid.uuid4()
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
# Mock no existing providers
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.find_or_create_provider(
|
||||
requester='new-requester',
|
||||
base_url='https://new.api.com',
|
||||
api_keys=['new-key'],
|
||||
)
|
||||
|
||||
# Verify - creates new provider with valid UUID format
|
||||
assert result is not None
|
||||
assert len(result) == 36 # UUID format
|
||||
# Verify provider was loaded to runtime
|
||||
ap.model_mgr.load_provider.assert_called_once()
|
||||
|
||||
async def test_create_provider_name_from_url_parse(self):
|
||||
"""Creates provider with name parsed from URL."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {}
|
||||
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.provider_entity = Mock()
|
||||
runtime_provider.provider_entity.uuid = 'parsed-url-uuid'
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result_uuid = await service.find_or_create_provider(
|
||||
requester='custom',
|
||||
base_url='https://api.example.com/v1',
|
||||
api_keys=['key'],
|
||||
)
|
||||
|
||||
# Verify - name should be parsed from URL (api.example.com)
|
||||
ap.model_mgr.load_provider.assert_called_once()
|
||||
assert result_uuid is not None
|
||||
|
||||
|
||||
class TestModelProviderServiceUpdateSpaceModelProviderApiKeys:
|
||||
"""Tests for update_space_model_provider_api_keys method."""
|
||||
|
||||
async def test_update_space_provider_api_keys(self):
|
||||
"""Updates Space provider API keys."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.reload_provider = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_space_model_provider_api_keys('space-api-key')
|
||||
|
||||
# Verify - update and reload called for Space provider UUID
|
||||
ap.model_mgr.reload_provider.assert_called_once_with(
|
||||
'00000000-0000-0000-0000-000000000000'
|
||||
)
|
||||
|
||||
|
||||
class TestModelProviderServiceScanProviderModels:
|
||||
"""Tests for scan_provider_models method."""
|
||||
|
||||
async def test_scan_provider_not_found_raises_error(self):
|
||||
"""Raises ValueError when provider not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([], first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='provider not found'):
|
||||
await service.scan_provider_models('nonexistent-uuid')
|
||||
|
||||
async def test_scan_provider_returns_models_list(self):
|
||||
"""Returns scanned models list."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.llm_model_service = SimpleNamespace()
|
||||
ap.embedding_models_service = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='scan-uuid')
|
||||
|
||||
mock_result = _create_mock_result([], first_item=provider)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'scan-uuid',
|
||||
'name': 'Scan Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
}
|
||||
)
|
||||
|
||||
# Mock runtime provider with scan capability
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.requester = Mock()
|
||||
runtime_provider.token_mgr = Mock()
|
||||
runtime_provider.token_mgr.get_token = Mock(return_value='token')
|
||||
runtime_provider.token_mgr.tokens = ['token']
|
||||
|
||||
# Mock scan_models to return models
|
||||
async def mock_scan_models(token):
|
||||
return {
|
||||
'models': [
|
||||
{'id': 'gpt-4', 'name': 'GPT-4', 'type': 'llm'},
|
||||
{'id': 'text-embedding', 'name': 'Text Embedding', 'type': 'embedding'},
|
||||
],
|
||||
'debug': None,
|
||||
}
|
||||
|
||||
runtime_provider.requester.scan_models = AsyncMock(side_effect=mock_scan_models)
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
# Mock existing model services
|
||||
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(return_value=[])
|
||||
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.scan_provider_models('scan-uuid')
|
||||
|
||||
# Verify
|
||||
assert 'models' in result
|
||||
assert len(result['models']) == 2
|
||||
|
||||
async def test_scan_provider_filter_by_model_type(self):
|
||||
"""Returns filtered models by type."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.llm_model_service = SimpleNamespace()
|
||||
ap.embedding_models_service = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='filter-uuid')
|
||||
|
||||
mock_result = _create_mock_result([], first_item=provider)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'filter-uuid',
|
||||
'name': 'Filter Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
}
|
||||
)
|
||||
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.requester = Mock()
|
||||
runtime_provider.token_mgr = Mock()
|
||||
runtime_provider.token_mgr.get_token = Mock(return_value='token')
|
||||
runtime_provider.token_mgr.tokens = ['token']
|
||||
|
||||
async def mock_scan_models(token):
|
||||
return {
|
||||
'models': [
|
||||
{'id': 'gpt-4', 'name': 'GPT-4', 'type': 'llm'},
|
||||
{'id': 'text-embedding', 'name': 'Text Embedding', 'type': 'embedding'},
|
||||
],
|
||||
'debug': None,
|
||||
}
|
||||
|
||||
runtime_provider.requester.scan_models = AsyncMock(side_effect=mock_scan_models)
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(return_value=[])
|
||||
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute - filter for LLM only
|
||||
result = await service.scan_provider_models('filter-uuid', model_type='llm')
|
||||
|
||||
# Verify - only LLM models returned
|
||||
assert len(result['models']) == 1
|
||||
assert result['models'][0]['type'] == 'llm'
|
||||
|
||||
async def test_scan_provider_not_implemented_raises_error(self):
|
||||
"""Raises ValueError when scan not implemented."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='no-scan-uuid')
|
||||
|
||||
mock_result = _create_mock_result([], first_item=provider)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'no-scan-uuid',
|
||||
'name': 'No Scan Provider',
|
||||
'requester': 'custom',
|
||||
'base_url': 'https://custom.api.com',
|
||||
'api_keys': ['key'],
|
||||
}
|
||||
)
|
||||
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.requester = Mock()
|
||||
runtime_provider.token_mgr = Mock()
|
||||
runtime_provider.token_mgr.get_token = Mock(return_value='token')
|
||||
runtime_provider.token_mgr.tokens = ['token']
|
||||
runtime_provider.requester.scan_models = AsyncMock(
|
||||
side_effect=NotImplementedError('scan not supported')
|
||||
)
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='current provider does not support model scanning'):
|
||||
await service.scan_provider_models('no-scan-uuid')
|
||||
|
||||
async def test_scan_provider_marks_already_added_models(self):
|
||||
"""Marks models that are already added."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.llm_model_service = SimpleNamespace()
|
||||
ap.embedding_models_service = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='already-added-uuid')
|
||||
|
||||
mock_result = _create_mock_result([], first_item=provider)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'already-added-uuid',
|
||||
'name': 'Already Added Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
}
|
||||
)
|
||||
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.requester = Mock()
|
||||
runtime_provider.token_mgr = Mock()
|
||||
runtime_provider.token_mgr.get_token = Mock(return_value='token')
|
||||
runtime_provider.token_mgr.tokens = ['token']
|
||||
|
||||
async def mock_scan_models(token):
|
||||
return {
|
||||
'models': [
|
||||
{'id': 'existing-model', 'name': 'Existing Model', 'type': 'llm'},
|
||||
{'id': 'new-model', 'name': 'New Model', 'type': 'llm'},
|
||||
],
|
||||
'debug': None,
|
||||
}
|
||||
|
||||
runtime_provider.requester.scan_models = AsyncMock(side_effect=mock_scan_models)
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
# Mock existing LLM model
|
||||
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(
|
||||
return_value=[{'name': 'Existing Model'}]
|
||||
)
|
||||
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.scan_provider_models('already-added-uuid')
|
||||
|
||||
# Verify - existing model marked as already_added
|
||||
existing_model = next(m for m in result['models'] if m['name'] == 'Existing Model')
|
||||
assert existing_model['already_added'] is True
|
||||
|
||||
new_model = next(m for m in result['models'] if m['name'] == 'New Model')
|
||||
assert new_model['already_added'] is False
|
||||
@@ -1,778 +0,0 @@
|
||||
"""
|
||||
Unit tests for SpaceService.
|
||||
|
||||
Tests LangBot Space API interactions including:
|
||||
- OAuth URL generation
|
||||
- Token exchange and refresh
|
||||
- User info retrieval
|
||||
- Credits caching
|
||||
- Model listing
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/space.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||
from types import SimpleNamespace
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from langbot.pkg.api.http.service.space import SpaceService
|
||||
from langbot.pkg.entity.persistence.user import User
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_user(
|
||||
email: str = 'test@example.com',
|
||||
account_type: str = 'space',
|
||||
space_account_uuid: str = 'space-uuid-123',
|
||||
space_access_token: str = 'access_token_123',
|
||||
space_refresh_token: str = 'refresh_token_123',
|
||||
space_access_token_expires_at: datetime.datetime = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock User entity."""
|
||||
user = Mock(spec=User)
|
||||
user.user = email
|
||||
user.account_type = account_type
|
||||
user.space_account_uuid = space_account_uuid
|
||||
user.space_access_token = space_access_token
|
||||
user.space_refresh_token = space_refresh_token
|
||||
user.space_access_token_expires_at = space_access_token_expires_at
|
||||
return user
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestSpaceServiceGetOAuthAuthorizeUrl:
|
||||
"""Tests for get_oauth_authorize_url method."""
|
||||
|
||||
def test_get_oauth_authorize_url_basic(self):
|
||||
"""Returns OAuth URL with redirect_uri."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'space': {
|
||||
'oauth_authorize_url': 'https://space.langbot.app/auth/authorize',
|
||||
}
|
||||
}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service.get_oauth_authorize_url('http://localhost/callback')
|
||||
|
||||
# Verify
|
||||
assert 'redirect_uri=http://localhost/callback' in result
|
||||
assert 'https://space.langbot.app/auth/authorize' in result
|
||||
|
||||
def test_get_oauth_authorize_url_with_state(self):
|
||||
"""Returns OAuth URL with redirect_uri and state."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'space': {
|
||||
'oauth_authorize_url': 'https://space.langbot.app/auth/authorize',
|
||||
}
|
||||
}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service.get_oauth_authorize_url('http://localhost/callback', state='random_state')
|
||||
|
||||
# Verify
|
||||
assert 'redirect_uri=http://localhost/callback' in result
|
||||
assert 'state=random_state' in result
|
||||
|
||||
def test_get_oauth_authorize_url_default_config(self):
|
||||
"""Uses default OAuth URL when config not set."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service.get_oauth_authorize_url('http://localhost/callback')
|
||||
|
||||
# Verify - uses default URL
|
||||
assert 'https://space.langbot.app/auth/authorize' in result
|
||||
|
||||
|
||||
class TestSpaceServiceGetUserByEmail:
|
||||
"""Tests for _get_user_by_email internal method."""
|
||||
|
||||
async def test_get_user_by_email_found(self):
|
||||
"""Returns user when found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user(email='found@example.com')
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._get_user_by_email('found@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.user == 'found@example.com'
|
||||
|
||||
async def test_get_user_by_email_not_found(self):
|
||||
"""Returns None when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._get_user_by_email('notfound@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSpaceServiceEnsureValidToken:
|
||||
"""Tests for _ensure_valid_token internal method."""
|
||||
|
||||
async def test_ensure_valid_token_user_not_found(self):
|
||||
"""Returns None when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._ensure_valid_token('notfound@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_ensure_valid_token_not_space_account(self):
|
||||
"""Returns None when user is not a space account."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user(email='local@example.com', account_type='local')
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._ensure_valid_token('local@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_ensure_valid_token_no_access_token(self):
|
||||
"""Returns None when user has no access token."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user(space_access_token=None)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._ensure_valid_token('test@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_ensure_valid_token_valid_token(self):
|
||||
"""Returns valid access token when not expired."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
# Token expires in 1 hour (valid)
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='valid_token',
|
||||
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._ensure_valid_token('test@example.com')
|
||||
|
||||
# Verify
|
||||
assert result == 'valid_token'
|
||||
|
||||
async def test_ensure_valid_token_expired_no_refresh(self):
|
||||
"""Returns None when token expired and no refresh token."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
# Token expired 1 hour ago
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='expired_token',
|
||||
space_refresh_token=None,
|
||||
space_access_token_expires_at=datetime.datetime.now() - datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._ensure_valid_token('test@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSpaceServiceGetCredits:
|
||||
"""Tests for get_credits method."""
|
||||
|
||||
async def test_get_credits_no_user(self):
|
||||
"""Returns None when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_credits('notfound@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_get_credits_returns_cached_value(self):
|
||||
"""Returns cached credits without API call."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Pre-populate cache
|
||||
service._credits_cache = {'cached@example.com': (100, time.time())}
|
||||
|
||||
# Execute
|
||||
result = await service.get_credits('cached@example.com')
|
||||
|
||||
# Verify - returns cached value without API call
|
||||
assert result == 100
|
||||
|
||||
async def test_get_credits_cache_expired_refreshes(self):
|
||||
"""Refreshes expired cache."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='valid_token',
|
||||
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Pre-populate expired cache (70 seconds ago, past 60s TTL)
|
||||
service._credits_cache = {'test@example.com': (50, time.time() - 70)}
|
||||
|
||||
# Mock get_user_info to return new credits
|
||||
service.get_user_info = AsyncMock(return_value={'credits': 200})
|
||||
|
||||
# Execute
|
||||
result = await service.get_credits('test@example.com')
|
||||
|
||||
# Verify - cache was refreshed
|
||||
assert result == 200
|
||||
assert service._credits_cache['test@example.com'][0] == 200
|
||||
|
||||
async def test_get_credits_force_refresh(self):
|
||||
"""Force refresh ignores cache."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='valid_token',
|
||||
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Pre-populate cache
|
||||
service._credits_cache = {'test@example.com': (100, time.time())}
|
||||
|
||||
# Mock get_user_info to return new credits
|
||||
service.get_user_info = AsyncMock(return_value={'credits': 300})
|
||||
|
||||
# Execute with force_refresh=True
|
||||
result = await service.get_credits('test@example.com', force_refresh=True)
|
||||
|
||||
# Verify - fresh value returned
|
||||
assert result == 300
|
||||
|
||||
async def test_get_credits_returns_cached_on_exception(self):
|
||||
"""Returns cached fallback value when API fails."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='valid_token',
|
||||
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Pre-populate expired cache - will try to refresh and fail
|
||||
service._credits_cache = {'test@example.com': (150, time.time() - 70)}
|
||||
|
||||
# Mock get_user_info to raise exception
|
||||
service.get_user_info = AsyncMock(side_effect=Exception('API Error'))
|
||||
|
||||
# Execute - should return cached fallback value (even though expired)
|
||||
result = await service.get_credits('test@example.com')
|
||||
|
||||
# Verify - returns cached fallback value (150) because API failed
|
||||
assert result == 150
|
||||
|
||||
|
||||
class TestSpaceServiceRefreshToken:
|
||||
"""Tests for refresh_token method."""
|
||||
|
||||
async def test_refresh_token_success(self):
|
||||
"""Refreshes token successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'access_token': 'new_access_token',
|
||||
'refresh_token': 'new_refresh_token',
|
||||
'expires_in': 3600,
|
||||
}
|
||||
})
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.post = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
# Use async context manager mock
|
||||
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute
|
||||
result = await service.refresh_token('old_refresh_token')
|
||||
|
||||
# Verify
|
||||
assert result['access_token'] == 'new_access_token'
|
||||
|
||||
async def test_refresh_token_api_error(self):
|
||||
"""Raises ValueError on API error."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response with error
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 1,
|
||||
'msg': 'Invalid refresh token',
|
||||
})
|
||||
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid refresh token"}')
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.post = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Failed to refresh token'):
|
||||
await service.refresh_token('invalid_refresh_token')
|
||||
|
||||
async def test_refresh_token_http_error(self):
|
||||
"""Raises ValueError on HTTP error."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response with error status
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 500
|
||||
mock_response.text = AsyncMock(return_value='Internal Server Error')
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.post = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Failed to refresh token'):
|
||||
await service.refresh_token('refresh_token')
|
||||
|
||||
|
||||
class TestSpaceServiceExchangeOAuthCode:
|
||||
"""Tests for exchange_oauth_code method."""
|
||||
|
||||
async def test_exchange_oauth_code_success(self):
|
||||
"""Exchanges OAuth code successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'access_token': 'new_access_token',
|
||||
'refresh_token': 'new_refresh_token',
|
||||
'expires_in': 3600,
|
||||
}
|
||||
})
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.post = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute
|
||||
result = await service.exchange_oauth_code('auth_code')
|
||||
|
||||
# Verify
|
||||
assert result['access_token'] == 'new_access_token'
|
||||
|
||||
async def test_exchange_oauth_code_api_error(self):
|
||||
"""Raises ValueError on API error."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response with error
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={'code': 1, 'msg': 'Invalid code'})
|
||||
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid code"}')
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.post = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Failed to exchange OAuth code'):
|
||||
await service.exchange_oauth_code('invalid_code')
|
||||
|
||||
|
||||
class TestSpaceServiceGetUserInfoRaw:
|
||||
"""Tests for get_user_info_raw method."""
|
||||
|
||||
async def test_get_user_info_raw_success(self):
|
||||
"""Gets user info successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'email': 'test@example.com',
|
||||
'credits': 100,
|
||||
}
|
||||
})
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.get = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_info_raw('access_token')
|
||||
|
||||
# Verify
|
||||
assert result['email'] == 'test@example.com'
|
||||
assert result['credits'] == 100
|
||||
|
||||
async def test_get_user_info_raw_api_error(self):
|
||||
"""Raises ValueError on API error."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response with error
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={'code': 1, 'msg': 'Unauthorized'})
|
||||
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Unauthorized"}')
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.get = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Failed to get user info'):
|
||||
await service.get_user_info_raw('invalid_token')
|
||||
|
||||
|
||||
class TestSpaceServiceGetUserInfo:
|
||||
"""Tests for get_user_info method (with token validation)."""
|
||||
|
||||
async def test_get_user_info_no_token(self):
|
||||
"""Returns None when no valid token."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_info('notfound@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_get_user_info_with_valid_token(self):
|
||||
"""Returns user info with valid token."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='valid_token',
|
||||
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock get_user_info_raw
|
||||
service.get_user_info_raw = AsyncMock(return_value={'email': 'test@example.com', 'credits': 100})
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_info('test@example.com')
|
||||
|
||||
# Verify
|
||||
assert result['email'] == 'test@example.com'
|
||||
|
||||
|
||||
class TestSpaceServiceGetModels:
|
||||
"""Tests for get_models method."""
|
||||
|
||||
async def test_get_models_success(self):
|
||||
"""Gets models successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response with proper model data matching SpaceModel schema
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'models': [
|
||||
{
|
||||
'uuid': 'uuid-1',
|
||||
'model_id': 'model-1',
|
||||
'provider': 'provider-1',
|
||||
'category': 'chat',
|
||||
'status': 'active',
|
||||
},
|
||||
{
|
||||
'uuid': 'uuid-2',
|
||||
'model_id': 'model-2',
|
||||
'provider': 'provider-2',
|
||||
'category': 'chat',
|
||||
'status': 'active',
|
||||
},
|
||||
]
|
||||
}
|
||||
})
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.get = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute
|
||||
result = await service.get_models()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
|
||||
async def test_get_models_api_error(self):
|
||||
"""Raises ValueError on API error."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response with error
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={'code': 1, 'msg': 'Unauthorized'})
|
||||
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Unauthorized"}')
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.get = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Failed to get models'):
|
||||
await service.get_models()
|
||||
|
||||
|
||||
class TestSpaceServiceCreditsCache:
|
||||
"""Tests for credits cache behavior."""
|
||||
|
||||
def test_credits_cache_initialized(self):
|
||||
"""Verify _credits_cache is initialized as empty dict."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Verify
|
||||
assert hasattr(service, '_credits_cache')
|
||||
assert service._credits_cache == {}
|
||||
|
||||
async def test_credits_cache_updates_on_success(self):
|
||||
"""Cache updates when get_credits succeeds."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='valid_token',
|
||||
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock get_user_info
|
||||
service.get_user_info = AsyncMock(return_value={'credits': 500})
|
||||
|
||||
# Execute
|
||||
result = await service.get_credits('test@example.com')
|
||||
|
||||
# Verify - cache updated
|
||||
assert result == 500
|
||||
assert 'test@example.com' in service._credits_cache
|
||||
assert service._credits_cache['test@example.com'][0] == 500
|
||||
@@ -1,608 +0,0 @@
|
||||
"""
|
||||
Unit tests for UserService.
|
||||
|
||||
Tests user management operations including:
|
||||
- User initialization check
|
||||
- Local user creation and authentication
|
||||
- JWT token generation and verification
|
||||
- Password management (reset, change, set)
|
||||
- Space account management
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/user.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langbot.pkg.api.http.service.user import UserService
|
||||
from langbot.pkg.entity.persistence.user import User
|
||||
from langbot.pkg.entity.errors.account import AccountEmailMismatchError
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_user(
|
||||
email: str = 'test@example.com',
|
||||
password: str = 'hashed_password',
|
||||
account_type: str = 'local',
|
||||
space_account_uuid: str = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock User entity."""
|
||||
user = Mock(spec=User)
|
||||
user.user = email
|
||||
user.password = password
|
||||
user.account_type = account_type
|
||||
user.space_account_uuid = space_account_uuid
|
||||
return user
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestUserServiceIsInitialized:
|
||||
"""Tests for is_initialized method."""
|
||||
|
||||
async def test_is_initialized_returns_true_when_users_exist(self):
|
||||
"""Returns True when at least one user exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user()
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.is_initialized()
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
|
||||
async def test_is_initialized_returns_false_when_no_users(self):
|
||||
"""Returns False when no users exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.is_initialized()
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
async def test_is_initialized_returns_false_on_none_result(self):
|
||||
"""Returns False when result is None."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = Mock()
|
||||
mock_result.all = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.is_initialized()
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestUserServiceGetUserByEmail:
|
||||
"""Tests for get_user_by_email method."""
|
||||
|
||||
async def test_get_user_by_email_found(self):
|
||||
"""Returns user when found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user(email='found@example.com')
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_by_email('found@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.user == 'found@example.com'
|
||||
|
||||
async def test_get_user_by_email_not_found(self):
|
||||
"""Returns None when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_by_email('notfound@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_get_user_by_email_empty_string(self):
|
||||
"""Handles empty email string."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_by_email('')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUserServiceGetUserBySpaceAccountUuid:
|
||||
"""Tests for get_user_by_space_account_uuid method."""
|
||||
|
||||
async def test_get_user_by_space_uuid_found(self):
|
||||
"""Returns user when Space UUID found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user(
|
||||
email='space@example.com',
|
||||
account_type='space',
|
||||
space_account_uuid='space-uuid-123',
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_by_space_account_uuid('space-uuid-123')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.space_account_uuid == 'space-uuid-123'
|
||||
|
||||
async def test_get_user_by_space_uuid_not_found(self):
|
||||
"""Returns None when Space UUID not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_by_space_account_uuid('nonexistent-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUserServiceAuthenticate:
|
||||
"""Tests for authenticate method."""
|
||||
|
||||
async def test_authenticate_user_not_found_raises_error(self):
|
||||
"""Raises ValueError when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 3600}}}
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='用户不存在'):
|
||||
await service.authenticate('nonexistent@example.com', 'password')
|
||||
|
||||
async def test_authenticate_space_user_without_password_raises_error(self):
|
||||
"""Raises ValueError for Space user without local password."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
# Space user has empty password
|
||||
mock_user = _create_mock_user(
|
||||
email='space@example.com',
|
||||
password='', # Empty password for Space user
|
||||
account_type='space',
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='请使用 Space 账户登录'):
|
||||
await service.authenticate('space@example.com', 'password')
|
||||
|
||||
|
||||
class TestUserServiceGenerateJwtToken:
|
||||
"""Tests for generate_jwt_token method."""
|
||||
|
||||
async def test_generate_jwt_token_returns_valid_token(self):
|
||||
"""Generates valid JWT token."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 3600}}}
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
token = await service.generate_jwt_token('test@example.com')
|
||||
|
||||
# Verify - JWT format (base64 encoded parts)
|
||||
assert token is not None
|
||||
assert len(token) > 0
|
||||
parts = token.split('.')
|
||||
assert len(parts) == 3 # JWT has 3 parts
|
||||
|
||||
async def test_generate_jwt_token_custom_expire(self):
|
||||
"""Generates token with custom expiry."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 7200}}}
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
token = await service.generate_jwt_token('test@example.com')
|
||||
|
||||
# Verify
|
||||
assert token is not None
|
||||
|
||||
|
||||
class TestUserServiceVerifyJwtToken:
|
||||
"""Tests for verify_jwt_token method."""
|
||||
|
||||
async def test_verify_jwt_token_valid(self):
|
||||
"""Verifies valid JWT token and returns user email."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 3600}}}
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# First generate a valid token
|
||||
token = await service.generate_jwt_token('verify@example.com')
|
||||
|
||||
# Execute
|
||||
user_email = await service.verify_jwt_token(token)
|
||||
|
||||
# Verify
|
||||
assert user_email == 'verify@example.com'
|
||||
|
||||
async def test_verify_jwt_token_invalid_raises_error(self):
|
||||
"""Raises error for invalid JWT token."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 3600}}}
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute & Verify - invalid token should raise JWT error
|
||||
with pytest.raises(Exception): # jwt.DecodeError or similar
|
||||
await service.verify_jwt_token('invalid.token.here')
|
||||
|
||||
|
||||
class TestUserServiceResetPassword:
|
||||
"""Tests for reset_password method."""
|
||||
|
||||
async def test_reset_password_updates_password(self):
|
||||
"""Updates user password."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
await service.reset_password('test@example.com', 'new_password')
|
||||
|
||||
# Verify - execute_async was called with update
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
|
||||
class TestUserServiceChangePassword:
|
||||
"""Tests for change_password method."""
|
||||
|
||||
async def test_change_password_user_not_found_raises_error(self):
|
||||
"""Raises ValueError when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock get_user_by_email to return None
|
||||
service.get_user_by_email = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='User not found'):
|
||||
await service.change_password('nonexistent@example.com', 'current', 'new')
|
||||
|
||||
async def test_change_password_no_local_password_raises_error(self):
|
||||
"""Raises ValueError when user has no local password set."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock user without password
|
||||
mock_user = _create_mock_user(email='nopass@example.com', password=None)
|
||||
service.get_user_by_email = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='No local password set'):
|
||||
await service.change_password('nopass@example.com', 'current', 'new')
|
||||
|
||||
|
||||
class TestUserServiceGetFirstUser:
|
||||
"""Tests for get_first_user method."""
|
||||
|
||||
async def test_get_first_user_found(self):
|
||||
"""Returns first user when exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user(email='first@example.com')
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_first_user()
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.user == 'first@example.com'
|
||||
|
||||
async def test_get_first_user_not_found(self):
|
||||
"""Returns None when no users exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_first_user()
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUserServiceSetPassword:
|
||||
"""Tests for set_password method."""
|
||||
|
||||
async def test_set_password_user_not_found_raises_error(self):
|
||||
"""Raises ValueError when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock get_user_by_email to return None
|
||||
service.get_user_by_email = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='User not found'):
|
||||
await service.set_password('nonexistent@example.com', 'new_password')
|
||||
|
||||
async def test_set_password_with_existing_password_requires_current(self):
|
||||
"""Requires current password when user has existing password."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock user with existing password
|
||||
mock_user = _create_mock_user(email='haspass@example.com', password='hashed_old_password')
|
||||
service.get_user_by_email = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Execute & Verify - should raise when no current_password provided
|
||||
with pytest.raises(ValueError, match='Current password is required'):
|
||||
await service.set_password('haspass@example.com', 'new_password')
|
||||
|
||||
|
||||
class TestUserServiceCreateOrUpdateSpaceUser:
|
||||
"""Tests for create_or_update_space_user method."""
|
||||
|
||||
async def test_create_or_update_existing_space_user(self):
|
||||
"""Updates existing Space user tokens."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.provider_service = SimpleNamespace()
|
||||
ap.provider_service.update_space_model_provider_api_keys = AsyncMock()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock existing Space user
|
||||
existing_user = _create_mock_user(
|
||||
email='space@example.com',
|
||||
account_type='space',
|
||||
space_account_uuid='existing-space-uuid',
|
||||
)
|
||||
service.get_user_by_space_account_uuid = AsyncMock(return_value=existing_user)
|
||||
service.get_user_by_email = AsyncMock(return_value=None)
|
||||
service.is_initialized = AsyncMock(return_value=True)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
# Execute
|
||||
updated_user = await service.create_or_update_space_user(
|
||||
space_account_uuid='existing-space-uuid',
|
||||
email='space@example.com',
|
||||
access_token='new_access_token',
|
||||
refresh_token='new_refresh_token',
|
||||
api_key='new_api_key',
|
||||
expires_in=3600,
|
||||
)
|
||||
|
||||
# Verify - update was called and user returned
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
assert updated_user.space_account_uuid == 'existing-space-uuid'
|
||||
|
||||
async def test_create_or_update_new_space_user_first_init(self):
|
||||
"""Creates new Space user on first initialization."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.provider_service = SimpleNamespace()
|
||||
ap.provider_service.update_space_model_provider_api_keys = AsyncMock()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock new user to be returned after creation
|
||||
new_user = _create_mock_user(
|
||||
email='newspace@example.com',
|
||||
account_type='space',
|
||||
space_account_uuid='new-space-uuid',
|
||||
)
|
||||
|
||||
# First call (line 138) returns None, second call (line 194) returns new_user
|
||||
call_count = 0
|
||||
async def mock_get_by_space_uuid(uuid):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1: # First check for existing user
|
||||
return None
|
||||
return new_user # After insert, return the new user
|
||||
|
||||
service.get_user_by_space_account_uuid = AsyncMock(side_effect=mock_get_by_space_uuid)
|
||||
service.get_user_by_email = AsyncMock(return_value=None)
|
||||
service.is_initialized = AsyncMock(return_value=False) # Not initialized
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
# Execute
|
||||
result = await service.create_or_update_space_user(
|
||||
space_account_uuid='new-space-uuid',
|
||||
email='newspace@example.com',
|
||||
access_token='access_token',
|
||||
refresh_token='refresh_token',
|
||||
api_key='api_key',
|
||||
expires_in=3600,
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result.space_account_uuid == 'new-space-uuid'
|
||||
|
||||
async def test_create_or_update_space_user_already_initialized_raises_error(self):
|
||||
"""Raises AccountEmailMismatchError when system already initialized and user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.provider_service = SimpleNamespace()
|
||||
ap.provider_service.update_space_model_provider_api_keys = AsyncMock()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock system already initialized, no matching users
|
||||
service.get_user_by_space_account_uuid = AsyncMock(return_value=None)
|
||||
service.get_user_by_email = AsyncMock(return_value=None)
|
||||
service.is_initialized = AsyncMock(return_value=True) # Already initialized
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(AccountEmailMismatchError):
|
||||
await service.create_or_update_space_user(
|
||||
space_account_uuid='unknown-space-uuid',
|
||||
email='unknown@example.com',
|
||||
access_token='token',
|
||||
refresh_token='refresh',
|
||||
api_key='key',
|
||||
expires_in=3600,
|
||||
)
|
||||
|
||||
async def test_create_or_update_space_user_no_expiry(self):
|
||||
"""Creates Space user without token expiry."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.provider_service = SimpleNamespace()
|
||||
ap.provider_service.update_space_model_provider_api_keys = AsyncMock()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
new_user = _create_mock_user(
|
||||
email='noexpiry@example.com',
|
||||
account_type='space',
|
||||
space_account_uuid='noexpiry-uuid',
|
||||
)
|
||||
|
||||
# First call (line 138) returns None, second call (line 194) returns new_user
|
||||
call_count = 0
|
||||
async def mock_get_by_space_uuid(uuid):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1: # First check for existing user
|
||||
return None
|
||||
return new_user # After insert, return the new user
|
||||
|
||||
service.get_user_by_space_account_uuid = AsyncMock(side_effect=mock_get_by_space_uuid)
|
||||
service.get_user_by_email = AsyncMock(return_value=None)
|
||||
service.is_initialized = AsyncMock(return_value=False)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
# Execute with expires_in=0 (no expiry)
|
||||
result = await service.create_or_update_space_user(
|
||||
space_account_uuid='noexpiry-uuid',
|
||||
email='noexpiry@example.com',
|
||||
access_token='token',
|
||||
refresh_token='refresh',
|
||||
api_key='key',
|
||||
expires_in=0, # No expiry
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.space_account_uuid == 'noexpiry-uuid'
|
||||
|
||||
|
||||
class TestUserServiceCreateUserLock:
|
||||
"""Tests for create_user_lock attribute."""
|
||||
|
||||
def test_create_user_lock_initialized(self):
|
||||
"""Verify create_user_lock is initialized as asyncio.Lock."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Verify lock exists
|
||||
assert hasattr(service, '_create_user_lock')
|
||||
assert service._create_user_lock is not None
|
||||
@@ -1,506 +0,0 @@
|
||||
"""
|
||||
Unit tests for WebhookService.
|
||||
|
||||
Tests webhook CRUD operations including:
|
||||
- Webhook listing
|
||||
- Webhook creation
|
||||
- Webhook retrieval by ID
|
||||
- Webhook updates
|
||||
- Webhook deletion
|
||||
- Enabled webhooks filtering
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/webhook.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langbot.pkg.api.http.service.webhook import WebhookService
|
||||
from langbot.pkg.entity.persistence.webhook import Webhook
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_webhook(
|
||||
webhook_id: int = 1,
|
||||
name: str = 'Test Webhook',
|
||||
url: str = 'http://example.com/webhook',
|
||||
description: str = 'Test Description',
|
||||
enabled: bool = True,
|
||||
) -> Mock:
|
||||
"""Helper to create mock Webhook entity."""
|
||||
webhook = Mock(spec=Webhook)
|
||||
webhook.id = webhook_id
|
||||
webhook.name = name
|
||||
webhook.url = url
|
||||
webhook.description = description
|
||||
webhook.enabled = enabled
|
||||
return webhook
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestWebhookServiceGetWebhooks:
|
||||
"""Tests for get_webhooks method."""
|
||||
|
||||
async def test_get_webhooks_empty_list(self):
|
||||
"""Returns empty list when no webhooks exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'id': entity.id,
|
||||
'name': entity.name,
|
||||
'url': entity.url,
|
||||
}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_webhooks()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_webhooks_returns_serialized_list(self):
|
||||
"""Returns serialized list of webhooks."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
webhook1 = _create_mock_webhook(webhook_id=1, name='Webhook 1')
|
||||
webhook2 = _create_mock_webhook(webhook_id=2, name='Webhook 2')
|
||||
|
||||
mock_result = _create_mock_result([webhook1, webhook2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'id': entity.id,
|
||||
'name': entity.name,
|
||||
'url': entity.url,
|
||||
'description': entity.description,
|
||||
'enabled': entity.enabled,
|
||||
}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_webhooks()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert result[0]['name'] == 'Webhook 1'
|
||||
assert result[1]['name'] == 'Webhook 2'
|
||||
|
||||
|
||||
class TestWebhookServiceCreateWebhook:
|
||||
"""Tests for create_webhook method."""
|
||||
|
||||
async def test_create_webhook_full_params(self):
|
||||
"""Creates webhook with all parameters."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Mock insert result
|
||||
insert_result = Mock()
|
||||
|
||||
# Mock select result for retrieving created webhook
|
||||
created_webhook = _create_mock_webhook(
|
||||
webhook_id=1,
|
||||
name='New Webhook',
|
||||
url='http://new.example.com/webhook',
|
||||
description='New Description',
|
||||
enabled=True,
|
||||
)
|
||||
select_result = _create_mock_result(first_item=created_webhook)
|
||||
|
||||
# execute_async returns different results
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return insert_result # Insert
|
||||
return select_result # Select
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'id': 1,
|
||||
'name': 'New Webhook',
|
||||
'url': 'http://new.example.com/webhook',
|
||||
'description': 'New Description',
|
||||
'enabled': True,
|
||||
}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.create_webhook(
|
||||
name='New Webhook',
|
||||
url='http://new.example.com/webhook',
|
||||
description='New Description',
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result['name'] == 'New Webhook'
|
||||
assert result['url'] == 'http://new.example.com/webhook'
|
||||
assert result['description'] == 'New Description'
|
||||
assert result['enabled'] is True
|
||||
|
||||
async def test_create_webhook_defaults(self):
|
||||
"""Creates webhook with default description and enabled."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
created_webhook = _create_mock_webhook(
|
||||
webhook_id=1,
|
||||
name='Minimal Webhook',
|
||||
url='http://minimal.example.com',
|
||||
description='', # Default
|
||||
enabled=True, # Default
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return Mock() # Insert
|
||||
return _create_mock_result(first_item=created_webhook)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'id': 1,
|
||||
'name': 'Minimal Webhook',
|
||||
'url': 'http://minimal.example.com',
|
||||
'description': '',
|
||||
'enabled': True,
|
||||
}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute - only name and url required
|
||||
result = await service.create_webhook(name='Minimal Webhook', url='http://minimal.example.com')
|
||||
|
||||
# Verify defaults
|
||||
assert result['description'] == ''
|
||||
assert result['enabled'] is True
|
||||
|
||||
async def test_create_webhook_disabled(self):
|
||||
"""Creates webhook with enabled=False."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
created_webhook = _create_mock_webhook(webhook_id=1, enabled=False)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return Mock()
|
||||
return _create_mock_result(first_item=created_webhook)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'id': 1, 'enabled': False}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.create_webhook(name='Disabled', url='http://disabled.com', enabled=False)
|
||||
|
||||
# Verify
|
||||
assert result['enabled'] is False
|
||||
|
||||
|
||||
class TestWebhookServiceGetWebhook:
|
||||
"""Tests for get_webhook method."""
|
||||
|
||||
async def test_get_webhook_by_id_found(self):
|
||||
"""Returns webhook when found by ID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
webhook = _create_mock_webhook(webhook_id=1, name='Found Webhook')
|
||||
mock_result = _create_mock_result(first_item=webhook)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'id': 1,
|
||||
'name': 'Found Webhook',
|
||||
'url': 'http://example.com/webhook',
|
||||
}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_webhook(1)
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['id'] == 1
|
||||
assert result['name'] == 'Found Webhook'
|
||||
|
||||
async def test_get_webhook_by_id_not_found(self):
|
||||
"""Returns None when webhook not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_webhook(999)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_get_webhook_by_id_zero(self):
|
||||
"""Handles ID=0 (edge case) correctly."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_webhook(0)
|
||||
|
||||
# Verify - should return None (no webhook with ID 0)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestWebhookServiceUpdateWebhook:
|
||||
"""Tests for update_webhook method."""
|
||||
|
||||
async def test_update_webhook_name_only(self):
|
||||
"""Updates only the name field."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_webhook(1, name='Updated Name')
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_webhook_url_only(self):
|
||||
"""Updates only the url field."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_webhook(1, url='http://updated.example.com')
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_webhook_description_only(self):
|
||||
"""Updates only the description field."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_webhook(1, description='Updated description')
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_webhook_enabled_only(self):
|
||||
"""Updates only the enabled field."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_webhook(1, enabled=False)
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_webhook_all_fields(self):
|
||||
"""Updates all fields at once."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_webhook(
|
||||
1,
|
||||
name='All Updated',
|
||||
url='http://all.updated.com',
|
||||
description='All updated description',
|
||||
enabled=False,
|
||||
)
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_webhook_no_fields(self):
|
||||
"""Does nothing when no fields provided."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute - no update parameters
|
||||
await service.update_webhook(1)
|
||||
|
||||
# Verify - no execute call since no update_data
|
||||
ap.persistence_mgr.execute_async.assert_not_called()
|
||||
|
||||
|
||||
class TestWebhookServiceDeleteWebhook:
|
||||
"""Tests for delete_webhook method."""
|
||||
|
||||
async def test_delete_webhook_by_id(self):
|
||||
"""Deletes webhook by ID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_webhook(1)
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_delete_webhook_nonexistent_id(self):
|
||||
"""Delete operation completes even for nonexistent ID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute - should not raise
|
||||
await service.delete_webhook(999)
|
||||
|
||||
# Verify - still called
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
|
||||
class TestWebhookServiceGetEnabledWebhooks:
|
||||
"""Tests for get_enabled_webhooks method."""
|
||||
|
||||
async def test_get_enabled_webhooks_empty(self):
|
||||
"""Returns empty list when no enabled webhooks."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_enabled_webhooks()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_enabled_webhooks_filters_enabled(self):
|
||||
"""Returns only enabled webhooks."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# All returned webhooks should be enabled (SQL filter)
|
||||
webhook1 = _create_mock_webhook(webhook_id=1, name='Enabled 1', enabled=True)
|
||||
webhook2 = _create_mock_webhook(webhook_id=2, name='Enabled 2', enabled=True)
|
||||
|
||||
mock_result = _create_mock_result([webhook1, webhook2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'id': entity.id,
|
||||
'name': entity.name,
|
||||
'enabled': entity.enabled,
|
||||
}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_enabled_webhooks()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert all(w['enabled'] for w in result)
|
||||
|
||||
async def test_get_enabled_webhooks_filters_disabled(self):
|
||||
"""Does not return disabled webhooks."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Empty result because query filters on enabled=True
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_enabled_webhooks()
|
||||
|
||||
# Verify - should be empty (SQL would filter disabled)
|
||||
assert result == []
|
||||
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.api.http.service.apikey import ApiKeyService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize('api_key', [None, 123, b'lbk_bytes', '', 'plain_key', ' LBK_bad', 'sk-lbk_fake'])
|
||||
async def test_verify_api_key_rejects_non_lbk_keys_without_db_query(api_key):
|
||||
persistence_mgr = SimpleNamespace(execute_async=AsyncMock())
|
||||
service = ApiKeyService(SimpleNamespace(persistence_mgr=persistence_mgr))
|
||||
|
||||
result = await service.verify_api_key(api_key)
|
||||
|
||||
assert result is False
|
||||
persistence_mgr.execute_async.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
('db_row', 'expected'),
|
||||
[
|
||||
(object(), True),
|
||||
(None, False),
|
||||
],
|
||||
)
|
||||
async def test_verify_api_key_keeps_db_validation_for_lbk_keys(db_row, expected):
|
||||
query_result = Mock()
|
||||
query_result.first.return_value = db_row
|
||||
persistence_mgr = SimpleNamespace(execute_async=AsyncMock(return_value=query_result))
|
||||
service = ApiKeyService(SimpleNamespace(persistence_mgr=persistence_mgr))
|
||||
|
||||
result = await service.verify_api_key('lbk_valid_format')
|
||||
|
||||
assert result is expected
|
||||
persistence_mgr.execute_async.assert_awaited_once()
|
||||
@@ -1 +0,0 @@
|
||||
# Unit tests for command module
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user