Compare commits

..

2 Commits

Author SHA1 Message Date
Junyan Qin
2912eec7f5 Add OSS and commercial workspace boundaries 2026-05-08 17:29:22 +08:00
Junyan Qin
158503880c Document multi-tenant workspace architecture 2026-05-08 17:04:53 +08:00
224 changed files with 5604 additions and 43905 deletions

View File

@@ -4,29 +4,25 @@ on:
pull_request:
types: [opened, ready_for_review, synchronize]
paths:
- 'src/langbot/**'
- 'pkg/**'
- 'tests/**'
- '.github/workflows/run-tests.yml'
- 'pyproject.toml'
- 'uv.lock'
- 'run_tests.sh'
- 'scripts/test-*.sh'
push:
branches:
- master
- develop
paths:
- 'src/langbot/**'
- 'pkg/**'
- 'tests/**'
- '.github/workflows/run-tests.yml'
- 'pyproject.toml'
- 'uv.lock'
- 'run_tests.sh'
- 'scripts/test-*.sh'
jobs:
test:
name: Unit Tests
name: Run Unit Tests
runs-on: ubuntu-latest
strategy:
matrix:
@@ -43,13 +39,28 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install uv
uses: astral-sh/setup-uv@v4
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Install dependencies
run: uv sync --dev
run: |
uv sync --dev
- name: Run unit + smoke tests
run: uv run pytest tests/unit_tests/ tests/smoke/ -q --tb=short
- name: Run unit tests
run: |
bash run_tests.sh
- name: Upload coverage to Codecov
if: matrix.python-version == '3.12'
uses: codecov/codecov-action@v5
with:
files: ./coverage.xml
flags: unit-tests
name: unit-tests-coverage
fail_ci_if_error: false
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- name: Test Summary
if: always()
@@ -58,79 +69,3 @@ jobs:
echo "" >> $GITHUB_STEP_SUMMARY
echo "Python Version: ${{ matrix.python-version }}" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
integration:
name: Fast Integration Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Run fast integration tests
run: uv run pytest tests/integration/ -m "not slow" -q --tb=short
- name: Integration Test Summary
if: always()
run: |
echo "## Integration Tests Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
coverage:
name: Coverage Gate
runs-on: ubuntu-latest
needs: [test, integration]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Run coverage (unit + smoke)
run: |
uv run pytest tests/unit_tests/ tests/smoke/ \
--cov=langbot \
--cov-report=xml \
--cov-report=term-missing \
--cov-fail-under=18 \
-q --tb=short
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
files: ./coverage.xml
flags: unit-tests
name: coverage-report
fail_ci_if_error: false
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- name: Coverage Summary
if: always()
run: |
echo "## Coverage Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Threshold: 18%" >> $GITHUB_STEP_SUMMARY
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY

View File

@@ -9,13 +9,11 @@ on:
paths:
- 'src/langbot/pkg/persistence/**'
- 'src/langbot/pkg/entity/persistence/**'
- 'tests/integration/persistence/**'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- 'src/langbot/pkg/persistence/**'
- 'src/langbot/pkg/entity/persistence/**'
- 'tests/integration/persistence/**'
jobs:
test-migrations-sqlite:
@@ -36,8 +34,52 @@ jobs:
- name: Install dependencies
run: uv sync --dev
- name: Run SQLite migration tests
run: uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
- name: Test Alembic upgrade (SQLite)
run: |
uv run python -c "
import asyncio
from sqlalchemy.ext.asyncio import create_async_engine
from langbot.pkg.entity.persistence.base import Base
from langbot.pkg.persistence.alembic_runner import run_alembic_upgrade, run_alembic_stamp, get_alembic_current
async def main():
engine = create_async_engine('sqlite+aiosqlite:///test_migrations.db')
# Create all tables (simulates existing DB)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(engine, '0001_baseline')
rev = await get_alembic_current(engine)
assert rev == '0001_baseline', f'Expected 0001_baseline, got {rev}'
print(f'Stamped: {rev}')
# Upgrade to head
await run_alembic_upgrade(engine, 'head')
rev = await get_alembic_current(engine)
print(f'After upgrade: {rev}')
assert rev is not None, 'Expected a revision after upgrade'
# Verify idempotent
await run_alembic_upgrade(engine, 'head')
rev2 = await get_alembic_current(engine)
assert rev2 == rev, f'Expected {rev}, got {rev2}'
print(f'Idempotent check passed: {rev2}')
# Fresh DB: upgrade from scratch
engine2 = create_async_engine('sqlite+aiosqlite:///test_migrations_fresh.db')
async with engine2.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await run_alembic_upgrade(engine2, 'head')
rev3 = await get_alembic_current(engine2)
print(f'Fresh DB upgrade: {rev3}')
assert rev3 is not None
print('All SQLite migration tests passed!')
asyncio.run(main())
"
test-migrations-postgres:
name: Migrations (PostgreSQL)
@@ -72,7 +114,58 @@ jobs:
- name: Install dependencies
run: uv sync --dev
- name: Run PostgreSQL migration tests
env:
TEST_POSTGRES_URL: postgresql+asyncpg://langbot:langbot@localhost:5432/langbot_test
run: uv run pytest tests/integration/persistence/test_migrations_postgres.py -q --tb=short
- name: Test Alembic upgrade (PostgreSQL)
run: |
uv run python -c "
import asyncio
from sqlalchemy.ext.asyncio import create_async_engine
from langbot.pkg.entity.persistence.base import Base
from langbot.pkg.persistence.alembic_runner import run_alembic_upgrade, run_alembic_stamp, get_alembic_current
DB_URL = 'postgresql+asyncpg://langbot:langbot@localhost:5432/langbot_test'
async def main():
engine = create_async_engine(DB_URL)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(engine, '0001_baseline')
rev = await get_alembic_current(engine)
assert rev == '0001_baseline', f'Expected 0001_baseline, got {rev}'
print(f'Stamped: {rev}')
# Upgrade to head
await run_alembic_upgrade(engine, 'head')
rev = await get_alembic_current(engine)
print(f'After upgrade: {rev}')
assert rev is not None
# Verify idempotent
await run_alembic_upgrade(engine, 'head')
rev2 = await get_alembic_current(engine)
assert rev2 == rev, f'Expected {rev}, got {rev2}'
print(f'Idempotent check passed: {rev2}')
# Fresh DB: drop all and upgrade from scratch
engine2 = create_async_engine(DB_URL.replace('langbot_test', 'langbot_fresh'))
# Create fresh database
from sqlalchemy import text
async with engine.connect() as conn:
await conn.execute(text('COMMIT'))
await conn.execute(text('CREATE DATABASE langbot_fresh'))
async with engine2.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await run_alembic_upgrade(engine2, 'head')
rev3 = await get_alembic_current(engine2)
print(f'Fresh DB upgrade: {rev3}')
assert rev3 is not None
print('All PostgreSQL migration tests passed!')
asyncio.run(main())
"

View File

@@ -1,36 +0,0 @@
# LangBot Makefile
# Quick developer commands
.PHONY: test test-quick test-integration-fast test-coverage test-all-local lint
# Run all tests (full suite with coverage)
test:
bash run_tests.sh
# Quick self-test for developers (lint + unit + smoke, no real credentials needed)
test-quick:
bash scripts/test-quick.sh
# Fast integration tests (SQLite/API/Pipeline, no external services)
test-integration-fast:
bash scripts/test-integration-fast.sh
# Coverage gate (all tests, enforces minimum threshold)
test-coverage:
bash scripts/test-coverage.sh
# Full local quality gate (quick + integration + coverage)
test-all-local:
bash scripts/test-quick.sh
bash scripts/test-integration-fast.sh
bash scripts/test-coverage.sh
# Run linting only
lint:
ruff check src/langbot/ tests/
ruff format --check src/langbot/ tests/
# Fix linting issues
lint-fix:
ruff check --fix src/langbot/ tests/
ruff format src/langbot/ tests/

View File

@@ -47,8 +47,6 @@ LangBot is an **open-source, production-grade platform** for building AI-powered
[→ Learn more about all features](https://link.langbot.app/en/docs/features)
📍 Practical guides: [deploy a multi-platform AI bot in 5 minutes](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [connect DeepSeek to WeChat, Discord, and Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [run a Dify Agent in Discord, Telegram, and Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/), and [build an n8n-powered chatbot](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
---
## Quick Start

View File

@@ -47,8 +47,6 @@ LangBot 是一个**开源的生产级平台**,用于构建 AI 驱动的即时
[→ 了解更多功能特性](https://link.langbot.app/zh/docs/features)
📍 实践指南:[5 分钟部署多平台 AI 机器人](https://blog.langbot.app/zh/blog/deploy-ai-bot-in-5-minutes/)、[将 DeepSeek 接入微信、企业微信与 Discord](https://blog.langbot.app/zh/blog/connect-deepseek-to-wechat/)、[让 Dify Agent 跑在 Discord、Telegram 和 Slack 上](https://blog.langbot.app/zh/blog/dify-agent-discord-telegram-slack/),以及[用 n8n 构建多平台 AI 聊天机器人](https://blog.langbot.app/zh/blog/n8n-multi-platform-ai-chatbot/)。
---
## 快速开始

View File

@@ -46,8 +46,6 @@ LangBot es una **plataforma de código abierto y grado de producción** para con
[→ Conocer más sobre todas las funcionalidades](https://link.langbot.app/en/docs/features)
📍 Guías prácticas: [desplegar un bot de IA multiplataforma en 5 minutos](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [conectar DeepSeek a WeChat, Discord y Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [ejecutar un Dify Agent en Discord, Telegram y Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) y [crear un chatbot con n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
---
## Inicio Rápido

View File

@@ -46,8 +46,6 @@ LangBot est une **plateforme open-source de niveau production** pour créer des
[→ En savoir plus sur toutes les fonctionnalités](https://link.langbot.app/en/docs/features)
📍 Guides pratiques : [déployer un bot IA multiplateforme en 5 minutes](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [connecter DeepSeek à WeChat, Discord et Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [exécuter un Dify Agent dans Discord, Telegram et Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) et [créer un chatbot avec n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
---
## Démarrage Rapide

View File

@@ -46,8 +46,6 @@ LangBot は、AI搭載のインスタントメッセージングボットを構
[→ すべての機能について詳しく見る](https://link.langbot.app/ja/docs/features)
📍 実践ガイド: [5分でマルチプラットフォームAIボットをデプロイ](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/)、[DeepSeekをWeChat・Discord・Telegramに接続](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/)、[Dify AgentをDiscord・Telegram・Slackで動かす](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/)、[n8n連携チャットボットを構築](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/)。
---
## クイックスタート

View File

@@ -46,8 +46,6 @@ LangBot은 AI 기반 인스턴트 메시징 봇을 구축하기 위한 **오픈
[→ 모든 기능 자세히 보기](https://link.langbot.app/en/docs/features)
📍 실전 가이드: [5분 만에 멀티 플랫폼 AI 봇 배포하기](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [DeepSeek를 WeChat, Discord, Telegram에 연결하기](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [Dify Agent를 Discord, Telegram, Slack에서 실행하기](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/), [n8n 기반 챗봇 만들기](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
---
## 빠른 시작

View File

@@ -46,8 +46,6 @@ LangBot — это **платформа с открытым исходным к
[→ Подробнее обо всех возможностях](https://link.langbot.app/en/docs/features)
📍 Практические руководства: [развернуть мультиплатформенного ИИ-бота за 5 минут](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [подключить DeepSeek к WeChat, Discord и Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [запустить Dify Agent в Discord, Telegram и Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) и [создать чат-бота на n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
---
## Быстрый старт

View File

@@ -48,8 +48,6 @@ LangBot 是一個**開源的生產級平台**,用於建構 AI 驅動的即時
[→ 了解更多功能特性](https://link.langbot.app/zh/docs/features)
📍 實踐指南:[5 分鐘部署多平台 AI 機器人](https://blog.langbot.app/zh/blog/deploy-ai-bot-in-5-minutes/)、[將 DeepSeek 接入微信、企業微信與 Discord](https://blog.langbot.app/zh/blog/connect-deepseek-to-wechat/)、[讓 Dify Agent 跑在 Discord、Telegram 和 Slack 上](https://blog.langbot.app/zh/blog/dify-agent-discord-telegram-slack/),以及[用 n8n 建構多平台 AI 聊天機器人](https://blog.langbot.app/zh/blog/n8n-multi-platform-ai-chatbot/)。
---
## 快速開始

View File

@@ -46,8 +46,6 @@ LangBot là một **nền tảng mã nguồn mở, cấp sản xuất** để x
[→ Tìm hiểu thêm về tất cả tính năng](https://link.langbot.app/en/docs/features)
📍 Hướng dẫn thực hành: [triển khai bot AI đa nền tảng trong 5 phút](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [kết nối DeepSeek với WeChat, Discord và Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [chạy Dify Agent trên Discord, Telegram và Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) và [xây dựng chatbot với n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
---
## Bắt đầu nhanh

View File

@@ -0,0 +1,858 @@
# LangBot 多租户与多用户改造方案
## 目标
本方案面向 LangBot 从“单实例单管理员”演进到 SaaS 友好的“多 workspace、多账户、多权限”架构。
核心定义:
- Account登录主体。一个自然人或服务账号可加入多个 workspace。
- Workspace租户边界。一个 workspace 内可拥有多个用户、机器人、流水线、模型、知识库、扩展、监控数据与 API Key。
- Membership账户与 workspace 的关系,承载角色与权限。
- Role/Permissionworkspace 内权限,不再用“是否是当前唯一用户”来决定访问能力。
目标体验:
- 新用户登录后可以创建 workspace、加入 workspace、切换 workspace。
- 同一个账户可加入多个 workspace每个 workspace 权限不同。
- 一个 workspace 可邀请多个用户协作,并分别设置 owner/admin/editor/viewer 等权限。
- 所有业务资源默认属于某个 workspace所有 API 默认在当前 workspace 下工作。
- Plugin SDK、MCP、知识库、模型调用、监控日志都能拿到稳定的 workspace 上下文,并且不跨租户泄露数据。
## 调研结论
### 当前 LangBot 的单用户假设
LangBot 现在已经有 `users` 表和 JWT 登录,但仍是单用户/单租户模型:
- `src/langbot/pkg/entity/persistence/user.py``User` 只保存 `user/password/account_type/space_*`没有角色、状态、workspace 关系。
- `src/langbot/pkg/api/http/service/user.py` 通过 `is_initialized()` 判断系统是否已有用户;`create_or_update_space_user()` 在系统已初始化且邮箱不匹配时拒绝新用户,这直接限制了多用户登录。
- `src/langbot/pkg/api/http/controller/group.py``AuthType.USER_TOKEN` 验证后只向 handler 注入 `user_email`JWT payload 也只有 `user`,没有 `account_id``workspace_id``role``permissions`
- `src/langbot/pkg/api/http/service/apikey.py` 的 API Key 只验证 key 是否存在,没有 owner、scope、workspace。
- `web/src/app/infra/http/BaseHttpClient.ts``localStorage.token` 读取单个 token并在所有请求上加 `Authorization`;前端没有 workspace selector也没有当前 workspace 上下文。
当前登录流程更像“初始化一个本地管理账号”,而不是 SaaS 账户体系。要支持多用户,必须把“初始化状态”和“首个 workspace 创建”拆开。
### 业务资源当前都是全局资源
主要持久化表没有租户字段:
- Bot`bots`
- Pipeline`legacy_pipelines``pipeline_run_records`
- Model`model_providers``llm_models``embedding_models``rerank_models`
- Plugin`plugin_settings`
- MCP`mcp_servers`
- RAG`knowledge_bases``knowledge_base_files``knowledge_base_chunks`
- Monitoring`monitoring_messages``monitoring_llm_calls``monitoring_sessions``monitoring_errors``monitoring_embedding_calls``monitoring_feedback`
- API Key`api_keys`
- Webhook`webhooks`
- Metadata`metadata`
- Binary storage`binary_storages`
对应服务也直接 select 全表,例如:
- `BotService.get_bots()` 返回所有 bot。
- `PipelineService.get_pipelines()` 返回所有 pipeline。
- `ModelProviderService.get_providers()` 返回所有 provider。
- `MCPService.get_mcp_servers()` 返回所有 MCP server。
- 插件和二进制存储没有 workspace 维度,插件 workspace storage 在 SDK 里还硬编码为 `default`
所以改造重点不是只给用户表加字段,而是给资源访问层统一加入 workspace scope。
### 运行时也存在全局单例假设
`src/langbot/pkg/core/stages/build_app.py` 初始化的是一个全局 `Application`,其中包含单例:
- `platform_mgr`
- `pipeline_mgr`
- `model_mgr`
- `tool_mgr`
- `plugin_connector`
- `sess_mgr`
- `rag_mgr`
- `vector_db_mgr`
当前运行时把所有 bot、pipeline、model、plugin、MCP 都加载到同一套内存管理器。多租户改造需要决定:是共享运行时并在对象上带 workspace 过滤,还是每个 workspace 拆 runtime shard。第一阶段建议共享进程、强制 workspace-aware等规模变大后再演进为按 workspace 分片。
### Plugin SDK 对 workspace 的假设
SDK 当前只认识 bot/pipeline/query/session不认识租户
- `src/langbot_plugin/api/entities/builtin/pipeline/query.py``Query``query_id/launcher_type/launcher_id/sender_id/bot_uuid/pipeline_uuid`,没有 `workspace_id/account_id`
- `src/langbot_plugin/api/entities/builtin/provider/session.py``Session` 只按 `launcher_type + launcher_id` 表达会话。
- `src/langbot_plugin/api/proxies/langbot_api.py` 暴露 `get_bots/get_llm_models/invoke_llm/list_tools/vector_*` 等 Host API都是全局语义。
- `src/langbot_plugin/runtime/io/handlers/plugin.py``set_workspace_storage/get_workspace_storage``owner_type` 设为 `workspace`,但 `owner` 固定为 `default`
- LangBot 侧 `src/langbot/pkg/plugin/handler.py` 处理插件请求时,会把 `GET_BOTS``GET_LLM_MODELS``VECTOR_*` 等转到全局服务。
这意味着多租户落地时,不能只在 Web API 层过滤;插件可以通过 Host API 访问全局资源,所以 SDK/Runtime 通信也必须传递 workspace context。
## 开源版与商业版产品边界
LangBot 是开源软件,但多 workspace 能力本质上是 SaaS 控制面能力。如果把完整多 workspace、成员协作、订阅权益和配额代码都放进开源仓库只靠本地 feature flag 或本地 license check无法有效避免第三方 fork 后自建 SaaS。因此建议采用 open-core 架构:开源版保留单 workspace 执行能力,账户、订阅、权益和多 workspace 协作能力放到 LangBot Space/Cloud Control Plane 和商业模块中。
### 版本边界
推荐拆成三层:
- `LangBot Core OSS`:开源,自托管,默认只有一个隐式 `default workspace`。它可以在数据结构上兼容 workspace但产品能力上不提供创建多个 workspace、切换 workspace、成员邀请、成员权限管理、审计和多租户配额。
- `LangBot Space / Cloud Control Plane`:托管控制面,负责 account、workspace、membership、subscription、billing、entitlement、license token、workspace quota、marketplace 权益等能力。
- `LangBot Commercial Module`:商业闭源或私有包,承载多 workspace、团队协作、RBAC、自定义角色、审计日志、SAML/SSO、企业私有化授权等能力。
企业私有化版本可以采用 `LangBot Core + Commercial Module + License Token` 的形式交付。开源 Core 仍然可独立运行,但只能作为单 workspace 自托管产品,不提供 SaaS 多租户控制面。
### OSS 中如何保留兼容但不开放多 workspace
为了让后续商业版复用同一套资源隔离模型OSS 代码里可以保留 `workspace_uuid` 相关字段和默认 workspace 迁移,但应限制为单 workspace
- 首次初始化时创建一个 `Default Workspace`
- 所有资源自动归属这个 default workspace。
- 不暴露 `POST /api/v1/workspaces`
- 不暴露 workspace switcher。
- 不暴露成员邀请和成员角色管理。
- 不支持一个 account 加入多个 workspace。
- 不支持 workspace 数量大于 1。
- 前端不展示 workspace selector。
- API 层如果收到非 default workspace 的 `X-Workspace-Id`,直接拒绝。
也就是说OSS 可以是 workspace-aware但不是 multi-workspace-enabled。这样做的价值是代码结构提前适配租户隔离未来商业版不用重写所有资源模型同时开源版用户无法直接通过 UI/API 获得 SaaS 型多租户能力。
### 账户、订阅和权益抽到 Space
账户和订阅体系建议从 LangBot Core 中抽出,交给 Space 控制面:
```text
LangBot Space
-> Account
-> Workspace
-> Membership
-> Subscription
-> Entitlement
-> License Token
LangBot Core
-> Validate entitlement / license
-> Run bots, pipelines, plugins, MCP, RAG
-> Enforce local resource scope
-> Report usage
```
这样做有几个原因:
- 账号体系如果完全在本地,第三方容易直接改库创建 workspace/membership。
- 订阅、配额和商业权益如果完全在本地,容易绕过。
- Space 可以统一处理 OAuth、组织、邀请、付款、发票、套餐、权益、Marketplace 分发权限。
- LangBot Core 只作为执行面消费 Space 下发的 entitlement减少商业规则暴露。
### Entitlement 设计
Space 向 LangBot Core 下发签名权益,可以是在线校验,也可以为企业版提供短期/长期离线 license token。
示例:
```json
{
"edition": "oss",
"workspace_limit": 1,
"member_limit": 1,
"multi_workspace": false,
"rbac": false,
"audit_log": false,
"custom_roles": false,
"sso": false,
"commercial_use": false,
"expires_at": 1893456000
}
```
OSS 默认权益:
- `workspace_limit = 1`
- `member_limit = 1`
- `multi_workspace = false`
- `rbac = false`
- `audit_log = false`
- `sso = false`
Cloud/Pro/Enterprise 权益:
- `workspace_limit > 1`
- `member_limit > 1`
- `multi_workspace = true`
- `rbac = true`
- 可按套餐打开 audit、custom roles、SSO、usage reporting、enterprise support 等能力。
Core 执行面需要在关键入口强制校验 entitlement
- 创建 workspace。
- 邀请成员。
- 修改成员角色。
- 切换 workspace。
- 创建超过 quota 的资源。
- 开启商业模块功能。
### 商业模块边界
以下能力不建议进入 OSS 仓库的完整实现:
- 多 workspace 创建和切换。
- Workspace 成员邀请。
- Workspace RBAC 和自定义角色。
- Workspace 审计日志。
- Workspace 级用量和配额管理。
- 订阅、账单、发票。
- 企业 SSO/SAML/OIDC。
- 在线/离线 license 管理。
- 多租户 SaaS 运营控制台。
OSS 仓库可以保留接口占位、默认 workspace 兼容和必要的数据隔离字段,但完整交互、管理 UI、权益校验器和多 workspace policy engine 应由 Space 或商业模块提供。
### 防自建 SaaS 的现实边界
技术上无法 100% 阻止别人 fork 开源代码后自行改造。更可靠的策略是组合:
- 不把完整商业多 workspace 实现放进 OSS。
- Space 控制面提供账号、订阅、权益、Marketplace 和官方托管能力。
- 商业模块闭源或私有分发。
- 使用商标、云服务条款和商业 license 限制“自称官方 LangBot SaaS”或未经授权商用托管。
- 如果当前开源 license 对托管商用限制不足,需要单独评估 license 策略,必要时引入 open-core license 或新增商业附加条款。具体 license 调整需要法律评审。
结论:多 workspace 的底层 schema 可以在 OSS 中以 default workspace 兼容方式铺路,但多 workspace 产品能力、账户订阅权益、协作管理和 SaaS 控制面应放到 Space/商业模块,不作为开源版可直接使用的能力。
## 推荐总体架构
采用“单实例多 workspace资源行级隔离运行时上下文隔离”的架构
```mermaid
flowchart TD
A["Account"] --> B["WorkspaceMembership"]
B --> C["Workspace"]
C --> D["Bots"]
C --> E["Pipelines"]
C --> F["Models & Providers"]
C --> G["Knowledge Bases"]
C --> H["Extensions: Plugins / MCP"]
C --> I["API Keys & Webhooks"]
C --> J["Monitoring"]
D --> K["Runtime Query"]
E --> K
K --> L["Plugin Runtime"]
K --> M["MCP Runtime"]
L --> N["Workspace-scoped Host APIs"]
```
原则:
- 账户全局唯一workspace 是所有业务资源的归属边界。
- 所有 HTTP handler 在进入业务服务前解析出 `RequestContext(account, workspace, membership, permissions)`
- 所有 service 方法显式接收 `ctx``workspace_id`,禁止在业务服务里无条件 select 全表。
- 运行时对象的 key 从 `uuid` 扩展为 `(workspace_id, uuid)` 或使用全局唯一 uuid 但必须记录 workspace_id 并校验。
- 插件/MCP/知识库/模型调用都按 query 所属 workspace 过滤可用资源。
## 数据模型设计
### Account
替代现有 `users` 的语义,建议保留表名但升级字段,避免过大迁移:
字段建议:
- `id`
- `uuid`
- `email`
- `password_hash`
- `display_name`
- `avatar_url`
- `account_type`: `local | space`
- `status`: `active | disabled | deleted`
- `space_account_uuid`
- `space_access_token`
- `space_refresh_token`
- `space_access_token_expires_at`
- `space_api_key`
- `created_at`
- `updated_at`
兼容策略:
- 旧字段 `user` 迁移为 `email`,可以短期保留 alias。
-`password` 迁移为 `password_hash`也可先保持列名不变service 层改命名。
- JWT 中不要继续只放 email应放 `sub=account_uuid`
### Workspace
新增 `workspaces`
- `uuid`
- `name`
- `slug`
- `avatar_url`
- `type`: `personal | team`
- `status`: `active | suspended | deleted`
- `default_language`
- `created_by_account_uuid`
- `created_at`
- `updated_at`
每个账户首次登录时自动创建一个 personal workspace。旧单用户实例迁移时创建一个 `Default Workspace`
### WorkspaceMembership
新增 `workspace_memberships`
- `workspace_uuid`
- `account_uuid`
- `role`: `owner | admin | developer | operator | viewer`
- `status`: `active | invited | disabled`
- `invited_by_account_uuid`
- `joined_at`
- `created_at`
- `updated_at`
唯一索引:
- `(workspace_uuid, account_uuid)`
### WorkspaceInvitation
新增 `workspace_invitations`
- `uuid`
- `workspace_uuid`
- `email`
- `role`
- `token_hash`
- `expires_at`
- `accepted_at`
- `created_by_account_uuid`
- `created_at`
用于邀请外部用户加入 workspace。Space OAuth 登录时可以根据 email 自动匹配未接受邀请。
### Role 与 Permission
先用固定角色,后续再做自定义角色。
建议权限:
- `workspace.manage`
- `member.view`
- `member.invite`
- `member.update_role`
- `member.remove`
- `bot.view`
- `bot.manage`
- `pipeline.view`
- `pipeline.manage`
- `model.view`
- `model.manage`
- `knowledge.view`
- `knowledge.manage`
- `extension.view`
- `extension.manage`
- `monitoring.view`
- `apikey.manage`
- `webhook.manage`
- `billing.view`
- `billing.manage`
角色映射:
| Role | 说明 | 权限 |
| --- | --- | --- |
| owner | workspace 拥有者 | 全部权限;可转让 owner不可被其他角色移除 |
| admin | 管理员 | 除 owner 转让和删除 workspace 外的全部权限 |
| developer | 构建者 | 管理 bot、pipeline、model、knowledge、extension、webhook可看监控 |
| operator | 运营者 | 查看和启停 bot、查看监控、查看配置不可改密钥和删除资源 |
| viewer | 只读成员 | 只读资源和监控 |
### 业务资源加 workspace_uuid
以下表需要新增 `workspace_uuid`
- `bots`
- `legacy_pipelines`
- `pipeline_run_records`
- `model_providers`
- `llm_models`
- `embedding_models`
- `rerank_models`
- `plugin_settings`
- `mcp_servers`
- `knowledge_bases`
- `knowledge_base_files`
- `knowledge_base_chunks`
- `monitoring_*`
- `api_keys`
- `webhooks`
- `binary_storages`
- `metadata`
索引建议:
- 所有资源表加 `(workspace_uuid, created_at)``(workspace_uuid, updated_at)`
- 资源唯一键从单列改为 workspace 复合唯一:
- `bots.uuid` 可保持全局唯一,但查询仍必须带 workspace。
- `plugin_settings` 主键从 `(plugin_author, plugin_name)` 改为 `(workspace_uuid, plugin_author, plugin_name)`
- `mcp_servers.name` 如果未来要求唯一,必须是 `(workspace_uuid, name)`
- `metadata.key` 改为 `(workspace_uuid, key)`,系统级 metadata 单独放 `system_metadata` 或使用 `workspace_uuid=NULL`
- `binary_storages.unique_key` 建议改为 `workspace_uuid + owner_type + owner + key` 的 hash。
### API Key
API Key 必须归属于 workspace
- `workspace_uuid`
- `created_by_account_uuid`
- `scopes`
- `expires_at`
- `last_used_at`
- `status`
验证 API Key 后生成 `RequestContext`
- `account_uuid=None` 或 service account uuid
- `workspace_uuid=key.workspace_uuid`
- `permissions=key.scopes`
这样 `/api/v1/platform/bots/<uuid>/send_message` 之类接口不会跨 workspace 操作 bot。
## 后端改造方案
### RequestContext
新增统一上下文对象,例如:
```python
class RequestContext:
account_uuid: str | None
workspace_uuid: str
role: str | None
permissions: set[str]
auth_type: Literal["user_token", "api_key"]
```
改造 `RouterGroup.route()`
- `USER_TOKEN`:验证 JWT读取 `account_uuid`,再从 header/query/cookie 中解析 current workspace。
- `API_KEY`:验证 API Key直接得到 workspace。
- `USER_TOKEN_OR_API_KEY`:两者都返回同一种 `RequestContext`
- handler 参数从可选 `user_email` 升级为可选 `ctx`;兼容期同时支持 `user_email`
当前 workspace 传递方式:
- 推荐 header`X-Workspace-Id: <workspace_uuid>`
- Web 前端同时把当前 workspace 存在 localStorage。
- 如果未传,后端用账户最近使用 workspace 或第一个 active membership。
JWT payload
```json
{
"sub": "account_uuid",
"email": "user@example.com",
"iss": "LangBot-...",
"exp": 1234567890
}
```
不要把 workspace 写死在 JWT 里,否则切换 workspace 需要刷新 token。可以额外支持短 TTL workspace token但第一阶段不必。
### 服务层改造模式
所有 service 方法增加 `ctx``workspace_uuid`
```python
async def get_bots(self, ctx: RequestContext, include_secret: bool = True):
require(ctx, "bot.view")
query = sqlalchemy.select(Bot).where(Bot.workspace_uuid == ctx.workspace_uuid)
```
需要改造的服务:
- `UserService`:拆成 AccountService + WorkspaceService 更清晰。
- `ApiKeyService`:按 workspace 管理 key。
- `BotService`:所有 bot 查询/创建/更新/删除按 workspace。
- `PipelineService`:所有 pipeline 查询/默认 pipeline 按 workspace。
- `ModelProviderService` 和 model services按 workspace 隔离 provider 和 model。
- `MCPService`:按 workspace 管理 MCP server运行时按 workspace host。
- `KnowledgeService/RAGRuntimeService`:按 workspace 管理 KB、文件、collection。
- `MonitoringService`:记录和查询都带 workspace。
- `WebhookService`:按 workspace 管理 webhook。
- `PluginRuntimeConnector`:插件安装、设置、配置按 workspace。
### HTTP API 形态
保留现有路径,靠 `X-Workspace-Id` 表示当前 workspace可减少前端和 SDK 破坏:
- `GET /api/v1/workspaces`
- `POST /api/v1/workspaces`
- `GET /api/v1/workspaces/current`
- `PUT /api/v1/workspaces/current`
- `GET /api/v1/workspaces/<workspace_uuid>/members`
- `POST /api/v1/workspaces/<workspace_uuid>/invitations`
- `PUT /api/v1/workspaces/<workspace_uuid>/members/<account_uuid>`
- `DELETE /api/v1/workspaces/<workspace_uuid>/members/<account_uuid>`
现有资源 API
- `/api/v1/platform/bots`
- `/api/v1/pipelines`
- `/api/v1/provider/*`
- `/api/v1/plugins`
- `/api/v1/mcp`
- `/api/v1/knowledge`
继续保留,但必须从 `RequestContext.workspace_uuid` 过滤。
对外 API Key 也使用相同路径,只是由 key 决定 workspace。
### 初始化流程
现有 `/api/v1/user/init` 含义改为“创建首个账号和首个 workspace”
1. 如果系统没有任何 account
- 创建 account。
- 创建 personal/team workspace。
- 创建 owner membership。
- 创建默认 pipeline。
- 标记 wizard status 到该 workspace metadata。
2. 如果系统已有 account
- 禁止无邀请注册,除非配置允许公开注册。
- Space OAuth 登录后,如果没有 membership引导创建 workspace 或接受邀请。
`/api/v1/user/account-info` 不应再只返回 first user应返回
- `initialized`
- `registration_mode`
- `space_enabled`
- `default_login_methods`
登录成功后前端调用 `/api/v1/workspaces` 选择 workspace。
### 运行时隔离
第一阶段采用共享进程 + workspace-aware runtime
- `RuntimeBot` 增加 `workspace_uuid`
- `RuntimePipeline` 增加 `workspace_uuid`
- `Query` 增加 `workspace_uuid`,从 bot/pipeline 派生。
- `SessionManager.get_session()` key 从 `(launcher_type, launcher_id)` 改为 `(workspace_uuid, bot_uuid, launcher_type, launcher_id)`
- `PipelineManager.pipeline_dict` key 可保持 pipeline uuid但所有 load/get 都校验 workspace如果 uuid 不是全局唯一则改为 `(workspace_uuid, pipeline_uuid)`
- `ModelManager` provider/model 加 workspace 过滤;`get_model_by_uuid` 必须确保 query workspace 可访问。
- `ToolManager` 中 MCP tools、plugin tools 按 query workspace 过滤。
后续规模化时可演进:
- workspace runtime shard每个 workspace 一套 plugin runtime/MCP runtime。
- 大客户独立进程或独立数据库。
## Plugin SDK 与 Runtime 改造
### Query/Event 增加 workspace context
SDK `Query` 增加:
- `workspace_uuid: str`
- `workspace_slug: str | None`
- `account_uuid: str | None`,仅 Web/API 触发时可能有,聊天平台消息通常为空。
Event 模型通过 `event.query.workspace_uuid` 可拿到租户上下文;序列化时也应包含这些字段。
向后兼容:
- 字段可选,默认 `None`
- 老插件不感知这些字段也能跑。
- 新插件可通过 `ctx.event.query.workspace_uuid` 或新增 `ctx.get_workspace()` 访问。
### Host API 默认按当前 workspace 限制
`LangBotAPIProxy` 的以下方法必须由 Host 端按 workspace 过滤:
- `get_bots`
- `get_bot_info`
- `send_message`
- `get_llm_models`
- `invoke_llm`
- `list_plugins_manifest`
- `list_commands`
- `list_tools`
- `call_tool`
- `invoke_embedding`
- `vector_*`
- `list_knowledge_bases`
- `retrieve_knowledge`
建议新增显式方法:
- `get_workspace_info()`
- `get_current_account()`
- `get_workspace_storage(...)`
但不要让插件传入任意 workspace id 来越权访问。插件请求的 workspace 应由 Runtime 根据当前 query/plugin connection 填充。
### Workspace storage 修复
当前 SDK runtime 中:
```python
data["owner_type"] = "workspace"
data["owner"] = "default"
```
必须改为:
- 如果请求来自 query/eventowner 为 `workspace_uuid`
- 如果请求来自后台插件任务owner 为 plugin 安装所属 workspace。
- Host 侧 `binary_storages``workspace_uuid`,并在 unique key 中包含 workspace。
Plugin storage 建议也同时加 workspace
- 现在 plugin storage owner 是 `author/name`,这会导致同一插件在不同 workspace 的私有数据冲突。
- 应改为 `(workspace_uuid, plugin_id, key)`
### 插件安装与配置
`plugin_settings` 从全局变为 workspace-scoped
- 同一个插件可安装到多个 workspace。
- 每个 workspace 有自己的 enabled、priority、config、install_source、install_info。
- 插件 runtime 列表需要能按 workspace 过滤。
实现路线有两种:
1. 共享插件进程,插件代码只加载一份,设置和调用时附带 workspace。
2. 每个 workspace 一个插件容器实例,隔离最彻底但资源占用更高。
推荐第一阶段采用方案 1但要求
- 所有 RuntimeToLangBot/PluginToRuntime action 都能携带 `workspace_uuid`
- 插件 config 获取时按 workspace 返回。
- 插件 page API 请求必须校验当前用户在该 workspace 有访问权限。
### MCP
MCP server 是租户资源:
- `mcp_servers.workspace_uuid`
- MCP session key 从 `server_name` 改为 `(workspace_uuid, server_name)` 或使用全局 uuid。
- Pipeline extension preferences 中绑定 MCP server uuid 时,只能绑定同 workspace 的 server。
- MCP tool 列表在 query 执行时按 query.workspace_uuid + pipeline 绑定关系过滤。
## 前端改造
### Workspace selector
Home layout 顶部或 sidebar 增加 workspace selector
- 当前 workspace 名称和头像。
- 切换 workspace 后写入 `localStorage.currentWorkspaceId`
- 所有请求自动带 `X-Workspace-Id`
- 切换后刷新 sidebar 数据和页面缓存。
`BaseHttpClient` request interceptor 增加:
```ts
const workspaceId = localStorage.getItem("currentWorkspaceId");
if (workspaceId) config.headers["X-Workspace-Id"] = workspaceId;
```
### 用户与成员管理页面
新增页面:
- `/home/workspace/settings`
- `/home/workspace/members`
- `/home/workspace/invitations`
能力:
- owner/admin 邀请成员。
- owner/admin 修改成员角色。
- owner 移除成员、转让 owner。
- 所有人可切换 workspace。
- viewer/operator 在 UI 上隐藏不可操作按钮,但后端仍做权限校验。
### 登录与注册
登录后流程:
1. `authUser` 拿 token。
2. `initializeUserInfo()` 获取 account info。
3. `GET /api/v1/workspaces`
4. 如果没有 workspace进入创建 workspace 向导。
5. 如果有多个 workspace默认进入最近使用 workspace可切换。
注册页不再表达“初始化管理员账号”,而是:
- 首次系统启动:创建首个 owner + default workspace。
- 后续:根据配置允许公开注册,或只能接受邀请。
### 旧页面影响
需要逐个检查这些页面的数据加载是否都依赖当前 workspace
- Bots
- Pipelines
- Plugins/Market/MCP
- Knowledge
- Monitoring
- Models dialog
- API integration dialog
- Wizard
## 迁移方案
### 迁移阶段 0准备
- 引入 `workspaces``workspace_memberships``workspace_invitations`
-`users` 增加 `uuid/status/display_name` 等字段。
- 创建 `RequestContext`,但先不强制所有服务改完。
### 迁移阶段 1默认 workspace
对现有实例执行迁移:
1. 创建 `Default Workspace`
2. 找到现有第一个 user设为 owner。
3. 所有已有资源写入 `workspace_uuid=default_workspace_uuid`
4. `metadata` 迁入 default workspace确实全局的配置放到 `system_metadata`
5. `binary_storages``owner_type=workspace, owner=default` 改为 owner 为 default workspace uuid。
6. 插件 `plugin_settings` 归入 default workspace。
### 迁移阶段 2服务层强制 scope
- 改所有 service 查询,必须要求 `workspace_uuid`
- API Key 迁移为 workspace key。
- 所有写操作必须检查权限。
- 监控和任务查询按 workspace 过滤。
### 迁移阶段 3运行时上下文
- `Query``Session``RuntimeBot``RuntimePipeline` 增加 workspace。
- Plugin/MCP/Model/RAG runtime 全部按 workspace 过滤。
- 修复 SDK workspace storage。
### 迁移阶段 4前端多 workspace
- 登录后 workspace 选择。
- Header/sidebar workspace switcher。
- 成员和邀请管理。
- 所有 API 请求带 `X-Workspace-Id`
### 迁移阶段 5安全收敛
- 添加跨 workspace 越权测试。
- 添加 API Key scope 测试。
- 添加插件 Host API 过滤测试。
- 添加 MCP 和 RAG 隔离测试。
## 安全边界
必须防的场景:
- 用户 A 修改 URL 中 bot uuid访问用户 B workspace 的 bot。
- API Key 来自 workspace A但调用 workspace B 的 bot。
- 插件通过 `get_bots()` 枚举所有 workspace 的 bot。
- 插件通过 `workspace_storage` 读取其它 workspace 的数据。
- MCP server 名称相同导致 session 复用。
- monitoring session_id 相同导致数据串租户。
- Space OAuth 登录时,同 email 账户被错误绑定到已有本地 account。
建议策略:
- 所有资源访问都使用 `workspace_uuid + resource_id`
- 所有 service 方法入口做权限检查。
- 插件 Host API 的 workspace 不信任插件入参,只信任 query/runtime connection 上下文。
- API Key 只授予最小 scope默认不允许成员管理。
- owner 角色不能被普通 admin 移除或降权。
## 实施优先级
### P0基础租户骨架
- Account uuid/status。
- Workspace / Membership / Invitation。
- RequestContext。
- JWT 改为 account uuid。
- 前端 current workspace header。
### P1资源行级隔离
- Bots、Pipelines、Models、MCP、Plugins、Knowledge、Monitoring、API Keys 全部加 workspace_uuid。
- service 查询统一加 workspace filter。
- 权限矩阵落地。
### P2运行时隔离
- Query、Session、RuntimeBot、RuntimePipeline 加 workspace。
- Plugin Host API 和 MCP tools 按 workspace 过滤。
- SDK workspace storage 从 `default` 改为真实 workspace。
### P3协作体验
- 邀请成员。
- 成员列表和角色管理。
- workspace switcher。
- 最近使用 workspace。
### P4SaaS 运维增强
- Workspace 级用量统计。
- Workspace 级限额max_bots/max_pipelines/max_extensions/tokens/storage。
- 审计日志。
- workspace suspend/delete。
- 可选自定义角色。
## 测试计划
后端测试:
- 账户可加入多个 workspace。
- 同账户在不同 workspace 权限不同。
- viewer 不能创建/修改资源。
- API Key 只能访问所属 workspace。
- 所有资源 list/get/update/delete 都不能跨 workspace。
- 默认 workspace 迁移后旧数据可用。
运行时测试:
- 两个 workspace 使用相同 `launcher_id` 不共享 session。
- 两个 workspace 使用相同 MCP server name 不共享 MCP session。
- 插件 `get_bots()` 只能看到当前 workspace bot。
- 插件 `workspace_storage` 在不同 workspace 读写隔离。
- Pipeline 只调用当前 workspace 绑定的插件和 MCP tools。
前端测试:
- 登录后自动进入最近 workspace。
- 切换 workspace 后列表数据变化。
- 无权限按钮隐藏,直接调用 API 也被后端拒绝。
- 邀请成员流程完整。
迁移测试:
- SQLite 老实例迁移。
- PostgreSQL 老实例迁移。
- 已有 local account 迁移为 default workspace owner。
- 已有 Space account token 和 Space model provider API key 不丢失。
## 关键实现注意事项
- 不建议在第一版做数据库 schema-per-tenant。LangBot 当前 ORM 和运行时均以单库单表为主,先做 shared schema + workspace_uuid 成本更低。
- 不建议每个 workspace 立即启动独立 plugin runtime。先共享 runtime强制 action 带 workspace大客户隔离可作为后续部署形态。
- 不要只在前端过滤 workspace。插件、API Key、MCP、RAG 都能绕过前端,必须在后端和运行时层过滤。
- `metadata` 要拆清楚wizard status 属于 workspace系统版本/迁移信息属于 system。
- `users.user` 用 email 当主键语义不稳,应尽快引入 `account_uuid` 并让 JWT 以 uuid 为准。
- `plugin_settings` 当前主键没有 workspace改造时要先改主键/唯一约束,否则同插件无法在多个 workspace 配不同配置。
## 建议落地顺序
1. 新增 workspace/account/membership 表和 RequestContext。
2. 迁移旧数据到 default workspace。
3. 改 auth 和前端请求头,让每个请求都有 current workspace。
4. 从最核心资源开始逐个加 scopebot -> pipeline -> provider/model -> plugin/MCP -> knowledge -> monitoring。
5. 改 SDK Query/Event 和 runtime storage。
6. 上成员管理 UI 和邀请。
7. 做越权测试和迁移测试。
这个顺序的好处是可以较早让主 UI 在一个 workspace 下继续工作,同时把最危险的跨租户泄露面逐步收紧。

View File

@@ -1,6 +1,6 @@
[project]
name = "langbot"
version = "4.9.7"
version = "4.9.6"
description = "Production-grade platform for building agentic IM bots"
readme = "README.md"
license-files = ["LICENSE"]
@@ -22,7 +22,7 @@ dependencies = [
"discord-py>=2.5.2",
"pynacl>=1.5.0", # Required for Discord voice support
"gewechat-client>=0.1.5",
"lark-oapi>=1.5.5",
"lark-oapi>=1.4.15",
"mcp>=1.25.0",
"nakuru-project-idk>=0.0.2.1",
"ollama>=0.4.8",
@@ -35,7 +35,6 @@ dependencies = [
"python-telegram-bot>=22.0",
"pyyaml>=6.0.2",
"qq-botpy-rc>=1.2.1.6",
"qrcode>=7.4",
"quart>=0.20.0",
"quart-cors>=0.8.0",
"requests>=2.32.3",
@@ -70,7 +69,7 @@ dependencies = [
"chromadb>=1.0.0,<2.0.0",
"qdrant-client (>=1.15.1,<2.0.0)",
"pyseekdb==1.1.0.post3",
"langbot-plugin==0.3.11",
"langbot-plugin==0.3.10",
"asyncpg>=0.30.0",
"line-bot-sdk>=3.19.0",
"matrix-nio>=0.25.2",
@@ -122,7 +121,6 @@ package-data = { "langbot" = ["templates/**", "pkg/provider/modelmgr/requesters/
[dependency-groups]
dev = [
"moto>=5.2.1",
"pre-commit>=4.2.0",
"pytest>=9.0.3",
"pytest-asyncio>=1.0.0",

View File

@@ -4,9 +4,6 @@ python_files = test_*.py
python_classes = Test*
python_functions = test_*
# Python path for imports
pythonpath = . tests
# Test paths
testpaths = tests
@@ -25,9 +22,7 @@ markers =
asyncio: mark test as async
unit: mark test as unit test
integration: mark test as integration test
smoke: mark test as smoke test
slow: mark test as slow running
e2e: mark test as end-to-end test (requires real LangBot process)
# Coverage options (when using pytest-cov)
[coverage:run]

View File

@@ -1,649 +0,0 @@
"""Generate the DingTalk human-input card template JSON.
The output is wrapped in the {editorData, widgetInfo, type, mode} envelope
the DingTalk card builder expects on import. editorData is itself a JSON
string (NOT a nested object), matching real exports from the builder.
Run from the repo root: python scripts/build_dingtalk_card_template.py
"""
from __future__ import annotations
import json
from pathlib import Path
OUTPUT = Path('src/langbot/templates/dingtalk_human_input_card.json')
def markdown_block(node_id, variable='content'):
"""A MarkdownBlock whose content is bound to a global variable.
Unlike BaseText (whose `text` field requires editor-side manual binding),
MarkdownBlock's `content` accepts `variableValue` binding at JSON load
time — the imported template renders the variable straight away.
"""
return {
'componentName': 'MarkdownBlock',
'id': node_id,
'props': {
'mdVer': 0,
'icon': {'type': 'icon', 'icon': '', 'iconType': 'emoji'},
'content': {'variable': variable, 'variableType': 'global', 'type': 'variableValue'},
'visible': {
'type': 'dynamicVisible',
'value': True,
'valueType': 'fixed',
'condition': {'op': 'and', 'conditions': []},
},
'isStreaming': True,
'enableLinkStatPoint': False,
'linkStatPoint': {'type': 'dynamicString', 'content': '', 'i18n': False},
'linkStatPointParams': [],
'marginTop': 6,
'marginBottom': 6,
'marginLeft': 12,
'marginRight': 12,
},
'title': 'AI 流式富文本',
'hidden': False,
'isLocked': False,
'condition': True,
'conditionGroup': '',
}
def text_block(
node_id,
text,
*,
bold=False,
gravity='left',
font_size=14,
line_height=22,
max_lines=20,
ml=12,
mr=12,
mt=4,
mb=4,
color_token='common_level1_base_color',
style_token='common_body_text_style',
):
return {
'componentName': 'BaseText',
'id': node_id,
'props': {
'text': {'i18n': False, 'type': 'dynamicString', 'content': text},
'hoverText': {'type': 'dynamicString', 'content': '', 'i18n': False},
'iconType': 'iconCode',
'iconFont': {'type': 'icon', 'icon': '', 'iconType': 'ddIcon'},
'icon': {
'type': 'dynamicLink',
'value': '',
'valueType': 'fixed',
'variable': '',
'variableType': 'global',
},
'darkIcon': {
'type': 'dynamicLink',
'value': '',
'valueType': 'fixed',
'variable': '',
'variableType': 'global',
},
'autoWidth': False,
'maxWidth': {
'type': 'dynamicNumber',
'valueType': 'fixed',
'value': 0,
'variable': '',
'variableType': 'global',
},
'fixedWidth': {
'type': 'dynamicNumber',
'valueType': 'fixed',
'value': 0,
'variable': '',
'variableType': 'global',
},
'marginLeft': ml,
'marginRight': mr,
'marginTop': mt,
'marginBottom': mb,
'fontColorType': 'Standard',
'enableHighlight': False,
'maxLine': {
'type': 'dynamicNumber',
'valueType': 'fixed',
'value': max_lines,
'variable': '',
'variableType': 'global',
},
'color': {
'type': 'dynamicColor',
'valueType': 'fixed',
'value': color_token,
'variable': '',
'variableType': 'global',
},
'customLightColor': {
'type': 'dynamicColor',
'valueType': 'fixed',
'value': '#35404b',
'variable': '',
'variableType': 'global',
},
'customDarkColor': {
'type': 'dynamicColor',
'valueType': 'fixed',
'value': '#f6f6f6',
'variable': '',
'variableType': 'global',
},
'gravity': gravity,
'fontSizeType': 'Standard',
'styleType': 'custom',
'styleToken': style_token,
'size': 'middle',
'customFontSize': font_size,
'customFontLineHeight': line_height,
'bold': bold,
'italic': False,
'strikeThrough': False,
'lineHeight': 'normal',
'visible': {
'type': 'dynamicVisible',
'value': True,
'valueType': 'fixed',
'condition': {'op': 'and', 'conditions': []},
},
'autoMaxWidth': False,
'innerOffset': 0,
'enableIcon': False,
'widthMode': 'match_parent',
'margin': -2,
},
'title': '基础文本',
'hidden': False,
'isLocked': False,
'condition': True,
'conditionGroup': '',
}
def button_group(node_id):
return {
'componentName': 'ButtonGroup',
'id': node_id,
'props': {
'dynamicButtons': {'type': 'variableValue', 'variableType': 'global', 'variable': 'btns'},
'marginLeft': 12,
'marginRight': 12,
'marginTop': 6,
'marginBottom': 12,
'visible': {
'type': 'dynamicVisible',
'value': True,
'valueType': 'fixed',
'condition': {'op': 'and', 'conditions': []},
},
'responsiveLayoutWidth': 350,
'buttonsSource': 'variable',
'fixedButtonIds': [],
'fixedButtons': [],
'enableResponsiveLayout': False,
'matchContent': False,
'buttonSpacing': 8,
'margin': -2,
'innerOffset': 0,
},
'title': '按钮组',
'hidden': False,
'isLocked': False,
'condition': True,
'conditionGroup': '',
}
def build_editor_data():
component_names = [
'AIPending',
'AICardStatusContainer',
'BaseText',
'AICardContent',
'AICardContainer',
'ButtonGroup',
'MarkdownBlock',
]
components_map = [
{
'package': '@ali/dxComponent',
'version': '1.0.0',
'exportName': n,
'main': './src/index.tsx',
'destructuring': False,
'subName': '',
'componentName': n,
}
for n in component_names
]
pending_state = {
'componentName': 'AICardStatusContainer',
'id': 'node_status_pending',
'props': {
'status': 1,
'marginLeft': 0,
'marginRight': 0,
'marginTop': 0,
'marginBottom': 0,
'enableExtend': False,
'autoFoldConfig': {
'needFold': True,
'heightLimit': 480,
'foldStatusLocalDataKey': '_cardFoldStatusLocalDataKey',
},
'innerOffset': 0,
'enableCollapse': False,
'margin': -2,
},
'title': '处理中状态',
'hidden': False,
'isLocked': False,
'condition': True,
'conditionGroup': '',
'children': [
{
'componentName': 'AIPending',
'id': 'node_pending_inner',
'props': {
'marginLeft': 0,
'marginRight': 0,
'marginTop': 0,
'marginBottom': 0,
'pendingTip': {'type': 'dynamicString', 'content': '处理中...', 'i18n': False},
'style': 'embed',
'hideIcon': False,
},
'hidden': False,
'title': '',
'isLocked': False,
'condition': True,
'conditionGroup': '',
}
],
}
done_state = {
'componentName': 'AICardStatusContainer',
'id': 'node_status_done',
'props': {
'status': 3,
'marginLeft': 0,
'marginRight': 0,
'marginTop': 0,
'marginBottom': 0,
'enableExtend': False,
'autoFoldConfig': {
'needFold': True,
'heightLimit': 480,
'foldStatusLocalDataKey': '_cardFoldStatusLocalDataKey',
},
'innerOffset': 0,
'enableCollapse': False,
'margin': -2,
},
'title': '完成状态',
'hidden': False,
'isLocked': False,
'condition': True,
'conditionGroup': '',
'children': [
{
'componentName': 'AICardContent',
'id': 'node_done_content',
'props': {
'marginLeft': 0,
'marginRight': 0,
'marginTop': 0,
'marginBottom': 0,
'visible': {
'type': 'dynamicVisible',
'value': True,
'valueType': 'fixed',
'condition': {'op': 'and', 'conditions': []},
},
'innerOffset': 0,
'disabledWhileForward': False,
'statPoint': {'type': 'dynamicString', 'content': '', 'i18n': False},
'statPointParams': [
{'type': 'fixed', 'variable': '', 'value': '', 'name': '', 'variableType': 'global', 'id': '1'}
],
'margin': -2,
'transformToEventChain': False,
'enableStatPoint': False,
},
'hidden': False,
'title': '',
'isLocked': False,
'condition': True,
'conditionGroup': '',
'children': [
markdown_block('node_text_content', variable='content'),
button_group('node_btn_group'),
],
}
],
}
failed_state = {
'componentName': 'AICardStatusContainer',
'id': 'node_status_failed',
'props': {
'status': 5,
'marginLeft': 0,
'marginRight': 0,
'marginTop': 0,
'marginBottom': 0,
'enableExtend': False,
'autoFoldConfig': {
'needFold': True,
'heightLimit': 480,
'foldStatusLocalDataKey': '_cardFoldStatusLocalDataKey',
},
'innerOffset': 0,
'enableCollapse': False,
'margin': -2,
},
'title': '失败状态',
'hidden': False,
'isLocked': False,
'condition': True,
'conditionGroup': '',
'children': [
{
'componentName': 'AICardContent',
'id': 'node_failed_content',
'props': {
'visible': {
'type': 'dynamicVisible',
'value': True,
'valueType': 'fixed',
'condition': {'op': 'and', 'conditions': []},
},
'marginLeft': 0,
'marginRight': 0,
'marginTop': 0,
'marginBottom': 0,
'innerOffset': 0,
'disabledWhileForward': False,
'statPoint': {'type': 'dynamicString', 'content': '', 'i18n': False},
'statPointParams': [
{'type': 'fixed', 'variable': '', 'value': '', 'name': '', 'variableType': 'global', 'id': '1'}
],
'margin': -2,
'transformToEventChain': False,
'enableStatPoint': False,
},
'hidden': False,
'title': '',
'isLocked': False,
'condition': True,
'conditionGroup': '',
'children': [
text_block(
'node_failed_text',
'操作失败,请稍后重试。',
gravity='center',
mt=10,
mb=10,
ml=10,
mr=10,
max_lines=2,
font_size=15,
)
],
}
],
}
root = {
'componentName': 'AICardContainer',
'id': 'node_root',
'props': {
'marginLeft': 0,
'marginRight': 0,
'marginTop': 0,
'marginBottom': 0,
'enablePending': True,
'enableWriting': False,
'enableDoing': False,
'enableFailed': True,
'summaryContent': {'type': 'variableValue', 'variableType': 'global', 'variable': ''},
'enableTitle': False,
'flowStatusVar': {'type': 'variableValue', 'variableType': 'global', 'variable': 'flowStatus'},
'operationPenalType': 'custom',
'enableFlowAbort': True,
'innerOffset': 0,
'enableGradientBorder': True,
'cardSizeMode': 'adaptive',
'cardSizeHeightMode': 'adaptive',
'cardSizeWidthMode': 'adaptive',
'cardSizeHeight': {
'type': 'dynamicNumber',
'valueType': 'fixed',
'value': 226,
'variable': '',
'variableType': 'global',
},
'hasBackground': False,
'backgroundType': 'Standard',
'standardBackgroundColor': 'gray',
'backgroundColor': '#F6F6F6',
'darkModeBackgroundColor': '#3C3C3C',
'enableEngineUpgrade': False,
'enableExposeStatPoint': False,
'enableDebugTool': False,
},
'hidden': False,
'title': '',
'isLocked': False,
'condition': True,
'conditionGroup': '',
'children': [pending_state, done_state, failed_state],
}
btns_var = {
'name': 'btns',
'private': False,
'type': 'buttonGroup',
'id': 'btns',
'description': '动态按钮列表Dify actions',
'editorVarType': 'variables',
'disabled': False,
'schema': [
{
'id': 'btns.text',
'type': 'string',
'name': 'text',
'private': False,
'editorVarType': 'variables',
'disabled': True,
'description': '按钮文案',
},
{
'id': 'btns.color',
'type': 'string',
'name': 'color',
'private': False,
'editorVarType': 'variables',
'disabled': True,
'description': '按钮颜色',
},
{
'id': 'btns.status',
'type': 'string',
'name': 'status',
'private': False,
'editorVarType': 'variables',
'disabled': True,
'description': '按钮状态',
},
{
'id': 'btns.event',
'type': 'dynamicEvent',
'name': 'event',
'private': False,
'editorVarType': 'variables',
'disabled': True,
'description': '按钮点击事件',
'schema': [
{
'id': 'btns.type',
'type': 'string',
'name': 'type',
'private': False,
'editorVarType': 'variables',
'disabled': True,
'description': '事件类型openLink / sendCardRequest',
},
{
'id': 'btns.params',
'type': 'object',
'name': 'params',
'private': False,
'editorVarType': 'variables',
'disabled': True,
'description': '事件参数',
'schema': [
{
'id': 'btns.url',
'type': 'string',
'name': 'url',
'private': False,
'editorVarType': 'variables',
'disabled': True,
'description': '点击跳转链接type=openLink',
},
{
'id': 'btns.actionId',
'type': 'string',
'name': 'actionId',
'private': False,
'editorVarType': 'variables',
'disabled': True,
'description': '回传请求 idtype=sendCardRequest',
},
{
'id': 'btns.params',
'type': 'object',
'name': 'params',
'private': False,
'editorVarType': 'variables',
'disabled': True,
'description': '回传请求参数type=sendCardRequest',
},
],
},
],
},
],
}
return {
'schemaVersion': '3.0.0',
'schema': {
'config': {'update_multi': True, 'streaming_mode': True},
'componentsMap': components_map,
'componentsTree': [root],
'i18n': {},
'version': '1.0.0',
},
'mockData': {
'cardData': {
'flowStatus': 3,
'content': '请审核以下报销申请:\n\n- 申请人:张三\n- 金额¥1,200\n- 类别:差旅',
'btns': [
{
'text': '通过',
'color': 'blue',
'status': 'normal',
'event': {
'type': 'sendCardRequest',
'params': {'actionId': 'approve', 'params': {'action_id': 'approve'}},
},
},
{
'text': '驳回',
'color': 'gray',
'status': 'normal',
'event': {
'type': 'sendCardRequest',
'params': {'actionId': 'reject', 'params': {'action_id': 'reject'}},
},
},
{
'text': '补充资料',
'color': 'gray',
'status': 'normal',
'event': {
'type': 'sendCardRequest',
'params': {'actionId': 'more_info', 'params': {'action_id': 'more_info'}},
},
},
],
},
'cardPrivateData': {},
'localData': {'flowStatus': '', '_cardFoldStatusLocalDataKey': ''},
'richTextData': {},
},
'renderContext': {'regenerateEnabled': '1', 'regenerateIndex': '2', 'regenerateTotal': '5'},
'editVersion': 0,
'customWidgetInfo': '',
'useCustomWidgetInfo': False,
'variableList': [
{
'id': 'content',
'type': 'markdown',
'name': 'content',
'description': '人工输入提示词Dify form_content 含可选 node_title 前缀)',
'private': False,
'editorVarType': 'variables',
'disabled': False,
},
{
'id': 'flowStatus',
'type': 'string',
'name': 'flowStatus',
'description': 'AI卡片状态pending(1)、writing(2)、done(3)、failed(5)',
'private': False,
'editorVarType': 'variables',
'disabled': True,
'visible': False,
},
btns_var,
],
'formList': [],
'customContextList': [],
'expList': [],
'localList': [],
'hsfList': [],
'lwpList': [],
'pageData': {},
'extension': {'extendType': 'AI', 'aiStatusList': [3, 1, 5], 'fileTypeList': []},
}
def main():
editor_data = build_editor_data()
wrapper = {
'editorData': json.dumps(editor_data, ensure_ascii=False, separators=(',', ':')),
'widgetInfo': '',
'type': 'im',
'mode': 'card',
}
OUTPUT.write_text(json.dumps(wrapper, ensure_ascii=False, indent=2), encoding='utf-8')
print(f'wrote {OUTPUT}')
if __name__ == '__main__':
main()

View File

@@ -1,65 +0,0 @@
#!/bin/bash
# Coverage gate script
# Runs all tests with coverage, enforcing minimum coverage threshold
# Uses separate pytest invocations to avoid sys.modules pollution between test types
set -euo pipefail
echo "=== LangBot Coverage Gate ==="
echo ""
# Coverage threshold (baseline from current coverage, conservative buffer)
# Current: ~22.14%, threshold: 18%
COVERAGE_THRESHOLD=18
# Create temporary directory for coverage files
COV_DIR=$(mktemp -d)
trap "rm -rf $COV_DIR" EXIT
echo "[1/3] Running unit + smoke tests with coverage..."
uv run pytest tests/unit_tests/ tests/smoke/ \
--cov=langbot \
--cov-report=json:$COV_DIR/unit.json \
--cov-report=term-missing \
-q --tb=short
echo ""
echo "[2/3] Running fast integration tests with coverage..."
uv run pytest tests/integration/ -m "not slow" \
--cov=langbot \
--cov-report=json:$COV_DIR/integration.json \
--cov-report=term-missing \
-q --tb=short
echo ""
echo "[3/3] Combining coverage reports..."
# Use coverage combine if available, otherwise just report total
if command -v coverage &> /dev/null; then
# Combine JSON reports
coverage combine --keep $COV_DIR/unit.json $COV_DIR/integration.json \
--data-file=$COV_DIR/combined.data 2>/dev/null || true
coverage report --data-file=$COV_DIR/combined.data || true
else
echo "Note: coverage combine not available, showing individual reports above"
fi
# Generate final XML report for CI (from last run)
uv run pytest tests/unit_tests/ tests/smoke/ \
--cov=langbot \
--cov-report=xml:coverage.xml \
--cov-report=term \
--cov-fail-under=$COVERAGE_THRESHOLD \
-q 2>/dev/null || {
# If threshold check fails on combined, check unit+smoke baseline
echo ""
echo "Coverage threshold: $COVERAGE_THRESHOLD%"
echo "Note: Full coverage requires running all test types separately"
}
echo ""
echo "=== Coverage Gate Complete ==="
echo ""
echo "Coverage baseline: $COVERAGE_THRESHOLD%"
echo "Coverage report saved to coverage.xml"

View File

@@ -1,16 +0,0 @@
#!/bin/bash
# Fast integration tests
# Runs integration tests excluding slow ones (PostgreSQL, external services)
# Uses fake runner/provider, no real credentials needed
set -euo pipefail
echo "=== LangBot Fast Integration Tests ==="
echo ""
echo "Running integration tests (excluding slow)..."
uv run pytest tests/integration/ -m "not slow" -q --tb=short
echo ""
echo "=== Fast Integration Tests Complete ==="

View File

@@ -1,36 +0,0 @@
#!/bin/bash
# Quick developer self-test command
# Runs linting, unit tests, and smoke tests without requiring real provider keys
# Suitable for local branch validation
set -euo pipefail
echo "=== LangBot Quick Self-Test ==="
echo ""
# 1. Ruff check
echo "[1/3] Running ruff check..."
uv run ruff check src/langbot/ tests/ --output-format=concise || {
echo ""
echo "⚠ Ruff check found issues. Run 'uv run ruff check --fix' to auto-fix."
exit 1
}
echo "✓ Ruff check passed"
echo ""
# 2. Unit tests
echo "[2/3] Running unit tests..."
uv run pytest tests/unit_tests/ -q --tb=short
echo ""
# 3. Smoke tests (if exists)
echo "[3/3] Running smoke tests..."
if [ -d "tests/smoke" ]; then
uv run pytest tests/smoke/ -q --tb=short
else
echo "No smoke tests found, skipping"
fi
echo ""
echo "=== Quick Self-Test Complete ==="

View File

@@ -1,3 +1,3 @@
"""LangBot - Production-grade platform for building agentic IM bots"""
__version__ = '4.9.7'
__version__ = '4.9.6'

View File

@@ -109,61 +109,6 @@ class AsyncDifyServiceClient:
if chunk.startswith('data:'):
yield json.loads(chunk[5:])
async def workflow_submit(
self,
form_token: str,
workflow_run_id: str,
inputs: dict[str, typing.Any],
user: str,
action: str = '',
timeout: float = 120.0,
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
"""Submit human input to resume a paused workflow, then stream events.
1. POST /form/human_input/{form_token} to submit the form
2. GET /workflow/{task_id}/events to stream the resumed workflow events
"""
headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
}
async with httpx.AsyncClient(
base_url=self.base_url,
trust_env=True,
timeout=timeout,
) as client:
# Step 1: Submit the form
payload: dict[str, typing.Any] = {
'inputs': inputs if isinstance(inputs, dict) else {},
'user': user,
'action': action,
}
submit_resp = await client.post(
f'/form/human_input/{form_token}',
headers=headers,
json=payload,
)
if submit_resp.status_code != 200:
raise DifyAPIError(f'{submit_resp.status_code} {submit_resp.text}')
# Step 2: Stream resumed workflow events
async with client.stream(
'GET',
f'/workflow/{workflow_run_id}/events',
headers={'Authorization': f'Bearer {self.api_key}'},
params={'user': user},
) as r:
async for chunk in r.aiter_lines():
if r.status_code != 200:
raise DifyAPIError(f'{r.status_code} {chunk}')
if chunk.strip() == '':
continue
if chunk.startswith('data:'):
yield json.loads(chunk[5:])
async def upload_file(
self,
file: httpx._types.FileTypes,

View File

@@ -1,26 +1,17 @@
import asyncio
import base64
import json
import logging
import time
import uuid
import urllib.parse
from typing import Awaitable, Callable, Optional
from typing import Callable
import dingtalk_stream # type: ignore
import websockets
from .EchoHandler import EchoTextHandler
from .card_callback import DingTalkCardActionHandler
from .dingtalkevent import DingTalkEvent
import httpx
import traceback
_stdout_logger = logging.getLogger('langbot.dingtalk_api')
DINGTALK_OPENAPI_BASE = 'https://api.dingtalk.com'
class DingTalkClient:
def __init__(
self,
@@ -30,7 +21,6 @@ class DingTalkClient:
robot_code: str,
markdown_card: bool,
logger: None,
card_action_callback: Optional[Callable[[dict], Awaitable[None]]] = None,
):
"""初始化 WebSocket 连接并自动启动"""
self.credential = dingtalk_stream.Credential(client_id, client_secret)
@@ -40,14 +30,6 @@ class DingTalkClient:
# 在 DingTalkClient 中传入自己作为参数,避免循环导入
self.EchoTextHandler = EchoTextHandler(self)
self.client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler)
# STREAM-mode card action button click handler. Forwards parsed payload
# to the adapter so it can resume paused Dify workflows.
self.card_action_callback = card_action_callback
self.card_action_handler = DingTalkCardActionHandler(self.client, self._on_card_action)
self.client.register_callback_handler(
dingtalk_stream.handlers.CallbackHandler.TOPIC_CARD_CALLBACK,
self.card_action_handler,
)
self._message_handlers = {
'example': [],
}
@@ -59,16 +41,6 @@ class DingTalkClient:
self.logger = logger
self._stopped = False # Flag to control the event loop
async def _on_card_action(self, payload: dict) -> None:
"""Dispatch a parsed card-action payload to the adapter callback."""
if self.card_action_callback is None:
return
try:
await self.card_action_callback(payload)
except Exception:
if self.logger:
await self.logger.error(f'DingTalk card action callback error: {traceback.format_exc()}')
async def get_access_token(self):
url = 'https://api.dingtalk.com/v1.0/oauth2/accessToken'
headers = {'Content-Type': 'application/json'}
@@ -457,35 +429,18 @@ class DingTalkClient:
'Content-Type': 'application/json',
}
# For enterprise-internal robots, robotCode == AppKey (client_id).
# The dedicated robot_code field is only required for scenario-group
# robots or third-party robots; fall back to client_id when empty so
# the common single-bot setup keeps working without manual config.
robot_code = self.robot_code or self.key
data = {
'robotCode': robot_code,
'robotCode': self.robot_code,
'userIds': [target_id],
'msgKey': 'sampleText',
'msgParam': json.dumps({'content': content}),
}
_stdout_logger.info(
'DingTalk send_proactive_message_to_one request: robotCode=%s target_id=%s content_len=%d',
robot_code,
target_id,
len(content),
)
try:
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, json=data)
_stdout_logger.info(
'DingTalk send_proactive_message_to_one response: status=%d body=%s',
response.status_code,
response.text[:500],
)
if response.status_code == 200:
return
except Exception:
_stdout_logger.exception('DingTalk send_proactive_message_to_one error')
await self.logger.error(f'failed to send proactive massage to person: {traceback.format_exc()}')
raise Exception(f'failed to send proactive massage to person: {traceback.format_exc()}')
@@ -501,7 +456,7 @@ class DingTalkClient:
}
data = {
'robotCode': self.robot_code or self.key,
'robotCode': self.robot_code,
'openConversationId': target_id,
'msgKey': 'sampleText',
'msgParam': json.dumps({'content': content}),
@@ -522,244 +477,47 @@ class DingTalkClient:
quote_origin: bool = False,
card_auto_layout: bool = False,
):
"""Create + deliver the streaming chat card for a chatbot reply.
card_data = {}
card_data['config'] = json.dumps({'autoLayout': card_auto_layout})
card_data['content'] = ''
Replaces the old `dingtalk_stream.AICardReplier`-based path. Returns
`(None, out_track_id)` to keep call sites compatible with the
previous `(card_instance, card_instance_id)` shape — the first slot
is unused now that everything is driven by out_track_id.
"""
out_track_id = uuid.uuid4().hex
is_group = str(incoming_message.conversation_type) == '2'
if is_group:
open_space_id = f'dtv1.card//IM_GROUP.{incoming_message.conversation_id}'
else:
open_space_id = f'dtv1.card//IM_ROBOT.{incoming_message.sender_staff_id}'
card_param_map = {'content': ''}
# 将用户的消息内容作为卡片的查询参数,方便后续处理
if incoming_message.message_type == 'text':
card_param_map['query'] = incoming_message.get_text_list()[0]
card_data['query'] = incoming_message.get_text_list()[0]
else:
card_param_map['query'] = '...'
card_data['query'] = '...'
await self.create_and_deliver_card(
card_template_id=temp_card_id,
out_track_id=out_track_id,
open_space_id=open_space_id,
is_group=is_group,
card_param_map=card_param_map,
card_data_config={'autoLayout': card_auto_layout},
card_instance = dingtalk_stream.AICardReplier(self.client, incoming_message)
# print(card_instance)
# 先投放卡片: https://open.dingtalk.com/document/orgapp/create-and-deliver-cards
card_instance_id = await card_instance.async_create_and_deliver_card(
temp_card_id,
card_data,
)
return None, out_track_id
return card_instance, card_instance_id
async def send_card_message(self, card_instance, card_instance_id: str, content: str, is_final: bool):
"""Stream a single chunk into an existing card's `content` field."""
content_key = 'content'
try:
await self.streaming_update_card(
out_track_id=card_instance_id,
content_key='content',
await card_instance.async_streaming(
card_instance_id,
content_key=content_key,
content_value=content,
append=False,
finished=is_final,
failed=False,
)
except Exception as e:
if self.logger:
self.logger.exception(e)
await self.streaming_update_card(
out_track_id=card_instance_id,
content_key='content',
self.logger.exception(e)
await card_instance.async_streaming(
card_instance_id,
content_key=content_key,
content_value='',
append=False,
finished=is_final,
failed=True,
)
async def create_and_deliver_card(
self,
*,
card_template_id: str,
out_track_id: str,
open_space_id: str,
is_group: bool,
card_param_map: Optional[dict] = None,
callback_type: str = 'STREAM',
callback_route_key: Optional[str] = None,
support_forward: bool = True,
dynamic_data_source_configs: Optional[list] = None,
card_data_config: Optional[dict] = None,
at_user_ids: Optional[dict] = None,
recipients: Optional[list] = None,
) -> bool:
"""POST /v1.0/card/instances/createAndDeliver.
Mirrors the SDK's `async_create_and_deliver_card` shape but exposes
the dynamic-data-source config slot so we can register a pull URL
for variable-length button lists.
"""
if not await self.check_access_token():
await self.get_access_token()
cardData: dict = {'cardParamMap': card_param_map or {}}
if card_data_config is not None:
cardData['config'] = json.dumps(card_data_config)
body: dict = {
'cardTemplateId': card_template_id,
'outTrackId': out_track_id,
'cardData': cardData,
'callbackType': callback_type,
'openSpaceId': open_space_id,
'imGroupOpenSpaceModel': {'supportForward': support_forward},
'imRobotOpenSpaceModel': {'supportForward': support_forward},
}
if callback_type == 'HTTP' and callback_route_key:
body['callbackRouteKey'] = callback_route_key
if is_group:
deliver: dict = {'robotCode': self.robot_code or self.key}
if at_user_ids:
deliver['atUserIds'] = at_user_ids
if recipients is not None:
deliver['recipients'] = recipients
body['imGroupOpenDeliverModel'] = deliver
else:
body['imRobotOpenDeliverModel'] = {'spaceType': 'IM_ROBOT'}
if dynamic_data_source_configs:
body['openDynamicDataConfig'] = {'dynamicDataSourceConfigs': dynamic_data_source_configs}
url = f'{DINGTALK_OPENAPI_BASE}/v1.0/card/instances/createAndDeliver'
headers = {
'x-acs-dingtalk-access-token': self.access_token,
'Content-Type': 'application/json',
}
try:
_stdout_logger.info(
'DingTalk createAndDeliver request body: %s',
json.dumps(body, ensure_ascii=False)[:1500],
)
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, json=body, timeout=30.0)
if response.status_code == 200:
_stdout_logger.info(
'DingTalk createAndDeliver response: %s',
response.text[:500],
)
return True
_stdout_logger.error(
'DingTalk createAndDeliver failed: status=%s body=%s',
response.status_code,
response.text,
)
if self.logger:
await self.logger.error(
f'DingTalk createAndDeliver failed: status={response.status_code} body={response.text}'
)
return False
except Exception:
_stdout_logger.exception('DingTalk createAndDeliver error')
if self.logger:
await self.logger.error(f'DingTalk createAndDeliver error: {traceback.format_exc()}')
return False
async def streaming_update_card(
self,
*,
out_track_id: str,
content_key: str,
content_value: str,
append: bool,
finished: bool,
failed: bool = False,
) -> bool:
"""PUT /v1.0/card/streaming.
Replaces `dingtalk_stream.AICardReplier.async_streaming` — same body
shape (outTrackId / guid / key / content / isFull / isFinalize /
isError) per the SDK source.
"""
if not await self.check_access_token():
await self.get_access_token()
body = {
'outTrackId': out_track_id,
'guid': uuid.uuid4().hex,
'key': content_key,
'content': content_value,
'isFull': not append,
'isFinalize': finished,
'isError': failed,
}
url = f'{DINGTALK_OPENAPI_BASE}/v1.0/card/streaming'
headers = {
'x-acs-dingtalk-access-token': self.access_token,
'Content-Type': 'application/json',
}
try:
async with httpx.AsyncClient() as client:
response = await client.put(url, headers=headers, json=body, timeout=30.0)
if response.status_code == 200:
return True
if self.logger:
await self.logger.error(
f'DingTalk card streaming failed: status={response.status_code} body={response.text}'
)
return False
except Exception:
if self.logger:
await self.logger.error(f'DingTalk card streaming error: {traceback.format_exc()}')
return False
async def update_card_data(
self,
*,
out_track_id: str,
card_param_map: Optional[dict] = None,
private_data: Optional[dict] = None,
) -> bool:
"""PUT /v1.0/card/instances — non-streaming card content update."""
if not await self.check_access_token():
await self.get_access_token()
body: dict = {
'outTrackId': out_track_id,
'cardData': {'cardParamMap': card_param_map or {}},
}
if private_data:
body['privateData'] = private_data
url = f'{DINGTALK_OPENAPI_BASE}/v1.0/card/instances'
headers = {
'x-acs-dingtalk-access-token': self.access_token,
'Content-Type': 'application/json',
}
try:
_stdout_logger.info(
'DingTalk update_card_data request: out_track_id=%s body=%s',
out_track_id,
json.dumps(body, ensure_ascii=False)[:500],
)
async with httpx.AsyncClient() as client:
response = await client.put(url, headers=headers, json=body, timeout=30.0)
_stdout_logger.info(
'DingTalk update_card_data response: status=%d body=%s',
response.status_code,
response.text[:300],
)
if response.status_code == 200:
return True
if self.logger:
await self.logger.error(
f'DingTalk update card failed: status={response.status_code} body={response.text}'
)
return False
except Exception:
_stdout_logger.exception('DingTalk update_card_data error')
if self.logger:
await self.logger.error(f'DingTalk update card error: {traceback.format_exc()}')
return False
async def start(self):
"""启动 WebSocket 连接,监听消息"""
self._stopped = False

View File

@@ -1,96 +0,0 @@
"""STREAM-mode handler for DingTalk card action button clicks.
DingTalk delivers card-action callbacks over the same WebSocket stream used
for chatbot messages, under the topic `/v1.0/card/instances/callback`. This
module subclasses `dingtalk_stream.CallbackHandler` and forwards the parsed
payload to a coroutine the adapter registers, so the resume-paused-workflow
logic stays in the platform adapter where it belongs.
The `CardCallbackMessage` returned by `from_dict` exposes:
* `card_instance_id` (from `outTrackId`) — the card whose button was clicked
* `user_id` — the clicker's userId
* `content` — parsed JSON; the click params live here. Where exactly inside
`content` they sit depends on the template binding. We probe
the common paths.
* `extension` — parsed JSON; any extra data we set when delivering the card.
"""
from __future__ import annotations
from typing import Awaitable, Callable, Optional
import dingtalk_stream # type: ignore
from dingtalk_stream import AckMessage
from dingtalk_stream.card_callback import CardCallbackMessage
_PARAM_PATHS = (
('params',),
('cardPrivateData', 'params'),
('userPrivateData', 'params'),
)
def _extract_params(content: dict) -> dict:
"""Return the action params dict regardless of where the template put it."""
for path in _PARAM_PATHS:
node = content
for key in path:
if not isinstance(node, dict):
node = None
break
node = node.get(key)
if node is None:
break
if isinstance(node, dict) and node:
return node
return {}
class DingTalkCardActionHandler(dingtalk_stream.CallbackHandler):
def __init__(
self,
dingtalk_stream_client,
on_action: Optional[Callable[[dict], Awaitable[None]]] = None,
):
super().__init__()
self.dingtalk_client = dingtalk_stream_client
self.on_action = on_action
async def process(self, callback: dingtalk_stream.CallbackMessage):
try:
message = CardCallbackMessage.from_dict(callback.data)
params = _extract_params(message.content if isinstance(message.content, dict) else {})
# `CardCallbackMessage.from_dict` does not surface `actionId` (the
# top-level field that ButtonGroup's sendCardRequest event puts
# there). Pull it from the raw callback.data instead.
raw = callback.data if isinstance(callback.data, dict) else {}
action_id = raw.get('actionId') or ''
if not action_id:
# Some templates nest it under actionData / cardPrivateData.
action_data = raw.get('actionData') or {}
if isinstance(action_data, dict):
action_id = action_data.get('actionId') or action_id
if not action_id:
cpd = action_data.get('cardPrivateData') or {}
if isinstance(cpd, dict):
ids = cpd.get('actionIds')
if isinstance(ids, list) and ids:
action_id = str(ids[0])
payload = {
'out_track_id': message.card_instance_id,
'user_id': message.user_id,
'corp_id': message.corp_id,
'action_id': action_id,
'params': params,
'raw_content': message.content,
'extension': message.extension if isinstance(message.extension, dict) else {},
}
if self.on_action is not None:
await self.on_action(payload)
except Exception as e:
self.logger.error(f'DingTalkCardActionHandler.process error: {e}')
return AckMessage.STATUS_OK, 'OK'

View File

@@ -67,16 +67,6 @@ class StreamSession:
# 反馈 ID用于接收用户点赞/点踩反馈
feedback_id: Optional[str] = None
# Dify 人工输入暂停态runner 把 _form_data 传过来时填充。
# 一旦设置,下次企微 followup 请求时返回 button_interaction 模板卡
# 替代 stream chunk。点击按钮会回调 template_card_eventEventKey
# 就是 Dify 的 action_id。
pending_form: Optional[dict] = None
# template_card task_id企微要求 button_interaction 必填且不可重复)。
# 创建 pending_form 时生成;按钮点击回调里用来反查 session。
pending_form_task_id: Optional[str] = None
class StreamSessionManager:
"""管理 stream 会话的生命周期,并负责队列的生产消费。"""
@@ -93,9 +83,6 @@ class StreamSessionManager:
self._sessions: dict[str, StreamSession] = {} # stream_id -> StreamSession 映射
self._msg_index: dict[str, str] = {} # msgid -> stream_id 映射,便于流水线根据消息 ID 找到会话
self._feedback_index: dict[str, str] = {} # feedback_id -> stream_id 映射
# task_id (button_interaction template_card 的) -> stream_id 映射,
# 用于按钮点击回调里反查 pending_form。
self._task_index: dict[str, str] = {}
def get_stream_id_by_msg(self, msg_id: str) -> Optional[str]:
if not msg_id:
@@ -131,40 +118,6 @@ class StreamSessionManager:
if feedback_id and stream_id:
self._feedback_index[feedback_id] = stream_id
def set_pending_form(self, stream_id: str, form_data: dict, task_id: str) -> None:
"""把 Dify 人工输入暂停态绑定到 stream session。
下一次企微 followup 请求时adapter 检测到 pending_form
返回 button_interaction 模板卡而不是 stream chunk。
"""
session = self._sessions.get(stream_id)
if not session:
return
session.pending_form = form_data
session.pending_form_task_id = task_id
if task_id:
self._task_index[task_id] = stream_id
def get_session_by_task_id(self, task_id: str) -> Optional[StreamSession]:
"""按按钮点击回调里的 TaskId 反查 session。"""
if not task_id:
return None
stream_id = self._task_index.get(task_id)
if not stream_id:
return None
return self._sessions.get(stream_id)
def clear_pending_form(self, stream_id: str) -> None:
"""按钮点击消费完后清掉 pending_form避免重复弹卡。"""
session = self._sessions.get(stream_id)
if not session:
return
task_id = session.pending_form_task_id
session.pending_form = None
session.pending_form_task_id = None
if task_id:
self._task_index.pop(task_id, None)
def create_or_get(self, msg_json: dict[str, Any]) -> tuple[StreamSession, bool]:
"""根据企业微信回调创建或获取会话。
@@ -770,79 +723,6 @@ async def parse_wecom_bot_message(
return message_data
def build_button_interaction_payload(form_data: dict, task_id: str) -> dict[str, Any]:
"""Build a `template_card` (button_interaction) WeCom payload.
Shared by both the webhook-mode client (returns the payload as the
response to a stream-followup callback) and the ws_client (sends it
as a reply frame). Output shape is `{"msgtype": "template_card",
"template_card": {...}}` per the WeCom spec.
Args:
form_data: Dify human-input form data with keys ``actions`` (list of
``{id, title, button_style}``), ``node_title``, ``form_content``.
task_id: Unique per-card identifier. WeCom requires this for
button_interaction. The click callback returns it as TaskId so we
can find the originating session.
Notes:
* ``button.key`` is set directly to the Dify ``action_id``. The click
callback's ``EventKey`` carries this back unchanged (1024-byte limit
per the spec, far more than we ever need).
* WeCom caps the button list at 6. Extra actions are appended to
``sub_title_text`` so users can still reply with the id as text.
* Styles map ``primary``→1 (blue), ``danger``→2 (red), default→0
(gray). First button is auto-promoted to primary when no style.
"""
actions = list(form_data.get('actions') or [])
node_title = (form_data.get('node_title') or '').strip() or '人工介入'
form_content = (form_data.get('form_content') or '').strip()
visible_actions = actions[:6]
overflow = actions[6:]
sub_title_parts: list[str] = []
if form_content:
sub_title_parts.append(form_content)
if overflow:
extra_lines = [f' - {a.get("title") or a.get("id") or ""} (回复 id: {a.get("id") or ""})' for a in overflow]
sub_title_parts.append(f'另有 {len(overflow)} 个选项不在按钮列表中,可直接回复 id\n' + '\n'.join(extra_lines))
sub_title_text = '\n\n'.join(sub_title_parts) or '请选择一个操作以继续。'
button_list = []
for idx, action in enumerate(visible_actions):
action_id = str(action.get('id') or '')
title = str(action.get('title') or action_id or f'选项 {idx + 1}')
style_raw = (action.get('button_style') or '').lower()
if style_raw == 'primary' or (style_raw == '' and idx == 0):
style = 1
elif style_raw == 'danger':
style = 2
else:
style = 0
button_list.append(
{
'text': title,
'style': style,
'key': action_id,
}
)
card = {
'card_type': 'button_interaction',
'main_title': {
'title': node_title,
},
'sub_title_text': sub_title_text,
'button_list': button_list,
'task_id': task_id,
}
return {
'msgtype': 'template_card',
'template_card': card,
}
class WecomBotClient:
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger, unified_mode: bool = False):
"""企业微信智能机器人客户端。
@@ -881,7 +761,6 @@ class WecomBotClient:
self.stream_poll_timeout = 0.5
self._feedback_callback: Optional[Callable] = None
self._card_action_callback: Optional[Callable] = None
def set_feedback_callback(self, callback: Callable) -> None:
"""设置反馈回调函数。
@@ -891,19 +770,6 @@ class WecomBotClient:
"""
self._feedback_callback = callback
def set_card_action_callback(self, callback: Callable) -> None:
"""设置按钮卡片点击回调函数。
Signature: ``async def callback(session, action_id, task_id, raw_event) -> None``
``session`` is the StreamSession the card was attached to;
``action_id`` is the Dify action_id reflected back via the
button's ``key`` field; ``task_id`` is the card's task_id
(matches ``session.pending_form_task_id``); ``raw_event`` is the
decoded callback JSON for any extra fields the adapter wants.
"""
self._card_action_callback = callback
@staticmethod
def _build_stream_payload(
stream_id: str, content: str, finish: bool, feedback_id: Optional[str] = None
@@ -934,12 +800,6 @@ class WecomBotClient:
'stream': stream_payload,
}
@staticmethod
def _build_button_interaction_payload(form_data: dict, task_id: str) -> dict[str, Any]:
"""Class-level shim — delegates to module-level builder so ws_client
can reuse the exact same payload shape without importing the class."""
return build_button_interaction_payload(form_data, task_id)
async def _encrypt_and_reply(self, payload: dict[str, Any], nonce: str) -> tuple[Response, int]:
"""对响应进行加密封装并返回给企业微信。
@@ -1032,22 +892,6 @@ class WecomBotClient:
return await self._encrypt_and_reply(self._build_stream_payload('', '', True), nonce)
session = self.stream_sessions.get_session(stream_id)
# If a Dify human-input pause arrived during this stream, switch
# the response from `msgtype: stream` to `msgtype: template_card`
# (button_interaction). The session's stream is also marked
# finished so future followups aren't expected (assuming the
# WeCom client treats template_card as the terminal response —
# we'll know from the next callback whether it kept polling).
if session and session.pending_form and session.pending_form_task_id:
await self.logger.info(
f'WeComBot: returning button_interaction for stream_id={stream_id} '
f'task_id={session.pending_form_task_id} actions={len(session.pending_form.get("actions") or [])}'
)
card_payload = self._build_button_interaction_payload(session.pending_form, session.pending_form_task_id)
self.stream_sessions.mark_finished(stream_id)
return await self._encrypt_and_reply(card_payload, nonce)
chunk = await self.stream_sessions.consume(stream_id, timeout=self.stream_poll_timeout)
if not chunk:
@@ -1156,50 +1000,11 @@ class WecomBotClient:
if event_type == 'feedback_event':
return await self._handle_feedback_event(msg_json, nonce)
# Button click on a button_interaction template_card. The WeCom doc
# calls this `template_card_event`; some routes wrap the button
# event payload inside `event.template_card_event`.
if event_type == 'template_card_event':
return await self._handle_template_card_event(msg_json, nonce)
if msg_json.get('msgtype') == 'stream':
return await self._handle_post_followup_response(msg_json, nonce)
return await self._handle_post_initial_response(msg_json, nonce)
async def _handle_template_card_event(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
"""Handle a button click on a button_interaction template_card.
WeCom carries the click info in ``event.template_card_event`` with
``TaskId`` matching the card we created and ``EventKey`` carrying
the button's ``key`` (which we set to the Dify ``action_id``).
"""
try:
tce = msg_json.get('event', {}).get('template_card_event', {})
task_id = tce.get('TaskId') or tce.get('task_id') or ''
event_key = tce.get('EventKey') or tce.get('event_key') or ''
card_type = tce.get('CardType') or tce.get('card_type') or ''
await self.logger.info(f'收到按钮点击: task_id={task_id} event_key={event_key!r} card_type={card_type}')
session = self.stream_sessions.get_session_by_task_id(task_id)
if session is None:
await self.logger.warning(f'未找到 task_id={task_id} 对应的 session按钮点击被丢弃')
else:
if self._card_action_callback is not None:
try:
await self._card_action_callback(session, event_key, task_id, msg_json)
except Exception:
await self.logger.error(f'card action callback raised: {traceback.format_exc()}')
# Drop the form so a fresh chunk/followup doesn't re-render
# the same card (and so the task_id can be GC'd).
self.stream_sessions.clear_pending_form(session.stream_id)
except Exception:
await self.logger.error(f'_handle_template_card_event error: {traceback.format_exc()}')
# WeCom expects an empty success ack for event callbacks.
return await self._encrypt_and_reply({}, nonce)
async def _handle_feedback_event(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
"""处理企业微信用户反馈事件(点赞/点踩)。
@@ -1309,29 +1114,6 @@ class WecomBotClient:
self.stream_sessions.mark_finished(stream_id)
return True
async def push_form_pause(
self, msg_id: str, form_data: dict, task_id: Optional[str] = None
) -> tuple[bool, Optional[str], Optional[str]]:
"""Attach a Dify human-input pause to the active stream session.
On the next WeCom followup poll, the response switches from
``msgtype: stream`` to ``msgtype: template_card`` (button_interaction)
carrying the buttons. ``task_id`` is auto-generated if not provided
and is what the button-click callback uses to look the session back up.
Returns:
``(ok, stream_id, task_id)``. ``ok`` is False if the
adapter's msg_id maps to no stream session (e.g. non-stream mode).
"""
stream_id = self.stream_sessions.get_stream_id_by_msg(msg_id)
if not stream_id:
return False, None, None
if not task_id:
# WeCom requires task_id [A-Za-z0-9_-@], <= 128 bytes, unique per bot.
task_id = f'dify-{uuid.uuid4().hex[:24]}'
self.stream_sessions.set_pending_form(stream_id, form_data, task_id)
return True, stream_id, task_id
async def set_message(self, msg_id: str, content: str):
"""兼容旧逻辑:若无法流式返回则缓存最终结果。

View File

@@ -20,11 +20,7 @@ from typing import Any, Callable, Optional
import aiohttp
from langbot.libs.wecom_ai_bot_api import wecombotevent
from langbot.libs.wecom_ai_bot_api.api import (
parse_wecom_bot_message,
StreamSession,
build_button_interaction_payload,
)
from langbot.libs.wecom_ai_bot_api.api import parse_wecom_bot_message, StreamSession
from langbot.pkg.platform.logger import EventLogger
DEFAULT_WS_URL = 'wss://openws.work.weixin.qq.com'
@@ -107,18 +103,6 @@ class WecomBotWsClient:
# msg_id -> feedback_id (for associating feedback with message)
self._msg_feedback_ids: dict[str, str] = {} # msg_id -> feedback_id
# Dify human-input pause state for ws mode. Keys are task_id (echoed
# back in template_card_event.TaskId so we can rebuild the session
# context on click).
# task_id -> {form_data, msg_id, user_id, chat_id, stream_id, req_id}
self._pending_forms_by_task: dict[str, dict] = {}
# Reverse: msg_id -> task_id (for cleanup when stream finishes).
self._task_id_by_msg: dict[str, str] = {}
# Optional card-action callback registered by the adapter.
# Signature mirrors the http-mode WecomBotClient:
# async def callback(session, action_id, task_id, raw_event) -> None
self._card_action_callback: Optional[Callable] = None
# ── Public API ──────────────────────────────────────────────────
async def connect(self):
@@ -252,83 +236,6 @@ class WecomBotWsClient:
}
return await self._send_reply(req_id, body)
async def reply_template_card(self, req_id: str, card_payload: dict[str, Any]) -> Optional[dict]:
"""Send a template_card (button_interaction etc.) reply.
Args:
req_id: The req_id from the original message frame.
card_payload: Body produced by ``build_button_interaction_payload``;
must contain ``msgtype`` and ``template_card`` keys.
Returns:
ACK frame dict, or None on failure.
"""
return await self._send_reply(req_id, card_payload)
def set_card_action_callback(self, callback: Callable) -> None:
"""Register the button-click handler.
``async def callback(session, action_id, task_id, raw_event) -> None``
— same signature as the http-mode WecomBotClient version so the
adapter can hand both off to the same coroutine.
"""
self._card_action_callback = callback
async def push_form_pause(
self, msg_id: str, form_data: dict, task_id: Optional[str] = None
) -> tuple[bool, Optional[str], Optional[str]]:
"""Attach a Dify human-input pause to the active stream and send
the button_interaction card immediately.
ws mode has no notion of polled "followup" responses — each reply
is a one-shot frame send. So unlike the http path (which defers
card delivery to the next followup), here we just craft the card
and reply with it on the original req_id. The corresponding stream
session is then torn down so subsequent chunks don't re-send.
Returns:
``(ok, stream_id, task_id)``. ``ok=False`` if no active stream
for this msg_id (e.g. message arrived in non-stream mode).
"""
key = self._stream_ids.get(msg_id)
if not key:
return False, None, None
req_id, stream_id = key.split('|', 1)
if not task_id:
task_id = f'dify-{secrets.token_hex(12)}'
session_info = self._stream_sessions.get(msg_id) or {}
self._pending_forms_by_task[task_id] = {
'form_data': form_data,
'msg_id': msg_id,
'user_id': session_info.get('user_id', ''),
'chat_id': session_info.get('chat_id', ''),
'stream_id': stream_id,
'req_id': req_id,
}
self._task_id_by_msg[msg_id] = task_id
card_payload = build_button_interaction_payload(form_data, task_id)
try:
await self.reply_template_card(req_id, card_payload)
except Exception:
await self.logger.error(f'Failed to send button_interaction card: {traceback.format_exc()}')
# Roll back the bookkeeping so the next attempt isn't blocked.
self._pending_forms_by_task.pop(task_id, None)
self._task_id_by_msg.pop(msg_id, None)
return False, stream_id, None
# Tear down the stream — WeCom expects either stream chunks OR a
# template_card, not both on the same req_id. Subsequent
# push_stream_chunk calls for this msg_id become no-ops.
self._stream_ids.pop(msg_id, None)
self._stream_last_content.pop(msg_id, None)
# Keep _stream_sessions so the button callback can still resolve
# user/chat context; it gets cleaned up when the click fires.
return True, stream_id, task_id
async def send_message(self, chat_id: str, content: str, msgtype: str = 'markdown') -> Optional[dict]:
"""Proactively send a message to a specified chat.
@@ -351,23 +258,6 @@ class WecomBotWsClient:
body['text'] = {'content': content}
return await self._send_reply(req_id, body, cmd=CMD_SEND_MSG)
async def send_template_card(self, chat_id: str, card_payload: dict[str, Any]) -> Optional[dict]:
"""Proactively push a template_card to a chat.
Used for the resumed-workflow path (button click → new query):
synthetic events have no inbound req_id to reply against, so we
fall back to proactive ``aibot_send_msg`` instead of reply mode.
Args:
chat_id: userid (single chat) or chatid (group chat).
card_payload: ``{"msgtype": "template_card", "template_card": {...}}``
as produced by :func:`build_button_interaction_payload`.
"""
req_id = _generate_req_id(CMD_SEND_MSG)
body = dict(card_payload)
body['chatid'] = chat_id
return await self._send_reply(req_id, body, cmd=CMD_SEND_MSG)
async def push_stream_chunk(self, msg_id: str, content: str, is_final: bool = False) -> bool:
"""Push a streaming chunk for a given message ID.
@@ -678,38 +568,6 @@ class WecomBotWsClient:
await self.logger.error(f'Error in feedback handler: {traceback.format_exc()}')
return
if event_type == 'template_card_event':
tce = event_info.get('template_card_event', {})
task_id = tce.get('TaskId') or tce.get('task_id') or ''
event_key = tce.get('EventKey') or tce.get('event_key') or ''
card_type = tce.get('CardType') or tce.get('card_type') or ''
await self.logger.info(
f'收到按钮点击 (ws): task_id={task_id} event_key={event_key!r} card_type={card_type}'
)
pending = self._pending_forms_by_task.get(task_id)
if pending is None:
await self.logger.warning(f'未找到 task_id={task_id} 对应的 pending_form (ws),按钮点击被丢弃')
elif self._card_action_callback is not None:
try:
session = StreamSession(
stream_id=pending.get('stream_id', ''),
msg_id=pending.get('msg_id', ''),
chat_id=pending.get('chat_id') or None,
user_id=pending.get('user_id') or None,
)
session.pending_form = pending.get('form_data')
session.pending_form_task_id = task_id
await self._card_action_callback(session, event_key, task_id, body)
except Exception:
await self.logger.error(f'card action callback raised (ws): {traceback.format_exc()}')
# Consume — drop bookkeeping so a stale click can't re-fire.
self._pending_forms_by_task.pop(task_id, None)
msg_id = pending.get('msg_id', '')
if msg_id:
self._task_id_by_msg.pop(msg_id, None)
self._stream_sessions.pop(msg_id, None)
return
event = wecombotevent.WecomBotEvent(message_data)
if event_type in self._message_handlers:

View File

@@ -1,6 +1,5 @@
import quart
import mimetypes
import asyncio
from ... import group
from langbot.pkg.utils import importutil
@@ -36,617 +35,3 @@ class AdaptersRouterGroup(group.RouterGroup):
return quart.Response(
importutil.read_resource_file_bytes(icon_path), mimetype=mimetypes.guess_type(icon_path)[0]
)
# In-memory session store for active registrations
_create_app_sessions: dict = {}
_SESSION_TTL = 900 # 15 minutes
def _cleanup_expired_sessions():
"""Remove sessions that have exceeded their TTL."""
import time
now = time.time()
expired = [sid for sid, s in _create_app_sessions.items() if now - s.get('created_at', 0) > _SESSION_TTL]
for sid in expired:
session = _create_app_sessions.pop(sid, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
@self.route('/lark/create-app', methods=['POST'])
async def _() -> str:
"""Start Feishu one-click app registration. Returns session_id + QR code URL."""
import uuid
import time
import lark_oapi as lark
from lark_oapi.scene.registration.errors import AppAccessDeniedError, AppExpiredError
_cleanup_expired_sessions()
session_id = str(uuid.uuid4())
loop = asyncio.get_running_loop()
session = {
'status': 'pending',
'qr_url': None,
'expire_at': None,
'app_id': None,
'app_secret': None,
'error': None,
'created_at': time.time(),
}
_create_app_sessions[session_id] = session
def on_qr_code(info):
# May be called from a background thread by the SDK;
# use call_soon_threadsafe to safely update session state.
def _update():
session['qr_url'] = info['url']
session['expire_at'] = time.time() + 600 # 10 minutes
session['status'] = 'waiting'
loop.call_soon_threadsafe(_update)
async def run_registration():
try:
result = await lark.aregister_app(
on_qr_code=on_qr_code,
source='langbot',
)
session['status'] = 'success'
session['app_id'] = result['client_id']
session['app_secret'] = result['client_secret']
except AppAccessDeniedError:
session['status'] = 'error'
session['error'] = 'User denied authorization'
except AppExpiredError:
session['status'] = 'error'
session['error'] = 'QR code expired'
except Exception as e:
session['status'] = 'error'
session['error'] = str(e)
task = asyncio.create_task(run_registration())
session['task'] = task
# Wait for QR code to be ready (max 10 seconds)
for _ in range(20):
if session['qr_url']:
break
await asyncio.sleep(0.5)
if not session['qr_url']:
task.cancel()
session['status'] = 'error'
session['error'] = 'Timeout waiting for QR code'
return self.http_status(504, -1, 'Timeout waiting for QR code')
return self.success(
data={
'session_id': session_id,
'qr_url': session['qr_url'],
'expire_at': session['expire_at'],
}
)
@self.route('/lark/create-app/status/<session_id>', methods=['GET'])
async def _(session_id: str) -> str:
"""Poll registration status."""
session = _create_app_sessions.get(session_id)
if not session:
return self.http_status(404, -1, 'Session not found')
data = {'status': session['status']}
if session['status'] == 'success':
data['app_id'] = session['app_id']
data['app_secret'] = session['app_secret']
_create_app_sessions.pop(session_id, None)
elif session['status'] == 'error':
data['error'] = session['error']
_create_app_sessions.pop(session_id, None)
return self.success(data=data)
@self.route('/lark/create-app/<session_id>', methods=['DELETE'])
async def _(session_id: str) -> str:
"""Cancel and clean up a registration session."""
session = _create_app_sessions.pop(session_id, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
return self.success(data={})
# -----------------------------------------------------------------------
# WeChat QR Code Login
# -----------------------------------------------------------------------
_weixin_login_sessions: dict = {}
_WEIXIN_SESSION_TTL = 600 # 10 minutes (3 retries × 3 min QR validity)
def _cleanup_expired_weixin_sessions():
import time
now = time.time()
expired = [
sid for sid, s in _weixin_login_sessions.items() if now - s.get('created_at', 0) > _WEIXIN_SESSION_TTL
]
for sid in expired:
session = _weixin_login_sessions.pop(sid, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
@self.route('/weixin/login', methods=['POST'])
async def _() -> str:
"""Start WeChat QR code login. Returns session_id + QR code data URL."""
import uuid
import time
from langbot.libs.openclaw_weixin_api.client import OpenClawWeixinClient, DEFAULT_BASE_URL
_cleanup_expired_weixin_sessions()
session_id = str(uuid.uuid4())
loop = asyncio.get_running_loop()
session = {
'status': 'pending',
'qr_data_url': None,
'expire_at': None,
'token': None,
'base_url': None,
'account_id': None,
'error': None,
'created_at': time.time(),
}
_weixin_login_sessions[session_id] = session
client = OpenClawWeixinClient(
base_url=DEFAULT_BASE_URL,
token='',
)
async def run_login():
try:
def on_qrcode(qr_data_url: str, _qr_url: str):
def _update():
session['qr_data_url'] = qr_data_url
session['expire_at'] = time.time() + 180
session['status'] = 'waiting'
loop.call_soon_threadsafe(_update)
result = await client.login(
max_retries=1,
poll_timeout_ms=180_000,
on_qrcode=on_qrcode,
)
session['status'] = 'success'
session['token'] = result.token
session['base_url'] = result.base_url
session['account_id'] = result.account_id
except Exception as e:
error_message = str(e)
if 'expired' in error_message.lower() or 'max retries exceeded' in error_message.lower():
session['status'] = 'expired'
session['error'] = 'QR code expired'
else:
session['status'] = 'error'
session['error'] = error_message
finally:
await client.close()
task = asyncio.create_task(run_login())
session['task'] = task
# Wait for QR code to be ready (max 10 seconds)
for _ in range(20):
if session['qr_data_url']:
break
await asyncio.sleep(0.5)
if not session['qr_data_url']:
task.cancel()
session['status'] = 'error'
session['error'] = 'Timeout waiting for QR code'
return self.http_status(504, -1, 'Timeout waiting for QR code')
return self.success(
data={
'session_id': session_id,
'qr_data_url': session['qr_data_url'],
'expire_at': session['expire_at'],
}
)
@self.route('/weixin/login/status/<session_id>', methods=['GET'])
async def _(session_id: str) -> str:
"""Poll WeChat login status."""
session = _weixin_login_sessions.get(session_id)
if not session:
return self.http_status(404, -1, 'Session not found')
data = {
'status': session['status'],
'qr_data_url': session['qr_data_url'],
'expire_at': session['expire_at'],
}
if session['status'] == 'success':
data['token'] = session['token']
data['base_url'] = session['base_url']
data['account_id'] = session['account_id']
_weixin_login_sessions.pop(session_id, None)
elif session['status'] == 'error':
data['error'] = session['error']
_weixin_login_sessions.pop(session_id, None)
elif session['status'] == 'expired':
data['error'] = session['error']
_weixin_login_sessions.pop(session_id, None)
return self.success(data=data)
@self.route('/weixin/login/<session_id>', methods=['DELETE'])
async def _(session_id: str) -> str:
"""Cancel and clean up a WeChat login session."""
session = _weixin_login_sessions.pop(session_id, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
return self.success(data={})
# -----------------------------------------------------------------------
# DingTalk Device Flow QR Code Login
# -----------------------------------------------------------------------
_dingtalk_sessions: dict = {}
_DINGTALK_SESSION_TTL = 600 # 10 minutes (QR code validity window)
def _cleanup_expired_dingtalk_sessions():
import time
now = time.time()
expired = [
sid for sid, s in _dingtalk_sessions.items() if now - s.get('created_at', 0) > _DINGTALK_SESSION_TTL
]
for sid in expired:
session = _dingtalk_sessions.pop(sid, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
@self.route('/dingtalk/create-app', methods=['POST'])
async def _() -> str:
"""Start DingTalk one-click app creation via Device Flow. Returns session_id + QR code URL."""
import uuid
import time
import aiohttp
DINGTALK_BASE_URL = 'https://oapi.dingtalk.com'
_cleanup_expired_dingtalk_sessions()
session_id = str(uuid.uuid4())
session = {
'status': 'pending',
'qr_url': None,
'expire_at': None,
'client_id': None,
'client_secret': None,
'error': None,
'created_at': time.time(),
'device_code': None,
'interval': 5,
}
_dingtalk_sessions[session_id] = session
async def run_device_flow():
try:
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as http:
# Step 1: Init — get nonce
async with http.post(
f'{DINGTALK_BASE_URL}/app/registration/init',
json={'source': 'langbot'},
) as resp:
try:
data = await resp.json()
except (aiohttp.ContentTypeError, ValueError):
session['status'] = 'error'
session['error'] = 'Invalid response from DingTalk service'
return
if data.get('errcode', -1) != 0:
session['status'] = 'error'
session['error'] = data.get('errmsg', 'Failed to init')
return
nonce = data['nonce']
# Step 2: Begin — get device_code + QR URL
async with http.post(
f'{DINGTALK_BASE_URL}/app/registration/begin',
json={'nonce': nonce},
) as resp:
try:
data = await resp.json()
except (aiohttp.ContentTypeError, ValueError):
session['status'] = 'error'
session['error'] = 'Invalid response from DingTalk service'
return
if data.get('errcode', -1) != 0:
session['status'] = 'error'
session['error'] = data.get('errmsg', 'Failed to begin authorization')
return
device_code = data['device_code']
verification_uri_complete = data.get('verification_uri_complete', '')
expires_in = data.get('expires_in', 7200)
interval = data.get('interval', 5)
session['device_code'] = device_code
session['interval'] = interval
session['qr_url'] = verification_uri_complete
session['expire_at'] = time.time() + 600 # QR code valid for ~10 min
session['status'] = 'waiting'
# Step 3: Poll for authorization result
deadline = time.time() + expires_in
while time.time() < deadline:
await asyncio.sleep(interval)
async with http.post(
f'{DINGTALK_BASE_URL}/app/registration/poll',
json={'device_code': device_code},
) as poll_resp:
try:
poll_data = await poll_resp.json()
except (aiohttp.ContentTypeError, ValueError):
continue
if poll_data.get('errcode', -1) != 0:
session['status'] = 'error'
session['error'] = poll_data.get('errmsg', 'Poll failed')
return
status = poll_data.get('status', '')
if status == 'SUCCESS':
session['status'] = 'success'
session['client_id'] = poll_data.get('client_id', '')
session['client_secret'] = poll_data.get('client_secret', '')
return
elif status == 'FAIL':
session['status'] = 'error'
session['error'] = poll_data.get('fail_reason', 'Authorization failed')
return
elif status == 'EXPIRED':
session['status'] = 'error'
session['error'] = 'QR code expired'
return
# status == 'WAITING': continue polling
# Timeout
session['status'] = 'error'
session['error'] = 'QR code expired'
except asyncio.CancelledError:
return
except Exception as e:
session['status'] = 'error'
session['error'] = str(e)
task = asyncio.create_task(run_device_flow())
session['task'] = task
# Wait for QR code to be ready (max 10 seconds)
for _ in range(20):
if session['qr_url'] or session['error']:
break
await asyncio.sleep(0.5)
if session['error']:
task.cancel()
return self.http_status(502, -1, session['error'])
if not session['qr_url']:
task.cancel()
session['status'] = 'error'
session['error'] = 'Timeout waiting for QR code'
return self.http_status(504, -1, 'Timeout waiting for QR code')
return self.success(
data={
'session_id': session_id,
'qr_url': session['qr_url'],
'expire_at': session['expire_at'],
}
)
@self.route('/dingtalk/create-app/status/<session_id>', methods=['GET'])
async def _(session_id: str) -> str:
"""Poll DingTalk Device Flow status."""
_cleanup_expired_dingtalk_sessions()
session = _dingtalk_sessions.get(session_id)
if not session:
return self.http_status(404, -1, 'Session not found')
data = {'status': session['status']}
if session['status'] == 'success':
data['client_id'] = session['client_id']
data['client_secret'] = session['client_secret']
_dingtalk_sessions.pop(session_id, None)
elif session['status'] == 'error':
data['error'] = session['error']
_dingtalk_sessions.pop(session_id, None)
return self.success(data=data)
@self.route('/dingtalk/create-app/<session_id>', methods=['DELETE'])
async def _(session_id: str) -> str:
"""Cancel and clean up a DingTalk Device Flow session."""
session = _dingtalk_sessions.pop(session_id, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
return self.success(data={})
# -----------------------------------------------------------------------
# WeComBot QR Code One-Click Create
# -----------------------------------------------------------------------
_wecombot_sessions: dict = {}
_WECOMBOT_SESSION_TTL = 300 # 5 minutes (WeCom QR validity window)
def _cleanup_expired_wecombot_sessions():
import time
now = time.time()
expired = [
sid for sid, s in _wecombot_sessions.items() if now - s.get('created_at', 0) > _WECOMBOT_SESSION_TTL
]
for sid in expired:
session = _wecombot_sessions.pop(sid, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
@self.route('/wecombot/create-bot', methods=['POST'])
async def _() -> str:
"""Start WeComBot one-click creation via QR code. Returns session_id + QR code URL."""
import uuid
import time
import aiohttp
WECOM_QC_GENERATE_URL = 'https://work.weixin.qq.com/ai/qc/generate'
WECOM_QC_QUERY_URL = 'https://work.weixin.qq.com/ai/qc/query_result'
_cleanup_expired_wecombot_sessions()
session_id = str(uuid.uuid4())
session = {
'status': 'pending',
'qr_url': None,
'expire_at': None,
'botid': None,
'secret': None,
'error': None,
'created_at': time.time(),
'scode': None,
'task': None,
}
_wecombot_sessions[session_id] = session
async def run_qr_flow():
try:
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as http:
# Step 1: Generate QR code
async with http.get(
f'{WECOM_QC_GENERATE_URL}?source=langbot&plat=0',
) as resp:
try:
data = await resp.json()
except (aiohttp.ContentTypeError, ValueError):
session['status'] = 'error'
session['error'] = 'Invalid response from WeCom service'
return
if not data.get('data', {}).get('scode') or not data.get('data', {}).get('auth_url'):
session['status'] = 'error'
session['error'] = data.get('errmsg', 'Failed to generate QR code')
return
scode = data['data']['scode']
auth_url = data['data']['auth_url']
session['scode'] = scode
session['qr_url'] = auth_url
session['expire_at'] = time.time() + _WECOMBOT_SESSION_TTL
session['status'] = 'waiting'
# Step 2: Poll for scan result
deadline = time.time() + _WECOMBOT_SESSION_TTL
while time.time() < deadline:
await asyncio.sleep(3)
async with http.get(
f'{WECOM_QC_QUERY_URL}?scode={scode}',
) as poll_resp:
try:
poll_data = await poll_resp.json()
except (aiohttp.ContentTypeError, ValueError):
continue
status = poll_data.get('data', {}).get('status', '')
if status == 'success':
bot_info = poll_data.get('data', {}).get('bot_info', {})
if bot_info.get('botid') and bot_info.get('secret'):
session['status'] = 'success'
session['botid'] = bot_info['botid']
session['secret'] = bot_info['secret']
return
else:
session['status'] = 'error'
session['error'] = 'Scan succeeded but bot info is incomplete'
return
# Timeout
session['status'] = 'error'
session['error'] = 'QR code expired'
except asyncio.CancelledError:
return
except Exception as e:
session['status'] = 'error'
session['error'] = str(e)
task = asyncio.create_task(run_qr_flow())
session['task'] = task
# Wait for QR code to be ready (max 10 seconds)
for _ in range(20):
if session['qr_url'] or session['error']:
break
await asyncio.sleep(0.5)
if session['error']:
task.cancel()
return self.http_status(502, -1, session['error'])
if not session['qr_url']:
task.cancel()
session['status'] = 'error'
session['error'] = 'Timeout waiting for QR code'
return self.http_status(504, -1, 'Timeout waiting for QR code')
return self.success(
data={
'session_id': session_id,
'qr_url': session['qr_url'],
'expire_at': session['expire_at'],
}
)
@self.route('/wecombot/create-bot/status/<session_id>', methods=['GET'])
async def _(session_id: str) -> str:
"""Poll WeComBot creation status."""
_cleanup_expired_wecombot_sessions()
session = _wecombot_sessions.get(session_id)
if not session:
return self.http_status(404, -1, 'Session not found')
data = {'status': session['status']}
if session['status'] == 'success':
data['botid'] = session['botid']
data['secret'] = session['secret']
_wecombot_sessions.pop(session_id, None)
elif session['status'] == 'error':
data['error'] = session['error']
_wecombot_sessions.pop(session_id, None)
return self.success(data=data)
@self.route('/wecombot/create-bot/<session_id>', methods=['DELETE'])
async def _(session_id: str) -> str:
"""Cancel and clean up a WeComBot creation session."""
session = _wecombot_sessions.pop(session_id, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
return self.success(data={})

View File

@@ -7,10 +7,8 @@ import httpx
import uuid
import os
import posixpath
import sqlalchemy
from .....core import taskmgr
from .....entity.persistence import plugin as persistence_plugin
from .. import group
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
@@ -41,16 +39,6 @@ def _normalize_plugin_asset_path(filepath: str) -> str | None:
return f'assets/{normalized}'
def _get_request_origin() -> str:
"""Return the public request origin, respecting reverse-proxy headers."""
forwarded_proto = quart.request.headers.get('X-Forwarded-Proto', '').split(',')[0].strip()
forwarded_host = quart.request.headers.get('X-Forwarded-Host', '').split(',')[0].strip()
scheme = forwarded_proto or quart.request.scheme
host = forwarded_host or quart.request.host
return f'{scheme}://{host}'
@group.group_class('plugins', '/api/v1/plugins')
class PluginsRouterGroup(group.RouterGroup):
async def _check_extensions_limit(self) -> str | None:
@@ -150,15 +138,7 @@ class PluginsRouterGroup(group.RouterGroup):
return self.http_status(404, -1, 'plugin not found')
if quart.request.method == 'GET':
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_plugin.PluginSetting.config)
.where(persistence_plugin.PluginSetting.plugin_author == author)
.where(persistence_plugin.PluginSetting.plugin_name == plugin_name)
)
persisted_config = result.scalar_one_or_none()
config = persisted_config if persisted_config is not None else plugin['plugin_config']
return self.success(data={'config': config})
return self.success(data={'config': plugin['plugin_config']})
elif quart.request.method == 'PUT':
data = await quart.request.json
@@ -209,7 +189,7 @@ class PluginsRouterGroup(group.RouterGroup):
# CSP for HTML pages served to sandboxed iframes (opaque origin).
# 'self' doesn't work in sandboxed iframes — use actual server origin.
if mime_type and mime_type.startswith('text/html'):
origin = _get_request_origin()
origin = f'{quart.request.scheme}://{quart.request.host}'
resp.headers['Content-Security-Policy'] = (
f'default-src {origin}; '
f"script-src {origin} 'unsafe-inline'; "

View File

@@ -140,6 +140,17 @@ class SystemRouterGroup(group.RouterGroup):
async def _() -> str:
return self.success(data=await self.ap.maintenance_service.get_storage_analysis())
@self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')
py_code = await quart.request.data
ap = self.ap
return self.success(data=exec(py_code, {'ap': ap}))
@self.route(
'/debug/plugin/action',
methods=['POST'],

View File

@@ -146,7 +146,6 @@ class UserRouterGroup(group.RouterGroup):
return self.fail(3, str(e))
except ValueError as e:
traceback.print_exc()
self.ap.logger.warning(f'Space OAuth callback failed: {e}')
return self.fail(1, str(e))
except Exception as e:
traceback.print_exc()

View File

@@ -52,9 +52,6 @@ class ApiKeyService:
async def verify_api_key(self, key: str) -> bool:
"""Verify if an API key is valid"""
if not isinstance(key, str) or not key.startswith('lbk_'):
return False
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(apikey.ApiKey).where(apikey.ApiKey.key == key)
)

View File

@@ -99,11 +99,11 @@ class BotService:
# TODO: 检查配置信息格式
bot_data['uuid'] = str(uuid.uuid4())
# bind the most recently updated pipeline if any exist
# checkout the default pipeline
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
.order_by(persistence_pipeline.LegacyPipeline.updated_at.desc())
.limit(1)
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.is_default == True
)
)
pipeline = result.first()
if pipeline is not None:
@@ -120,26 +120,24 @@ class BotService:
async def update_bot(self, bot_uuid: str, bot_data: dict) -> None:
"""Update bot"""
update_data = bot_data.copy()
if 'uuid' in update_data:
del update_data['uuid']
if 'uuid' in bot_data:
del bot_data['uuid']
# set use_pipeline_name
if 'use_pipeline_uuid' in update_data:
if 'use_pipeline_uuid' in bot_data:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid == update_data['use_pipeline_uuid']
persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid']
)
)
pipeline = result.first()
if pipeline is not None:
update_data['use_pipeline_name'] = pipeline.name
bot_data['use_pipeline_name'] = pipeline.name
else:
raise Exception('Pipeline not found')
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_bot.Bot).values(update_data).where(persistence_bot.Bot.uuid == bot_uuid)
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid)
)
await self.ap.platform_mgr.remove_bot(bot_uuid)

View File

@@ -31,126 +31,15 @@ class KnowledgeService:
if not knowledge_engine_plugin_id:
raise ValueError('knowledge_engine_plugin_id is required')
creation_settings = kb_data.get('creation_settings', {})
retrieval_settings = kb_data.get('retrieval_settings', {})
# Validate required fields based on plugin's creation_schema and retrieval_schema
await self._validate_schema_required_fields(
knowledge_engine_plugin_id,
creation_settings,
retrieval_settings,
)
kb = await self.ap.rag_mgr.create_knowledge_base(
name=kb_data.get('name', 'Untitled'),
knowledge_engine_plugin_id=knowledge_engine_plugin_id,
creation_settings=creation_settings,
retrieval_settings=retrieval_settings,
creation_settings=kb_data.get('creation_settings', {}),
retrieval_settings=kb_data.get('retrieval_settings', {}),
description=kb_data.get('description', ''),
)
return kb.uuid
async def _validate_schema_required_fields(
self,
plugin_id: str,
creation_settings: dict,
retrieval_settings: dict,
) -> None:
"""Validate required fields based on plugin's creation_schema and retrieval_schema.
This is a business-agnostic validation that checks all fields marked as
required in the plugin's schema, regardless of field type.
Args:
plugin_id: Knowledge Engine plugin ID.
creation_settings: User-provided creation settings.
retrieval_settings: User-provided retrieval settings.
Raises:
ValueError: If any required field is missing or empty.
"""
# Validate creation_schema
try:
creation_schema = await self.ap.plugin_connector.get_rag_creation_schema(plugin_id)
self._check_required_fields(creation_schema, creation_settings, 'creation_settings')
except ValueError:
raise
except Exception as e:
self.ap.logger.warning(f'Failed to get creation_schema for validation: {e}')
# Validate retrieval_schema
try:
retrieval_schema = await self.ap.plugin_connector.get_rag_retrieval_schema(plugin_id)
self._check_required_fields(retrieval_schema, retrieval_settings, 'retrieval_settings')
except ValueError:
raise
except Exception as e:
self.ap.logger.warning(f'Failed to get retrieval_schema for validation: {e}')
def _check_required_fields(
self,
schema: dict | list,
settings: dict,
context: str,
) -> None:
"""Check required fields in schema against provided settings.
Args:
schema: Plugin-defined schema (can be list or dict with 'schema' key).
settings: User-provided settings values.
context: Context name for error messages (e.g., 'creation_settings').
Raises:
ValueError: If a required field is missing or empty.
"""
if not schema:
return
# schema can be a list directly, or a dict with 'schema' key
items = schema if isinstance(schema, list) else schema.get('schema', [])
if not items:
return
for item in items:
field_name = item.get('name')
if not field_name:
continue
is_required = item.get('required', False)
if not is_required:
continue
# Check show_if condition - if field is conditionally shown, only validate when condition is met
show_if = item.get('show_if')
if show_if:
depend_field = show_if.get('field')
operator = show_if.get('operator')
expected_value = show_if.get('value')
if depend_field and operator:
depend_value = settings.get(depend_field)
# If show_if condition is not met, skip validation for this field
if operator == 'eq' and depend_value != expected_value:
continue
if operator == 'neq' and depend_value == expected_value:
continue
if operator == 'in' and isinstance(expected_value, list) and depend_value not in expected_value:
continue
value = settings.get(field_name)
# Validate required field has a non-empty value
if value is None or (isinstance(value, str) and value.strip() == ''):
# Get field label for friendly error message
label = item.get('label', {})
field_label = (
label.get('en_US', field_name)
or label.get('zh_Hans', field_name)
or label.get('zh_Hant', field_name)
or field_name
)
raise ValueError(f'{field_label} is required ({context}.{field_name})')
async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
"""更新知识库"""
# Filter to only mutable fields

View File

@@ -113,9 +113,14 @@ class PipelineService:
return pipeline_data['uuid']
async def update_pipeline(self, pipeline_uuid: str, pipeline_data: dict) -> None:
pipeline_data = pipeline_data.copy()
for protected_field in ('uuid', 'for_version', 'stages', 'is_default'):
pipeline_data.pop(protected_field, None)
if 'uuid' in pipeline_data:
del pipeline_data['uuid']
if 'for_version' in pipeline_data:
del pipeline_data['for_version']
if 'stages' in pipeline_data:
del pipeline_data['stages']
if 'is_default' in pipeline_data:
del pipeline_data['is_default']
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_pipeline.LegacyPipeline)

View File

@@ -17,24 +17,6 @@ class ModelProviderService:
def __init__(self, ap: app.Application) -> None:
self.ap = ap
@staticmethod
def _normalize_api_keys(api_keys: str | list[str] | tuple[str, ...] | None) -> list[str]:
if api_keys is None:
return []
raw_keys = [api_keys] if isinstance(api_keys, str) else list(api_keys)
normalized_keys = []
seen_keys = set()
for raw_key in raw_keys:
normalized_key = raw_key.strip() if isinstance(raw_key, str) else ''
if not normalized_key or normalized_key in seen_keys:
continue
normalized_keys.append(normalized_key)
seen_keys.add(normalized_key)
return normalized_keys
async def get_providers(self) -> list[dict]:
"""Get all providers"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.ModelProvider))
@@ -77,7 +59,6 @@ class ModelProviderService:
async def create_provider(self, provider_data: dict) -> str:
"""Create a new provider"""
provider_data['uuid'] = str(uuid.uuid4())
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
)
@@ -91,8 +72,6 @@ class ModelProviderService:
"""Update an existing provider"""
if 'uuid' in provider_data:
del provider_data['uuid']
if 'api_keys' in provider_data:
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.ModelProvider)
.where(persistence_model.ModelProvider.uuid == provider_uuid)
@@ -162,8 +141,6 @@ class ModelProviderService:
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
"""Find existing provider or create new one"""
api_keys = self._normalize_api_keys(api_keys)
# Try to find existing provider with same config
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
@@ -191,7 +168,7 @@ class ModelProviderService:
'name': provider_name,
'requester': requester,
'base_url': base_url,
'api_keys': api_keys,
'api_keys': api_keys or [],
}
)
@@ -200,7 +177,7 @@ class ModelProviderService:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.ModelProvider)
.where(persistence_model.ModelProvider.uuid == '00000000-0000-0000-0000-000000000000')
.values(api_keys=self._normalize_api_keys(api_key))
.values(api_keys=[api_key])
)
await self.ap.model_mgr.reload_provider('00000000-0000-0000-0000-000000000000')

View File

@@ -46,14 +46,12 @@ async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
async def main(loop: asyncio.AbstractEventLoop):
app_inst: app.Application | None = None
try:
# Hang system signal processing
import signal
def signal_handler(sig, frame):
if app_inst is not None:
app_inst.dispose()
app_inst.dispose()
print('[Signal] Program exit.')
os._exit(0)

View File

@@ -275,7 +275,6 @@ class MessageAggregator:
message_chain=merged_chain,
adapter=base_msg.adapter,
pipeline_uuid=base_msg.pipeline_uuid,
routed_by_rule=any(msg.routed_by_rule for msg in messages),
)
async def flush_all(self) -> None:

View File

@@ -76,10 +76,6 @@ class LongTextProcessStage(stage.PipelineStage):
self.ap.logger.debug('Long message processing strategy is not set, skip long message processing.')
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
if not query.resp_message_chain:
self.ap.logger.debug('Response message chain is empty, skip long message processing.')
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
# 检查是否包含非 Plain 组件
contains_non_plain = False

View File

@@ -157,7 +157,7 @@ class RuntimePipeline:
bot_message=query.resp_messages[-1],
message=result.user_notice,
quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
is_final=[msg.is_final for msg in query.resp_messages][-1],
is_final=[msg.is_final for msg in query.resp_messages][0],
)
else:
await query.adapter.reply_message(

View File

@@ -42,13 +42,9 @@ class QueryPool:
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
pipeline_uuid: typing.Optional[str] = None,
routed_by_rule: bool = False,
variables: typing.Optional[dict[str, typing.Any]] = None,
) -> pipeline_query.Query:
async with self.condition:
query_id = self.query_id_counter
initial_variables: dict[str, typing.Any] = {'_routed_by_rule': routed_by_rule}
if variables:
initial_variables.update(variables)
query = pipeline_query.Query(
bot_uuid=bot_uuid,
query_id=query_id,
@@ -57,7 +53,7 @@ class QueryPool:
sender_id=sender_id,
message_event=message_event,
message_chain=message_chain,
variables=initial_variables,
variables={'_routed_by_rule': routed_by_rule},
resp_messages=[],
resp_message_chain=[],
adapter=adapter,
@@ -67,7 +63,6 @@ class QueryPool:
self.cached_queries[query_id] = query
self.query_id_counter += 1
self.condition.notify_all()
return query
async def __aenter__(self):
await self.pool_lock.acquire()

View File

@@ -40,7 +40,7 @@ class SendResponseBackStage(stage.PipelineStage):
has_chunks = any(isinstance(msg, provider_message.MessageChunk) for msg in query.resp_messages)
# TODO 命令与流式的兼容性问题
if await query.adapter.is_stream_output_supported() and has_chunks:
is_final = [msg.is_final for msg in query.resp_messages][-1]
is_final = [msg.is_final for msg in query.resp_messages][0]
await query.adapter.reply_message_chunk(
message_source=query.message_event,
bot_message=query.resp_messages[-1],

View File

@@ -501,8 +501,6 @@ class PlatformManager:
bot_entity.adapter_config,
logger,
)
if hasattr(adapter_inst, 'ap'):
adapter_inst.ap = self.ap
# 如果 adapter 支持 set_bot_uuid 方法,设置 bot_uuid用于统一 webhook
if hasattr(adapter_inst, 'set_bot_uuid'):

View File

@@ -3,7 +3,6 @@ import typing
import asyncio
import traceback
import datetime
import json
import aiocqhttp
import pydantic
@@ -294,29 +293,6 @@ class AiocqhttpMessageConverter(abstract_platform_adapter.AbstractMessageConvert
elif msg.type == 'dice':
face_id = msg.data['result']
yiri_msg_list.append(platform_message.Face(face_type='dice', face_id=int(face_id), face_name='骰子'))
elif msg.type == 'json':
try:
raw = msg.data.get('data', {})
if isinstance(raw, str):
raw = json.loads(raw)
if isinstance(raw, dict):
_meta = raw.get('meta', {}) or {}
if isinstance(_meta, dict):
_detail = _meta.get('detail_1') or _meta.get('music') or _meta.get('news') or {}
else:
_detail = {}
if isinstance(_detail, dict):
preview = _detail.get('preview', '')
title = _detail.get('desc', '') or _detail.get('title', '')
url = _detail.get('qqdocurl', '') or _detail.get('jumpUrl', '')
else:
preview = title = url = ''
text = ' '.join([f'[{raw.get("app", "")}]', preview, title, url]).strip()
yiri_msg_list.append(platform_message.Plain(text=text or '[收到一张JSON卡片]'))
else:
yiri_msg_list.append(platform_message.Plain(text=str(raw)))
except Exception:
yiri_msg_list.append(platform_message.Plain(text='[收到一张JSON卡片]'))
chain = platform_message.MessageChain(yiri_msg_list)

View File

@@ -1,19 +1,13 @@
import asyncio
import json
import traceback
import typing
import uuid
from langbot.libs.dingtalk_api.dingtalkevent import DingTalkEvent
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.entities.builtin.provider.session as provider_session
from langbot.libs.dingtalk_api.api import DingTalkClient
import datetime
from langbot.pkg.platform.logger import EventLogger
from langbot.pkg.provider.runners.difysvapi import _format_human_input_text
class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@@ -176,22 +170,6 @@ class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
card_instance_id_dict: (
dict # 回复卡片消息字典key为消息idvalue为回复卡片实例id用于在流式消息时判断是否发送到指定卡片
)
# outTrackId → form snapshot {session_key, launcher_type, launcher_id, form_token,
# workflow_run_id, actions, node_title, form_content, expires_at, open_space_id,
# user_id_hint, current_text}. Lookup keys for the data-source pull endpoint and
# the STREAM card-action callback.
card_state: dict
# session_key → out_track_id of the currently-active card for the
# conversation turn. Lets resumed-workflow chunks (which arrive on a
# synthetic event with a fresh resp_message_id) keep updating the same
# card the user clicked instead of getting a new one.
active_turn_card: dict
# session_key → accumulated streaming text for the active turn. Read
# by _paint_form_on_card so the post-pause form keeps the streamed
# context above the new prompt.
active_turn_text: dict
ap: typing.Any = None
bot_uuid: str = ''
def __init__(self, config: dict, logger: EventLogger):
required_keys = [
@@ -216,17 +194,10 @@ class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
config=config,
logger=logger,
card_instance_id_dict={},
card_state={},
active_turn_card={},
active_turn_text={},
bot_account_id=bot_account_id,
bot=bot,
listeners={},
)
# Wire the card-action callback after super().__init__ so we can reference
# self.* — the client's handler stores this as a soft reference and reads
# it at fire time.
self.bot.card_action_callback = self._on_card_action
async def reply_message(
self,
@@ -251,82 +222,28 @@ class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
quote_origin: bool = False,
is_final: bool = False,
):
# event = await DingTalkEventConverter.yiri2target(
# message_source,
# )
# incoming_message = event.incoming_message
# msg_id = incoming_message.message_id
message_id = bot_message.resp_message_id
msg_seq = bot_message.msg_sequence
form_template_id = (self.config.get('human_input_card_template_id') or '').strip()
form_data = getattr(bot_message, '_form_data', None)
if is_final and self.ap is not None:
self.ap.logger.info(
f'DingTalk reply_message_chunk final: form_data_present={form_data is not None}, '
f'form_template_configured={bool(form_template_id)}'
)
if form_data and is_final:
await self._handle_form_chunk(message_source, bot_message, message, form_data)
return
if (msg_seq - 1) % 8 == 0 or is_final:
markdown_enabled = self.config.get('markdown_card', False)
content, at = await DingTalkMessageConverter.yiri2target(message, markdown_enabled)
card_instance, card_instance_id = self.card_instance_id_dict[message_id]
if not content and bot_message.content:
content = bot_message.content # 兼容直接传入content的情况
chat_card_entry = self.card_instance_id_dict.get(message_id)
if chat_card_entry is None:
# No streaming chat card was created for this query — common
# path for synthetic events (e.g. resumed workflow after a
# button click). Lazy-create one so the resumed output streams
# into a card just like a normal conversation, instead of
# being deferred and sent in one shot on is_final.
if not content:
return # nothing to stream yet
chat_card_entry = await self._lazy_create_resume_chat_card(message_source, message_id)
if chat_card_entry is None:
# Lazy-create failed (no template configured); fall back
# to a one-shot proactive message on the final chunk.
if is_final:
await self._send_proactive_to_event(message_source, content)
return
card_instance, card_instance_id = chat_card_entry
# print(card_instance_id)
if content:
if form_template_id:
# The card content has already been written via
# update_card_data (in _paint_form_on_card and the
# initial card creation). The streaming endpoint
# (PUT /v1.0/card/streaming) does not propagate
# updates on cards whose content was last set via
# update_card_data — they take different code paths
# on the DingTalk client. Stick with update_card_data
# for the whole turn for consistency.
try:
await self.bot.update_card_data(
out_track_id=card_instance_id,
card_param_map={
'content': content,
'btns': '[]',
'flowStatus': '3' if is_final else '1',
},
)
except Exception:
if self.ap is not None:
self.ap.logger.exception('DingTalk: update card content failed')
else:
await self.bot.send_card_message(card_instance, card_instance_id, content, is_final)
if is_final:
if form_template_id and not content:
# Empty final chunk still needs to leave the card with
# flowStatus=3 so the spinner stops.
try:
await self.bot.update_card_data(
out_track_id=card_instance_id,
card_param_map={'flowStatus': '3'},
)
except Exception:
pass
if bot_message.tool_calls is None:
self.card_instance_id_dict.pop(message_id, None)
await self.bot.send_card_message(card_instance, card_instance_id, content, is_final)
if is_final and bot_message.tool_calls is None:
# self.seq = 1 # 消息回复结束之后重置seq
self.card_instance_id_dict.pop(message_id) # 消息回复结束之后删除卡片实例id
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
markdown_enabled = self.config.get('markdown_card', False)
@@ -343,80 +260,16 @@ class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
return is_stream
async def create_message_card(self, message_id, event):
form_template_id = (self.config.get('human_input_card_template_id') or '').strip()
legacy_template_id = self.config.get('card_template_id', '')
# Synthetic events (button clicks): look up the card already in
# active_turn_card so reply_message_chunk can stream to it.
if event is None or event.source_platform_object is None:
if form_template_id:
session_key = self._session_key_from_event(event) if event is not None else ''
carry = self.active_turn_card.get(session_key, '') if session_key else ''
if carry:
self.card_instance_id_dict[message_id] = (None, carry)
return True
return False
if form_template_id:
# Create one card with the form template, empty buttons,
# pending state. Streaming writes content to it; form pause
# paints buttons onto it. One card per turn, no duplication.
incoming_message = event.source_platform_object.incoming_message
out_track_id = uuid.uuid4().hex
is_group = str(incoming_message.conversation_type) == '2'
if is_group:
open_space_id = f'dtv1.card//IM_GROUP.{incoming_message.conversation_id}'
else:
open_space_id = f'dtv1.card//IM_ROBOT.{incoming_message.sender_staff_id}'
try:
await self.bot.create_and_deliver_card(
card_template_id=form_template_id,
out_track_id=out_track_id,
open_space_id=open_space_id,
is_group=is_group,
card_param_map={'content': '', 'btns': '[]', 'flowStatus': '1'},
callback_type='STREAM',
)
except Exception:
if self.ap is not None:
self.ap.logger.exception('DingTalk: create form-template card failed')
return False
self.card_instance_id_dict[message_id] = (None, out_track_id)
session_key = self._session_key_from_event(event)
if session_key:
self.active_turn_card[session_key] = out_track_id
self.active_turn_text[session_key] = ''
return True
# Legacy chat-card path (no form template).
card_template_id = self.config['card_template_id']
incoming_message = event.source_platform_object.incoming_message
card_auto_layout = self.config.get('card_auto_layout', False)
# message_id = incoming_message.message_id
card_auto_layout = self.config.get('card_ auto_layout', False)
card_instance, card_instance_id = await self.bot.create_and_card(
legacy_template_id, incoming_message, card_auto_layout=card_auto_layout
card_template_id, incoming_message, card_auto_layout=card_auto_layout
)
self.card_instance_id_dict[message_id] = (card_instance, card_instance_id)
return True
def _session_key_from_event(self, event) -> str:
"""Return launcher_type_launcher_id for an event, '' if unrecoverable."""
if event is None:
return ''
spo = event.source_platform_object
if spo is None:
try:
if isinstance(event, platform_events.GroupMessage):
return f'group_{event.group.id}'
return f'person_{event.sender.id}'
except Exception:
return ''
try:
inc = spo.incoming_message
if str(inc.conversation_type) == '2':
return f'group_{inc.conversation_id}'
return f'person_{inc.sender_staff_id}'
except Exception:
return ''
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
@@ -456,543 +309,3 @@ class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
],
):
return super().unregister_listener(event_type, callback)
# ------------------------------------------------------------------
# Dify human-input form support
# ------------------------------------------------------------------
def set_bot_uuid(self, bot_uuid: str):
"""Receive the bot uuid from the platform manager.
Used to compose the public-facing unified-webhook URL for the card
dynamic-data-source pull endpoint.
"""
self.bot_uuid = bot_uuid
def _derive_open_space(self, message_source: platform_events.MessageEvent) -> tuple[str, bool]:
"""Return (openSpaceId, is_group) for the given inbound event."""
if isinstance(message_source, platform_events.GroupMessage):
return f'dtv1.card//IM_GROUP.{message_source.group.id}', True
return f'dtv1.card//IM_ROBOT.{message_source.sender.id}', False
def _derive_session_descriptor(
self, message_source: platform_events.MessageEvent
) -> tuple[provider_session.LauncherTypes, str, str]:
"""Return (launcher_type, launcher_id, sender_user_id) for routing."""
if isinstance(message_source, platform_events.GroupMessage):
return (
provider_session.LauncherTypes.GROUP,
str(message_source.group.id),
str(message_source.sender.id),
)
return (
provider_session.LauncherTypes.PERSON,
str(message_source.sender.id),
str(message_source.sender.id),
)
async def _handle_form_chunk(
self,
message_source: platform_events.MessageEvent,
bot_message,
message: platform_message.MessageChain,
form_data: dict,
) -> None:
"""Surface human-input prompt + buttons on the active card.
In single-card mode (form_template_id configured): update the
EXISTING card with form buttons so it transitions from streaming
output to prompt+buttons on the same card. In legacy mode:
finalize the chat card and deliver a separate form card.
"""
if self.ap is not None:
self.ap.logger.info(
f'DingTalk _handle_form_chunk: actions={len(form_data.get("actions") or [])}, '
f'node_title={form_data.get("node_title", "")!r}'
)
message_id = bot_message.resp_message_id
template_id = (self.config.get('human_input_card_template_id') or '').strip()
if template_id:
# Single-card mode: paint prompt + buttons onto the existing card.
session_key = self._session_key_from_event(message_source)
entry = self.card_instance_id_dict.get(message_id)
out_track_id = entry[1] if entry else None
if not out_track_id and session_key:
out_track_id = self.active_turn_card.get(session_key, '')
if out_track_id:
await self._paint_form_on_card(message_source, out_track_id, form_data, session_key)
self.card_instance_id_dict.pop(message_id, None)
return
# No existing card (e.g. Dify paused immediately with no LLM
# output before the pause). Create a form card directly.
await self._send_form_card(message_source, form_data, template_id)
self.card_instance_id_dict.pop(message_id, None)
return
# Legacy mode: finalize the streaming card with text fallback.
chat_card_entry = self.card_instance_id_dict.pop(message_id, None)
if chat_card_entry is not None:
_, chat_out_track_id = chat_card_entry
markdown_enabled = self.config.get('markdown_card', False)
text_content, _ = await DingTalkMessageConverter.yiri2target(message, markdown_enabled)
if not text_content and bot_message.content:
text_content = bot_message.content
try:
await self.bot.send_card_message(None, chat_out_track_id, text_content or '', True)
except Exception:
await self.logger.error(f'DingTalk: finalize chat card before form failed: {traceback.format_exc()}')
await self.send_message_text_form(message_source, form_data)
async def _paint_form_on_card(
self,
message_source: platform_events.MessageEvent,
out_track_id: str,
form_data: dict,
session_key: str,
) -> None:
"""Update an existing card's content + buttons for human-input."""
actions = list(form_data.get('actions') or [])
node_title = form_data.get('node_title', '') or 'Human Input Required'
form_content = form_data.get('form_content', '') or ''
# Record form state for the click-handler.
launcher_type, launcher_id, sender_user_id = self._derive_session_descriptor(message_source)
self.card_state[out_track_id] = {
'session_key': session_key,
'launcher_type': launcher_type.value,
'launcher_id': launcher_id,
'sender_user_id': sender_user_id,
'form_token': form_data.get('form_token', ''),
'workflow_run_id': form_data.get('workflow_run_id', ''),
'actions': actions,
'node_title': node_title,
'form_content': form_content,
}
btns = self._build_btns(actions, out_track_id)
parts: list[str] = []
prior = self.active_turn_text.get(session_key, '') if session_key else ''
if prior.strip():
parts.append(prior.rstrip())
parts.append('---')
if node_title:
parts.append(f'**{node_title}**')
if form_content:
parts.append(form_content)
display_content = '\n\n'.join(parts) or '请选择一个操作以继续。'
try:
await self.bot.update_card_data(
out_track_id=out_track_id,
card_param_map={
'content': display_content,
'btns': json.dumps(btns, ensure_ascii=False),
'flowStatus': '3',
},
)
except Exception:
if self.ap is not None:
self.ap.logger.exception('DingTalk: paint form on card failed')
await self.send_message_text_form(message_source, form_data)
return
if session_key:
self.active_turn_text[session_key] = display_content
@staticmethod
def _build_btns(actions: list, out_track_id: str) -> list:
btns = []
for idx, action in enumerate(actions):
action_id = str(action.get('id') or '')
title = str(action.get('title') or action_id or f'选项 {idx + 1}')
style = (action.get('button_style') or '').lower()
if style == 'primary' or (style == '' and idx == 0):
color = 'blue'
elif style == 'danger':
color = 'red'
else:
color = 'gray'
btns.append(
{
'text': title,
'color': color,
'status': 'normal',
'event': {
'type': 'sendCardRequest',
'params': {
'actionId': action_id,
'params': {'action_id': action_id, 'out_track_id': out_track_id},
},
},
}
)
return btns
async def _send_form_card(
self,
message_source: platform_events.MessageEvent,
form_data: dict,
template_id: str,
) -> None:
"""Deliver a new card pre-loaded with the human-input prompt + buttons."""
out_track_id = uuid.uuid4().hex
open_space_id, is_group = self._derive_open_space(message_source)
launcher_type, launcher_id, sender_user_id = self._derive_session_descriptor(message_source)
session_key = f'{launcher_type.value}_{launcher_id}'
actions = list(form_data.get('actions') or [])
node_title = form_data.get('node_title', '') or 'Human Input Required'
form_content = form_data.get('form_content', '') or ''
self.card_state[out_track_id] = {
'session_key': session_key,
'launcher_type': launcher_type.value,
'launcher_id': launcher_id,
'sender_user_id': sender_user_id,
'form_token': form_data.get('form_token', ''),
'workflow_run_id': form_data.get('workflow_run_id', ''),
'actions': actions,
'node_title': node_title,
'form_content': form_content,
'open_space_id': open_space_id,
'is_group': is_group,
}
parts = []
if node_title:
parts.append(f'**{node_title}**')
if form_content:
parts.append(form_content)
display_content = '\n\n'.join(parts) or '请选择一个操作以继续。'
btns = []
for idx, action in enumerate(actions):
action_id = str(action.get('id') or '')
title = str(action.get('title') or action_id or f'选项 {idx + 1}')
style = (action.get('button_style') or '').lower()
if style == 'primary' or (style == '' and idx == 0):
color = 'blue'
elif style == 'danger':
color = 'red'
else:
color = 'gray'
btns.append(
{
'text': title,
'color': color,
'status': 'normal',
'event': {
'type': 'sendCardRequest',
'params': {
'actionId': action_id,
'params': {'action_id': action_id, 'out_track_id': out_track_id},
},
},
}
)
try:
if self.ap is not None:
self.ap.logger.info(
f'DingTalk _send_form_card: out_track_id={out_track_id} template_id={template_id} '
f'open_space_id={open_space_id} is_group={is_group} btns={len(btns)}'
)
await self.bot.create_and_deliver_card(
card_template_id=template_id,
out_track_id=out_track_id,
open_space_id=open_space_id,
is_group=is_group,
card_param_map={
'content': display_content,
'btns': json.dumps(btns, ensure_ascii=False),
'flowStatus': '3',
},
callback_type='STREAM',
)
except Exception:
await self.logger.error(f'DingTalk: deliver form card failed: {traceback.format_exc()}')
await self.send_message_text_form(message_source, form_data)
self.card_state.pop(out_track_id, None)
async def _lazy_create_resume_chat_card(
self,
message_source: platform_events.MessageEvent,
message_id: str,
) -> typing.Optional[tuple]:
"""Create a new card for resumed-workflow streaming output.
Used after a button click triggers a synthetic event — the form
card stays put with the "已选择" notice, and a fresh card is
spawned here for the LLM reply to stream into.
"""
form_template_id = (self.config.get('human_input_card_template_id') or '').strip()
legacy_template_id = (self.config.get('card_template_id') or '').strip()
template_id = form_template_id or legacy_template_id
if not template_id:
return None
out_track_id = uuid.uuid4().hex
open_space_id, is_group = self._derive_open_space(message_source)
if form_template_id:
card_param_map = {'content': '', 'btns': '[]', 'flowStatus': '1'}
card_data_config = None
else:
card_param_map = {'content': '', 'query': '...'}
card_data_config = {'autoLayout': self.config.get('card_auto_layout', False)}
try:
success = await self.bot.create_and_deliver_card(
card_template_id=template_id,
out_track_id=out_track_id,
open_space_id=open_space_id,
is_group=is_group,
card_param_map=card_param_map,
card_data_config=card_data_config,
callback_type='STREAM',
)
except Exception:
if self.ap is not None:
self.ap.logger.exception('DingTalk: lazy create resume chat card failed')
return None
if not success:
return None
entry = (None, out_track_id)
self.card_instance_id_dict[message_id] = entry
# Register as the active card so any further chunks on this turn
# (and a subsequent re-pause) land on the same new card.
session_key = self._session_key_from_event(message_source)
if session_key:
self.active_turn_card[session_key] = out_track_id
self.active_turn_text[session_key] = ''
return entry
async def send_message_text_form(
self,
message_source: platform_events.MessageEvent,
form_data: dict,
) -> None:
"""Fallback: send the human-input prompt as plain text."""
display_text = _format_human_input_text(
form_data.get('node_title', ''),
form_data.get('form_content', ''),
form_data.get('actions', []) or [],
)
await self._send_proactive_to_event(message_source, display_text)
async def _send_proactive_to_event(
self,
message_source: platform_events.MessageEvent,
content: str,
) -> None:
"""Send `content` as a proactive message to the conversation behind
`message_source`. Used when no inbound chatbot message exists to
anchor a card on (e.g. resumed flows triggered by card actions).
"""
if not content:
return
if self.ap is not None:
target = (
str(message_source.group.id)
if isinstance(message_source, platform_events.GroupMessage)
else str(message_source.sender.id)
)
self.ap.logger.info(
f'DingTalk _send_proactive_to_event: target={target} '
f'is_group={isinstance(message_source, platform_events.GroupMessage)} content_len={len(content)}'
)
try:
if isinstance(message_source, platform_events.GroupMessage):
await self.bot.send_proactive_message_to_group(str(message_source.group.id), content)
else:
await self.bot.send_proactive_message_to_one(str(message_source.sender.id), content)
except Exception:
if self.ap is not None:
self.ap.logger.exception('DingTalk: send proactive message failed')
await self.logger.error(f'DingTalk: send proactive message failed: {traceback.format_exc()}')
async def _on_card_action(self, payload: dict) -> None:
"""Translate a card button click into a synthetic query.
Reads the clicked button's ``actionId`` (the real Dify action id —
the ButtonGroup template sends it back via `event.params.actionId`),
recovers the action title from ``card_state``, and enqueues a
synthetic `_dify_form_action` query the same way Lark / Telegram do.
"""
if self.ap is not None:
self.ap.logger.info(
f'DingTalk _on_card_action received: out_track_id={payload.get("out_track_id")} '
f'payload_action_id={payload.get("action_id")!r} params={payload.get("params")!r}'
)
out_track_id = payload.get('out_track_id') or ''
params = payload.get('params') or {}
# ButtonGroup `sendCardRequest` events surface the click id at the
# callback top level as `actionId`; fall back to `params.action_id`
# (alternate template wiring) and `params.actionId`.
raw_action_id = (
(payload.get('action_id') or '').strip()
or (params.get('action_id') or '').strip()
or (params.get('actionId') or '').strip()
or (params.get('id') or '').strip()
)
state = self.card_state.get(out_track_id)
if state is None:
await self.logger.warning(f'DingTalk: card action received for unknown out_track_id={out_track_id}')
return
if not raw_action_id:
await self.logger.warning(f'DingTalk: card action with no action_id, payload={payload}')
return
actions = state.get('actions', []) or []
action_id = raw_action_id
action_title = raw_action_id
for action in actions:
if str(action.get('id', '')) == raw_action_id:
action_title = action.get('title') or raw_action_id
break
launcher_type = (
provider_session.LauncherTypes.GROUP
if state.get('launcher_type') == provider_session.LauncherTypes.GROUP.value
else provider_session.LauncherTypes.PERSON
)
launcher_id = state.get('launcher_id', '')
sender_user_id = state.get('sender_user_id') or payload.get('user_id') or launcher_id
form_action_data = {
'form_token': state.get('form_token', ''),
'workflow_run_id': state.get('workflow_run_id', ''),
'action_id': action_id,
'action_title': action_title,
'node_title': state.get('node_title', ''),
'user': f'{launcher_type.value}_{launcher_id}',
'inputs': {},
}
message_chain = platform_message.MessageChain([platform_message.Plain(text=f'[Form Action: {action_title}]')])
if launcher_type == provider_session.LauncherTypes.GROUP:
synthetic_event = platform_events.GroupMessage(
sender=platform_entities.GroupMember(
id=sender_user_id,
member_name='',
permission=platform_entities.Permission.Member,
group=platform_entities.Group(
id=launcher_id,
name='',
permission=platform_entities.Permission.Member,
),
special_title='',
),
message_chain=message_chain,
time=int(datetime.datetime.now().timestamp()),
source_platform_object=None,
)
else:
synthetic_event = platform_events.FriendMessage(
sender=platform_entities.Friend(
id=sender_user_id,
nickname='',
remark='',
),
message_chain=message_chain,
time=int(datetime.datetime.now().timestamp()),
source_platform_object=None,
)
bot_uuid = ''
pipeline_uuid = None
if self.ap is not None:
for bot in self.ap.platform_mgr.bots:
if bot.adapter is self:
bot_uuid = bot.bot_entity.uuid
pipeline_uuid = bot.bot_entity.use_pipeline_uuid
break
try:
self.ap.logger.info(
f'DingTalk _on_card_action enqueuing form action: action_id={action_id!r} '
f'action_title={action_title!r} launcher_type={launcher_type.value} '
f'launcher_id={launcher_id} bot_uuid={bot_uuid} pipeline_uuid={pipeline_uuid}'
)
await self.ap.query_pool.add_query(
bot_uuid=bot_uuid,
launcher_type=launcher_type,
launcher_id=launcher_id,
sender_id=sender_user_id,
message_event=synthetic_event,
message_chain=message_chain,
adapter=self,
pipeline_uuid=pipeline_uuid,
variables={
'_dify_form_action': form_action_data,
'_routed_by_rule': True,
},
)
self.ap.logger.info('DingTalk _on_card_action: query enqueued OK')
except Exception:
self.ap.logger.exception('DingTalk: enqueue form action query failed')
return
# Visual feedback on the form card itself: keep the prompt visible,
# add a "已选择" line, remove the buttons. The resumed-workflow
# output lives on a separate new card (lazy-created in
# reply_message_chunk on the synthetic event), so the form card
# stays put as a record of the user's selection.
asyncio.create_task(
self._mark_card_resolved(
out_track_id,
action_title,
node_title=state.get('node_title', ''),
form_content=state.get('form_content', ''),
)
)
# Crucial: do NOT leave the form card's out_track_id in
# active_turn_card — otherwise create_message_card for the
# synthetic event would reuse it for the resume output, painting
# the LLM reply on top of the "已选择" notice. Clear it so the
# resume goes through the lazy-create path and spawns a fresh card.
session_key = state.get('session_key', '')
if session_key and self.active_turn_card.get(session_key) == out_track_id:
self.active_turn_card.pop(session_key, None)
self.active_turn_text.pop(session_key, None)
# Once consumed, drop the state — the runner clears _PENDING_FORMS too.
self.card_state.pop(out_track_id, None)
async def _mark_card_resolved(
self,
out_track_id: str,
action_title: str,
*,
node_title: str = '',
form_content: str = '',
) -> None:
"""Update the form card to acknowledge the user's selection.
Keeps the original prompt visible, adds a "已选择: X" notice, and
clears the buttons. The card stays as a permanent record of the
choice; the resumed workflow's output goes to a separate new card.
"""
parts: list[str] = []
if node_title:
parts.append(f'**{node_title}**')
if form_content:
parts.append(form_content)
parts.append(f'---\n✅ 已选择:**{action_title}**')
content = '\n\n'.join(parts)
if self.ap is not None:
self.ap.logger.info(f'DingTalk _mark_card_resolved: out_track_id={out_track_id} action={action_title!r}')
try:
await self.bot.update_card_data(
out_track_id=out_track_id,
card_param_map={
'content': content,
'btns': '[]',
'flowStatus': '3',
},
)
except Exception:
if self.ap is not None:
self.ap.logger.exception('DingTalk: mark card resolved failed')

View File

@@ -19,18 +19,6 @@ spec:
en: https://link.langbot.app/en/platforms/dingtalk
ja: https://link.langbot.app/ja/platforms/dingtalk
config:
- name: one-click-create
label:
en_US: One-Click Create App
zh_Hans: 一键创建应用
zh_Hant: 一鍵建立應用
description:
en_US: "Scan QR code with DingTalk to automatically create an app and fill in credentials. Note: Robot Code cannot be obtained automatically, you need to copy it from the DingTalk Developer Backend manually."
zh_Hans: "使用钉钉扫码自动创建应用并填写凭据。注意:机器人代码无法自动获取,需前往钉钉开发者后台手动复制。"
zh_Hant: "使用釘釘掃碼自動建立應用並填寫憑證。注意:機器人代碼無法自動取得,需前往釘釘開發者後台手動複製。"
type: qr-code-login
login_platform: dingtalk
required: false
- name: client_id
label:
en_US: Client ID
@@ -52,10 +40,6 @@ spec:
en_US: Robot Code
zh_Hans: 机器人代码
zh_Hant: 機器人代碼
description:
en_US: "Required for image recognition, file upload and other features. Get it from DingTalk Developer Backend > Robot Configuration."
zh_Hans: "识图、上传文件等功能必填。请前往钉钉开发者后台 > 机器人配置中获取。"
zh_Hant: "識圖、上傳檔案等功能必填。請前往釘釘開發者後台 > 機器人設定中取得。"
type: string
required: true
default: ""
@@ -103,18 +87,6 @@ spec:
type: string
required: true
default: "填写你的卡片template_id"
- name: human_input_card_template_id
label:
en_US: Human Input Card Template ID
zh_Hans: 人工输入卡片模板ID
zh_Hant: 人工輸入卡片範本ID
description:
en_US: "Template ID used as the SINGLE card for the whole conversation turn. Streamed LLM text fills the `content` markdown variable; on a Dify human-input pause the `btns` buttonGroup variable is populated so action buttons appear on the SAME card; after the user clicks a button the buttons disappear and resumed streaming continues into the same card. Use the bundled `src/langbot/templates/dingtalk_human_input_card.json` — it ships with `content` (MarkdownBlock) and `btns` (ButtonGroup) already wired. Leave empty to fall back to the legacy two-card behaviour (chat card streaming text + plain-text human-input prompts)."
zh_Hans: "用作整个对话回合**唯一**卡片的模板ID。流式 LLM 文本写入 `content` markdown 变量Dify 人工输入暂停时同一张卡的 `btns` buttonGroup 变量被填上、按钮浮现;用户点击后按钮消失、恢复的流式内容继续追加到同一张卡。可使用项目附带的 `src/langbot/templates/dingtalk_human_input_card.json`——已经预先连好 `content` (MarkdownBlock) 与 `btns` (ButtonGroup)。留空则降级为旧的双卡行为(聊天卡走流式 + 人工输入走纯文本)。"
zh_Hant: "用作整個對話回合**唯一**卡片的範本ID。流式 LLM 文字寫入 `content` markdown 變數Dify 人工輸入暫停時同一張卡的 `btns` buttonGroup 變數被填上、按鈕浮現;使用者點擊後按鈕消失、恢復的流式內容繼續追加到同一張卡。可使用專案附帶的 `src/langbot/templates/dingtalk_human_input_card.json`——已經預先連好 `content` (MarkdownBlock) 與 `btns` (ButtonGroup)。留空則降級為舊的雙卡行為(聊天卡走流式 + 人工輸入走純文字)。"
type: string
required: false
default: ""
execution:
python:
path: ./dingtalk.py

File diff suppressed because it is too large Load Diff

View File

@@ -23,20 +23,6 @@ spec:
en: https://link.langbot.app/en/platforms/lark
ja: https://link.langbot.app/ja/platforms/lark
config:
- name: one-click-create
label:
en_US: One-Click Create App
zh_Hans: 一键创建应用
zh_Hant: 一鍵建立應用
ja_JP: ワンクリックでアプリ作成
description:
en_US: Scan QR code to automatically create a Feishu app and fill in credentials
zh_Hans: 扫码自动创建飞书应用并填写凭据
zh_Hant: 掃碼自動建立飛書應用並填寫憑證
ja_JP: QRコードをスキャンしてFeishuアプリを自動作成し、認証情報を入力
type: qr-code-login
login_platform: feishu
required: false
- name: app_id
label:
en_US: App ID

View File

@@ -32,20 +32,6 @@ spec:
type: string
required: true
default: "https://ilinkai.weixin.qq.com"
- name: qr-login
label:
en_US: Scan QR Login
zh_Hans: 扫码登录
zh_Hant: 掃碼登入
ja_JP: QRコードでログイン
description:
en_US: Scan QR code with WeChat to authorize and automatically fill in the token
zh_Hans: 使用微信扫码授权,自动填写令牌
zh_Hant: 使用微信掃碼授權,自動填寫令牌
ja_JP: WeChatでQRコードをスキャンし、トークンを自動入力
type: qr-code-login
login_platform: weixin
required: false
- name: token
label:
en_US: Token

View File

@@ -1,14 +1,14 @@
from __future__ import annotations
import time
import telegram
import telegram.ext
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import ApplicationBuilder, ContextTypes, MessageHandler, CallbackQueryHandler, filters
from telegram import Update
from telegram.ext import ApplicationBuilder, ContextTypes, MessageHandler, filters
import telegramify_markdown
import typing
import traceback
import json
import base64
import pydantic
@@ -189,7 +189,6 @@ class TelegramEventConverter(abstract_platform_adapter.AbstractEventConverter):
class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: telegram.Bot = pydantic.Field(exclude=True)
application: telegram.ext.Application = pydantic.Field(exclude=True)
ap: typing.Any = pydantic.Field(exclude=True, default=None)
message_converter: TelegramMessageConverter = TelegramMessageConverter()
event_converter: TelegramEventConverter = TelegramEventConverter()
@@ -225,102 +224,6 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
telegram_callback,
)
)
async def callback_query_handler(update: Update, context: ContextTypes.DEFAULT_TYPE):
query = update.callback_query
await query.answer()
try:
data = json.loads(query.data)
if data.get('form_action') or data.get('f'):
import langbot_plugin.api.entities.builtin.provider.session as provider_session
workflow_run_id = data.get('workflow_run_id', '')
w_suffix = data.get('w', '')
action_id = data.get('action_id') or data.get('a', '')
session_key = data.get('session_key') or data.get('s', '')
if session_key.startswith('group_') or session_key.startswith('g:'):
launcher_type = provider_session.LauncherTypes.GROUP
launcher_id = (
session_key.split(':', 1)[1]
if session_key.startswith('g:')
else session_key[len('group_') :]
)
else:
launcher_type = provider_session.LauncherTypes.PERSON
launcher_id = (
session_key.split(':', 1)[1]
if session_key.startswith('p:')
else session_key[len('person_') :]
)
user_id = str(query.from_user.id)
# Find bot_uuid and pipeline_uuid
bot_uuid = ''
pipeline_uuid = None
for b in self.ap.platform_mgr.bots:
if b.adapter is self:
bot_uuid = b.bot_entity.uuid
pipeline_uuid = b.bot_entity.use_pipeline_uuid
break
form_action_data = {
'workflow_run_id': workflow_run_id,
'w_suffix': w_suffix,
'action_id': action_id,
'user': f'{launcher_type.value}_{launcher_id}',
'inputs': {},
}
message_chain = platform_message.MessageChain(
[platform_message.Plain(text=f'[Form Action: {action_id}]')]
)
if launcher_type == provider_session.LauncherTypes.GROUP:
synthetic_event = platform_events.GroupMessage(
sender=platform_entities.GroupMember(
id=user_id,
member_name='',
permission=platform_entities.Permission.Member,
group=platform_entities.Group(
id=launcher_id,
name='',
permission=platform_entities.Permission.Member,
),
),
message_chain=message_chain,
source_platform_object=update,
)
else:
synthetic_event = platform_events.FriendMessage(
sender=platform_entities.Friend(
id=user_id,
nickname='',
remark='',
),
message_chain=message_chain,
source_platform_object=update,
)
await self.ap.query_pool.add_query(
bot_uuid=bot_uuid,
launcher_type=launcher_type,
launcher_id=launcher_id,
sender_id=user_id,
message_event=synthetic_event,
message_chain=message_chain,
adapter=self,
pipeline_uuid=pipeline_uuid,
variables={
'_dify_form_action': form_action_data,
'_routed_by_rule': True,
},
)
except Exception:
await self.logger.error(f'Error in telegram callback query: {traceback.format_exc()}')
application.add_handler(CallbackQueryHandler(callback_query_handler))
super().__init__(
config=config,
logger=logger,
@@ -416,19 +319,14 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
update = event.source_platform_object
chat_id = update.effective_chat.id
chat_type = update.effective_chat.type
effective_message = update.effective_message
message_thread_id = getattr(effective_message, 'message_thread_id', None) if effective_message else None
message_thread_id = update.message.message_thread_id
if chat_type == 'private':
import time as _time
draft_id = int(_time.time() * 1000)
draft_id = int(time.time() * 1000)
self.msg_stream_id[message_id] = ('private', draft_id)
args = self._build_message_args(chat_id, 'Thinking...', message_thread_id, draft_id=draft_id)
try:
await self.bot.send_message_draft(**args)
except (telegram.error.RetryAfter, telegram.error.BadRequest):
pass
await self.bot.send_message_draft(**args)
else:
args = self._build_message_args(chat_id, 'Thinking...', message_thread_id)
send_msg = await self.bot.send_message(**args)
@@ -449,13 +347,12 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
assert isinstance(message_source.source_platform_object, Update)
update = message_source.source_platform_object
chat_id = update.effective_chat.id
effective_message = update.effective_message
message_thread_id = getattr(effective_message, 'message_thread_id', None) if effective_message else None
message_thread_id = update.message.message_thread_id
if message_id not in self.msg_stream_id:
return
chat_mode, stream_id = self.msg_stream_id[message_id]
chat_mode, draft_id = self.msg_stream_id[message_id]
components = await TelegramMessageConverter.yiri2target(message, self.bot)
if not components or components[0]['type'] != 'text':
@@ -464,42 +361,16 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
return
content = components[0]['text']
form_data = getattr(bot_message, '_form_data', None)
if form_data and is_final:
self.msg_stream_id.pop(message_id, None)
await self._send_form_action_buttons(message_source, form_data)
return
if chat_mode == 'private':
# Streaming via draft (ephemeral preview in the chat input area)
if (msg_seq - 1) % 8 == 0 or is_final:
args = self._build_message_args(chat_id, content, message_thread_id, draft_id=stream_id)
try:
await self.bot.send_message_draft(**args)
except telegram.error.BadRequest as exc:
if 'Message_too_long' in str(exc):
args['text'] = content[:4000] + '\n\n… (truncated)'
try:
await self.bot.send_message_draft(**args)
except telegram.error.RetryAfter:
pass
else:
pass # Ignore other draft errors (cosmetic)
args = self._build_message_args(chat_id, content, message_thread_id, draft_id=draft_id)
await self.bot.send_message_draft(**args)
if is_final and bot_message.tool_calls is None:
# Finalise: send the real message, discard the draft
args = self._build_message_args(chat_id, content, message_thread_id)
try:
await self.bot.send_message(**args)
except telegram.error.BadRequest as exc:
if 'Message_too_long' in str(exc):
args['text'] = content[:4000] + '\n\n… (truncated)'
await self.bot.send_message(**args)
else:
raise
del args['draft_id']
await self.bot.send_message(**args)
self.msg_stream_id.pop(message_id)
else:
# Streaming via edit_message_text (persistent message)
stream_id = draft_id
if (msg_seq - 1) % 8 == 0 or is_final:
args = {
'message_id': stream_id,
@@ -508,68 +379,11 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
}
if self.config.get('markdown_card', False):
args['parse_mode'] = 'MarkdownV2'
try:
await self.bot.edit_message_text(**args)
except telegram.error.BadRequest as exc:
if 'Message_too_long' in str(exc):
args['text'] = self._process_markdown(content[:4000] + '\n\n… (truncated)')
await self.bot.edit_message_text(**args)
else:
raise
await self.bot.edit_message_text(**args)
if is_final and bot_message.tool_calls is None:
self.msg_stream_id.pop(message_id)
async def _send_form_action_buttons(
self,
message_source: platform_events.MessageEvent,
form_data: dict,
):
"""Send inline keyboard buttons for Dify human_input_required form actions."""
actions = form_data.get('actions', [])
node_title = form_data.get('node_title', '')
form_content = form_data.get('form_content', '')
workflow_run_id = form_data.get('workflow_run_id', '')
# Telegram callback_data is capped at 64 bytes, so we identify the
# paused workflow by the last 8 chars of workflow_run_id (unique
# within a session with overwhelming probability).
w_suffix = workflow_run_id[-8:] if workflow_run_id else ''
if isinstance(message_source, platform_events.GroupMessage):
session_key = f'g:{message_source.group.id}'
else:
session_key = f'p:{message_source.sender.id}'
keyboard = []
for action in actions:
action_id = action.get('id', '')
action_title = action.get('title', action_id)
callback_payload = {'f': 1, 'a': action_id, 's': session_key}
if w_suffix:
callback_payload['w'] = w_suffix
callback_data = json.dumps(callback_payload, separators=(',', ':'))
keyboard.append([InlineKeyboardButton(action_title, callback_data=callback_data)])
reply_markup = InlineKeyboardMarkup(keyboard)
update = message_source.source_platform_object
chat_id = update.effective_chat.id
effective_message = update.effective_message
message_thread_id = getattr(effective_message, 'message_thread_id', None) if effective_message else None
text_lines = [f'[{node_title}] Please select an action:']
if form_content:
text_lines.insert(0, form_content)
args = {
'chat_id': chat_id,
'text': '\n\n'.join(text_lines),
'reply_markup': reply_markup,
}
if message_thread_id:
args['message_thread_id'] = message_thread_id
await self.bot.send_message(**args)
def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None:
if not isinstance(event.source_platform_object, Update):
return None

View File

@@ -27,7 +27,10 @@ class WebPageBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
listeners: dict = pydantic.Field(default_factory=dict, exclude=True)
_ws_adapter: typing.Any = None
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
class Config:
arbitrary_types_allowed = True
# Allow private attributes
underscore_attrs_are_private = True
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
super().__init__(config=config, logger=logger, **kwargs)

View File

@@ -296,7 +296,6 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
listeners: dict = {}
_stream_to_monitoring_msg: dict = {} # Maps stream_id to (monitoring_message_id, timestamp)
_STREAM_MAPPING_TTL = 600 # 10 minutes
ap: typing.Any = None
def __init__(self, config: dict, logger: EventLogger):
enable_webhook = config.get('enable-webhook', False)
@@ -337,12 +336,6 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
_stream_to_monitoring_msg={},
)
# Both WecomBotClient (webhook) and WecomBotWsClient (ws long-conn)
# expose ``set_card_action_callback``. Wire the click handler so
# Dify human-input button taps resume the workflow on either mode.
if hasattr(self.bot, 'set_card_action_callback'):
self.bot.set_card_action_callback(self._on_card_action)
async def reply_message(
self,
message_source: platform_events.MessageEvent,
@@ -352,37 +345,15 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
content = await self.message_converter.yiri2target(message)
_ws_mode = not self.config.get('enable-webhook', False)
event = message_source.source_platform_object
# Synthetic events (button-click resume queries) have no inbound
# platform object. Fall back to a proactive send so error
# messages and one-shot replies still reach the user.
if event is None:
if _ws_mode:
if isinstance(message_source, platform_events.GroupMessage):
chat_id = str(message_source.group.id)
else:
chat_id = str(message_source.sender.id)
try:
await self.bot.send_message(chat_id, content)
except Exception:
await self.logger.error(
f'WeComBot: proactive reply for synthetic event failed: {traceback.format_exc()}'
)
else:
await self.logger.warning(
'WeComBot webhook mode cannot reply to a synthetic event '
'(no req_id and no proactive-send credentials); dropping.'
)
return
if _ws_mode:
req_id = event.get('req_id', '') if isinstance(event, dict) else getattr(event, 'req_id', '')
event = message_source.source_platform_object
req_id = event.get('req_id', '')
if req_id:
await self.bot.reply_text(req_id, content)
else:
await self.bot.set_message(event.message_id, content)
else:
await self.bot.set_message(event.message_id, content)
await self.bot.set_message(message_source.source_platform_object.message_id, content)
async def reply_message_chunk(
self,
@@ -393,56 +364,9 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
is_final: bool = False,
):
content = await self.message_converter.yiri2target(message)
msg_id = message_source.source_platform_object.message_id
_ws_mode = not self.config.get('enable-webhook', False)
# Synthetic events (e.g. button-click triggered form resume) have
# no inbound platform message — no msg_id, no req_id, no stream
# session. The output must go via the proactive-send path instead
# of the stream/reply path.
spo = message_source.source_platform_object
if spo is None:
return await self._handle_synthetic_chunk(message_source, bot_message, content, is_final, _ws_mode)
msg_id = spo.message_id
# Dify human-input pause: when the runner attaches `_form_data` to
# the final chunk, hand the button_interaction card off to the
# underlying client. In webhook mode the card is queued for the
# next followup poll; in ws mode it's sent as a reply frame
# immediately. Falls back to plain text when the bot has no active
# stream session for this msg_id (rare).
form_data = getattr(bot_message, '_form_data', None)
if form_data and is_final:
if hasattr(self.bot, 'push_form_pause'):
ok, stream_id, task_id = await self.bot.push_form_pause(msg_id, form_data)
if ok:
await self.logger.info(
f'WeComBot: pending button_interaction registered '
f'stream_id={stream_id} task_id={task_id} ws_mode={_ws_mode}'
)
return {'stream': True, 'form': True, 'task_id': task_id}
await self.logger.warning(
'WeComBot: cannot register form pause (no active stream session); falling back to plain text'
)
try:
from langbot.pkg.provider.runners.difysvapi import _format_human_input_text
fallback = _format_human_input_text(
form_data.get('node_title', ''),
form_data.get('form_content', ''),
form_data.get('actions', []) or [],
)
except Exception:
fallback = content or '(人工输入)'
if _ws_mode:
event = message_source.source_platform_object
req_id = event.get('req_id', '') if isinstance(event, dict) else getattr(event, 'req_id', '')
if req_id:
await self.bot.reply_text(req_id, fallback)
else:
await self.bot.set_message(msg_id, fallback)
return {'stream': False, 'form': True, 'fallback': True}
if _ws_mode:
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
if not success and is_final:
@@ -461,129 +385,6 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
"""Whether streaming output is enabled for this bot instance."""
return self.config.get('enable-stream-reply', True)
async def _handle_synthetic_chunk(
self,
message_source: platform_events.MessageEvent,
bot_message,
content: str,
is_final: bool,
ws_mode: bool,
) -> dict:
"""Handle reply_message_chunk for synthetic events (button clicks).
Synthetic events have no inbound message → no msg_id, no req_id,
no stream session. We can't do incremental streaming, so we
buffer chunks per-conversation and flush on ``is_final`` via the
proactive send path.
Buffer keyed by ``(launcher_type, launcher_id)`` from the
synthetic event itself. Only ws mode has a usable proactive-send
path right now (``ws_client.send_message`` /
``ws_client.send_template_card``); webhook mode requires a
corpid/secret we don't have, so it logs and drops.
"""
if isinstance(message_source, platform_events.GroupMessage):
chat_id = str(message_source.group.id)
else:
chat_id = str(message_source.sender.id)
form_data = getattr(bot_message, '_form_data', None)
# Buffer streaming content until is_final.
buf_key = chat_id
if not hasattr(self, '_synthetic_buffers'):
# Attribute-not-declared trick: pydantic forbids dynamic attrs
# on the model, but plain instance dicts via object.__setattr__
# do work. Lazy-create on first call.
object.__setattr__(self, '_synthetic_buffers', {})
buffers: dict[str, str] = self._synthetic_buffers
if content and not form_data:
buffers[buf_key] = buffers.get(buf_key, '') + content
if not is_final:
return {'stream': True, 'synthetic': True, 'buffered': True}
final_content = buffers.pop(buf_key, '')
if content and final_content.startswith(content):
# is_final chunk re-emitted the full accumulated text — keep
# whichever is longer.
final_content = final_content if len(final_content) >= len(content) else content
elif content and not final_content:
final_content = content
if not ws_mode:
await self.logger.warning(
'WeComBot webhook mode cannot proactively push synthetic-event '
'output (no corpid/secret); the resume reply is dropped. '
f'content_len={len(final_content)} form_data_present={form_data is not None}'
)
return {'stream': False, 'synthetic': True, 'dropped': True}
# ws mode: proactive send.
try:
if form_data:
# Determine user_id / chat_id for the routing context of any
# subsequent click on this card.
if isinstance(message_source, platform_events.GroupMessage):
routing_chat_id = str(message_source.group.id)
routing_user_id = str(message_source.sender.id)
else:
routing_chat_id = ''
routing_user_id = str(message_source.sender.id)
payload = self._build_button_interaction_payload_from_form(
form_data,
user_id=routing_user_id,
chat_id=routing_chat_id,
)
await self.bot.send_template_card(chat_id, payload)
await self.logger.info(
f'WeComBot ws: proactively sent template_card for synthetic event '
f'chat_id={chat_id} form_token={form_data.get("form_token")!r} '
f'workflow_run_id={form_data.get("workflow_run_id")!r}'
)
elif final_content:
await self.bot.send_message(chat_id, final_content)
await self.logger.info(
f'WeComBot ws: proactively sent text for synthetic event chat_id={chat_id} len={len(final_content)}'
)
except Exception:
await self.logger.error(f'WeComBot: synthetic event proactive send failed: {traceback.format_exc()}')
return {'stream': False, 'synthetic': True, 'error': True}
return {'stream': True, 'synthetic': True}
def _build_button_interaction_payload_from_form(
self, form_data: dict, *, user_id: str = '', chat_id: str = ''
) -> dict:
"""Build a button_interaction payload + track task_id for click resolution.
Unlike the inbound-event path (where push_form_pause registers the
task_id with the active stream session), proactive sends still
need the task_id registered so button clicks find pending_form.
For ws mode we stash it directly on the ws_client's pending dict.
"""
from langbot.libs.wecom_ai_bot_api.api import build_button_interaction_payload
import secrets as _secrets
task_id = f'dify-{_secrets.token_hex(12)}'
payload = build_button_interaction_payload(form_data, task_id)
# Register task_id → form_data so the click callback can find it.
# user_id / chat_id are required so _on_card_action can route the
# resulting synthetic query back to the right user. msg_id / req_id
# / stream_id are intentionally empty — synthetic cards have no
# inbound message to anchor on.
if hasattr(self.bot, '_pending_forms_by_task'):
self.bot._pending_forms_by_task[task_id] = {
'form_data': form_data,
'msg_id': '',
'user_id': user_id,
'chat_id': chat_id,
'stream_id': '',
'req_id': '',
}
return payload
async def send_message(self, target_type, target_id, message):
_ws_mode = not self.config.get('enable-webhook', False)
if _ws_mode:
@@ -730,114 +531,3 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
async def is_muted(self, group_id: int) -> bool:
pass
# ------------------------------------------------------------------
# Dify human-input button-interaction click handling
# ------------------------------------------------------------------
async def _on_card_action(self, session, action_id: str, task_id: str, raw_event: dict) -> None:
"""Translate a button click on a button_interaction card into a
synthetic ``_dify_form_action`` query enqueued on the pool.
Pattern mirrors DingTalk / Lark / Telegram so the runner's
``_merge_pending_form_action`` path resumes the workflow.
"""
import langbot_plugin.api.entities.builtin.provider.session as provider_session
form = session.pending_form or {}
await self.logger.info(
f'WeComBot _on_card_action: task_id={task_id} action_id={action_id!r} '
f'form_token={form.get("form_token")!r} workflow_run_id={form.get("workflow_run_id")!r} '
f'session.user_id={session.user_id!r} session.chat_id={session.chat_id!r}'
)
actions = form.get('actions') or []
clean_action_id = (action_id or '').strip()
action_title = clean_action_id
for a in actions:
if str(a.get('id', '')) == clean_action_id:
action_title = a.get('title') or clean_action_id
break
launcher_id = session.user_id or session.chat_id or ''
sender_user_id = session.user_id or launcher_id
# WeCom AI bot has both single-chat and group-chat; chat_id present
# indicates group context.
if session.chat_id:
launcher_type = provider_session.LauncherTypes.GROUP
launcher_id = session.chat_id
else:
launcher_type = provider_session.LauncherTypes.PERSON
launcher_id = session.user_id or ''
form_action_data = {
'form_token': form.get('form_token', ''),
'workflow_run_id': form.get('workflow_run_id', ''),
'action_id': clean_action_id,
'action_title': action_title,
'node_title': form.get('node_title', ''),
'user': f'{launcher_type.value}_{launcher_id}',
'inputs': {},
}
message_chain = platform_message.MessageChain([platform_message.Plain(text=f'[Form Action: {action_title}]')])
if launcher_type == provider_session.LauncherTypes.GROUP:
synthetic_event = platform_events.GroupMessage(
sender=platform_entities.GroupMember(
id=sender_user_id,
member_name='',
permission=platform_entities.Permission.Member,
group=platform_entities.Group(
id=launcher_id,
name='',
permission=platform_entities.Permission.Member,
),
special_title='',
),
message_chain=message_chain,
time=int(time.time()),
source_platform_object=None,
)
else:
synthetic_event = platform_events.FriendMessage(
sender=platform_entities.Friend(
id=sender_user_id,
nickname='',
remark='',
),
message_chain=message_chain,
time=int(time.time()),
source_platform_object=None,
)
if self.ap is None:
await self.logger.error('WeComBot: ap not injected; cannot enqueue button-click query')
return
bot_uuid = ''
pipeline_uuid = None
for bot in self.ap.platform_mgr.bots:
if bot.adapter is self:
bot_uuid = bot.bot_entity.uuid
pipeline_uuid = bot.bot_entity.use_pipeline_uuid
break
try:
await self.ap.query_pool.add_query(
bot_uuid=bot_uuid,
launcher_type=launcher_type,
launcher_id=launcher_id,
sender_id=sender_user_id,
message_event=synthetic_event,
message_chain=message_chain,
adapter=self,
pipeline_uuid=pipeline_uuid,
variables={
'_dify_form_action': form_action_data,
'_routed_by_rule': True,
},
)
await self.logger.info(f'WeComBot: button-click query enqueued action_id={clean_action_id!r}')
except Exception:
await self.logger.error(f'WeComBot: enqueue button-click query failed: {traceback.format_exc()}')

View File

@@ -19,18 +19,6 @@ spec:
en: https://link.langbot.app/en/platforms/wecombot
ja: https://link.langbot.app/ja/platforms/wecombot
config:
- name: one-click-create
label:
en_US: One-Click Create Bot
zh_Hans: 一键创建机器人
zh_Hant: 一鍵建立機器人
description:
en_US: "Scan QR code with WeCom to automatically create a bot and fill in BotId and Secret. Note: Robot Name needs to be filled in manually."
zh_Hans: "使用企业微信扫码自动创建机器人并填写 BotId 和 Secret。注意机器人名称需手动填写。"
zh_Hant: "使用企業微信掃碼自動建立機器人並填寫 BotId 和 Secret。注意機器人名稱需手動填寫。"
type: qr-code-login
login_platform: wecombot
required: false
- name: BotId
label:
en_US: BotId

View File

@@ -11,7 +11,6 @@ import os
import sys
import httpx
import sqlalchemy
import yaml
from async_lru import alru_cache
from langbot_plugin.api.entities.builtin.pipeline.query import provider_session
@@ -35,10 +34,6 @@ from ..core import taskmgr
from ..entity.persistence import plugin as persistence_plugin
class PluginRuntimeNotConnectedError(RuntimeError):
"""Raised when plugin runtime operations are requested before connection."""
class PluginRuntimeConnector:
"""Plugin runtime connector"""
@@ -196,114 +191,44 @@ class PluginRuntimeConnector:
async def ping_plugin_runtime(self):
if not hasattr(self, 'handler'):
raise PluginRuntimeNotConnectedError('Plugin runtime is not connected')
raise Exception('Plugin runtime is not connected')
return await self.handler.ping()
def _inspect_plugin_package(
def _extract_deps_metadata(
self,
file_bytes: bytes,
task_context: taskmgr.TaskContext | None,
) -> tuple[str | None, str | None]:
"""Extract plugin identity and dependency metadata from a plugin package."""
plugin_author = None
plugin_name = None
):
"""Extract dependency count from requirements.txt inside plugin zip."""
if task_context is None:
return
try:
with zipfile.ZipFile(io.BytesIO(file_bytes)) as zf:
try:
manifest = yaml.safe_load(zf.read('manifest.yaml').decode('utf-8', errors='ignore')) or {}
metadata = manifest.get('metadata', {})
plugin_author = metadata.get('author')
plugin_name = metadata.get('name')
except Exception:
pass
if task_context is not None:
for name in zf.namelist():
if name.endswith('requirements.txt'):
content = zf.read(name).decode('utf-8', errors='ignore')
deps = [
line.strip()
for line in content.splitlines()
if line.strip() and not line.strip().startswith('#')
]
task_context.metadata['deps_total'] = len(deps)
task_context.metadata['deps_list'] = deps
break
for name in zf.namelist():
if name.endswith('requirements.txt'):
content = zf.read(name).decode('utf-8', errors='ignore')
deps = [
line.strip()
for line in content.splitlines()
if line.strip() and not line.strip().startswith('#')
]
task_context.metadata['deps_total'] = len(deps)
task_context.metadata['deps_list'] = deps
break
except Exception:
pass
return plugin_author, plugin_name
def _build_plugin_startup_failure_message(
self,
plugin_author: str,
plugin_name: str,
task_context: taskmgr.TaskContext | None,
) -> str:
dep_hint = ''
if task_context is not None:
current_dep = task_context.metadata.get('current_dep')
if current_dep:
dep_hint = f' Last dependency: {current_dep}.'
return (
f'Plugin {plugin_author}/{plugin_name} failed to start after installation. '
f'Dependency installation or plugin initialization may have failed.{dep_hint} '
f'Please check the plugin requirements and runtime logs.'
)
async def _wait_for_installed_plugin_ready(
self,
plugin_author: str | None,
plugin_name: str | None,
task_context: taskmgr.TaskContext | None,
timeout: float = 30,
):
"""Wait until the installed plugin is registered by the runtime.
The plugin runtime launches plugins asynchronously. If dependency installation
fails, the plugin process exits before registration; without this check the
install task can incorrectly finish successfully.
"""
if not plugin_author or not plugin_name:
return
deadline = time.time() + timeout
last_error: Exception | None = None
while time.time() < deadline:
try:
plugin = await self.get_plugin_info(plugin_author, plugin_name)
if plugin is not None:
status = plugin.get('status')
if status == 'initialized':
return
except Exception as e:
last_error = e
await asyncio.sleep(0.5)
message = self._build_plugin_startup_failure_message(plugin_author, plugin_name, task_context)
if last_error is not None:
message = f'{message} Last runtime error: {last_error}'
raise RuntimeError(message)
async def install_plugin(
self,
install_source: PluginInstallSource,
install_info: dict[str, Any],
task_context: taskmgr.TaskContext | None = None,
):
plugin_author = install_info.get('plugin_author')
plugin_name = install_info.get('plugin_name')
if install_source == PluginInstallSource.LOCAL:
# transfer file before install
file_bytes = install_info['plugin_file']
plugin_author, plugin_name = self._inspect_plugin_package(file_bytes, task_context)
if task_context is not None and plugin_author and plugin_name:
task_context.metadata['plugin_name'] = f'{plugin_author}/{plugin_name}'
self._extract_deps_metadata(file_bytes, task_context)
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
install_info['plugin_file_key'] = file_key
del install_info['plugin_file']
@@ -340,9 +265,7 @@ class PluginRuntimeConnector:
task_context.metadata['download_speed'] = downloaded / elapsed if elapsed > 0 else 0
file_bytes = b''.join(chunks)
plugin_author, plugin_name = self._inspect_plugin_package(file_bytes, task_context)
if task_context is not None and plugin_author and plugin_name:
task_context.metadata['plugin_name'] = f'{plugin_author}/{plugin_name}'
self._extract_deps_metadata(file_bytes, task_context)
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
install_info['plugin_file_key'] = file_key
self.ap.logger.info(f'Transfered file {file_key} to plugin runtime')
@@ -366,8 +289,6 @@ class PluginRuntimeConnector:
if metadata is not None and task_context is not None:
task_context.metadata.update(metadata)
await self._wait_for_installed_plugin_ready(plugin_author, plugin_name, task_context)
async def upgrade_plugin(
self,
plugin_author: str,
@@ -637,12 +558,11 @@ class PluginRuntimeConnector:
Raises:
ValueError: If plugin_id is not in the expected 'author/name' format.
"""
segments = plugin_id.split('/')
if len(segments) != 2 or not all(segments):
if '/' not in plugin_id:
raise ValueError(
f"Invalid plugin_id format: '{plugin_id}'. Expected 'author/name' format (e.g. 'langbot/rag-engine')."
)
return segments[0], segments[1]
return plugin_id.split('/', 1)
async def call_rag_ingest(self, plugin_id: str, context_data: dict[str, Any]) -> dict[str, Any]:
"""Call plugin to ingest document.

View File

@@ -340,7 +340,6 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
"""Provider API请求器"""
name: str = None
init_api_key: str = 'langbot-init-placeholder'
ap: app.Application

View File

@@ -25,7 +25,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
async def initialize(self):
self.client = openai.AsyncClient(
api_key=self.init_api_key,
api_key='',
base_url=self.requester_cfg['base_url'].replace(' ', ''),
timeout=self.requester_cfg['timeout'],
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),

View File

@@ -25,7 +25,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
async def initialize(self):
self.client = openai.AsyncClient(
api_key=self.init_api_key,
api_key='',
base_url=self.requester_cfg['base_url'],
timeout=self.requester_cfg['timeout'],
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),

View File

@@ -14,14 +14,7 @@ class TokenManager:
def __init__(self, name: str, tokens: list[str]):
self.name = name
self.tokens = []
seen_tokens = set()
for token in tokens:
normalized_token = token.strip() if isinstance(token, str) else ''
if not normalized_token or normalized_token in seen_tokens:
continue
self.tokens.append(normalized_token)
seen_tokens.add(normalized_token)
self.tokens = tokens
self.using_token_index = 0
def get_token(self) -> str:
@@ -30,6 +23,4 @@ class TokenManager:
return self.tokens[self.using_token_index]
def next_token(self):
if len(self.tokens) == 0:
return
self.using_token_index = (self.using_token_index + 1) % len(self.tokens)

View File

@@ -2,11 +2,9 @@ from __future__ import annotations
import typing
import json
import time
import uuid
import base64
import mimetypes
from collections import OrderedDict
from langbot.pkg.provider import runner
@@ -18,125 +16,6 @@ from langbot.libs.dify_service_api.v1 import client, errors
import httpx
# Module-level store for paused-workflow form state, keyed by session key
# (launcher_type_value + "_" + launcher_id). Each session holds an
# insertion-ordered dict of form_token -> form_data, allowing multiple
# Dify workflows to be paused simultaneously for the same session.
_PENDING_FORMS: dict[str, 'OrderedDict[str, dict[str, typing.Any]]'] = {}
_PENDING_FORM_DEFAULT_TTL = 30 * 60 # 30 minutes safety cap
def _session_key_from_query(query: pipeline_query.Query) -> str:
return f'{query.session.launcher_type.value}_{query.session.launcher_id}'
def _prune_pending_forms(now: float | None = None) -> None:
if now is None:
now = time.time()
for session_key in list(_PENDING_FORMS.keys()):
forms = _PENDING_FORMS[session_key]
expired_tokens = [token for token, data in forms.items() if data.get('_expires_at', 0) <= now]
for token in expired_tokens:
forms.pop(token, None)
if not forms:
_PENDING_FORMS.pop(session_key, None)
def _set_pending_form(session_key: str, form_data: dict[str, typing.Any]) -> None:
_prune_pending_forms()
stored = dict(form_data)
expiration_time = stored.get('expiration_time')
try:
expiration_ts = float(expiration_time) if expiration_time is not None else 0.0
except (TypeError, ValueError):
expiration_ts = 0.0
stored['_expires_at'] = expiration_ts or (time.time() + _PENDING_FORM_DEFAULT_TTL)
form_token = str(stored.get('form_token') or '')
forms = _PENDING_FORMS.setdefault(session_key, OrderedDict())
# Re-insert at the end so this becomes the "latest" entry
forms.pop(form_token, None)
forms[form_token] = stored
def _get_pending_form_by_token(session_key: str, form_token: str) -> dict[str, typing.Any] | None:
_prune_pending_forms()
forms = _PENDING_FORMS.get(session_key)
if not forms or not form_token:
return None
return forms.get(form_token)
def _get_pending_form_by_w_suffix(session_key: str, w_suffix: str) -> dict[str, typing.Any] | None:
"""Look up a pending form whose workflow_run_id ends with the given suffix.
Used by adapters (e.g. Telegram) whose callback payload is too small to
carry the full form_token / workflow_run_id.
"""
_prune_pending_forms()
forms = _PENDING_FORMS.get(session_key)
if not forms or not w_suffix:
return None
for token in reversed(forms):
form = forms[token]
if str(form.get('workflow_run_id', '')).endswith(w_suffix):
return form
return None
def _get_latest_pending_form(session_key: str) -> dict[str, typing.Any] | None:
_prune_pending_forms()
forms = _PENDING_FORMS.get(session_key)
if not forms:
return None
return forms[next(reversed(forms))]
def _iter_pending_forms(session_key: str) -> typing.Iterator[dict[str, typing.Any]]:
"""Iterate pending forms for a session, newest-first."""
_prune_pending_forms()
forms = _PENDING_FORMS.get(session_key)
if not forms:
return
for token in reversed(list(forms.keys())):
yield forms[token]
def _clear_pending_form(session_key: str, form_token: str | None = None) -> None:
"""Clear one specific pending form (by token) or all forms for the session."""
forms = _PENDING_FORMS.get(session_key)
if not forms:
return
if form_token is None:
_PENDING_FORMS.pop(session_key, None)
return
forms.pop(form_token, None)
if not forms:
_PENDING_FORMS.pop(session_key, None)
def _format_human_input_text(
node_title: str,
form_content: str,
actions: list[dict[str, typing.Any]],
) -> str:
"""Render a paused-workflow human-input prompt as plain text.
Used by adapters without rich UI (no buttons/cards) so users can reply
with the option number or the option title to resume the workflow.
"""
lines: list[str] = [f'[Human Input Required] {node_title or ""}'.rstrip()]
if form_content:
lines.append('')
lines.append(form_content)
if actions:
lines.append('')
lines.append('Reply with the number or title to continue:')
for idx, action in enumerate(actions, start=1):
title = action.get('title') or action.get('id') or ''
lines.append(f' {idx}. {title}')
return '\n'.join(lines)
@runner.runner_class('dify-service-api')
class DifyServiceAPIRunner(runner.RequestRunner):
"""Dify Service API 对话请求器"""
@@ -456,155 +335,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
query.session.using_conversation.uuid = chunk['conversation_id']
async def _submit_workflow_form_blocking(
self, form_action: dict
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""Submit human input to resume a paused Dify workflow (non-streaming)."""
form_token = form_action['form_token']
workflow_run_id = form_action['workflow_run_id']
user = form_action['user']
action_id = form_action.get('action_id', '')
inputs = form_action.get('inputs', {})
async for chunk in self.dify_client.workflow_submit(
form_token=form_token,
workflow_run_id=workflow_run_id,
inputs=inputs,
user=user,
action=action_id,
timeout=120,
):
self.ap.logger.debug('dify-workflow-submit-chunk: ' + str(chunk))
if chunk['event'] == 'workflow_finished':
if chunk['data'].get('error'):
raise errors.DifyAPIError(chunk['data']['error'])
content, _ = self._process_thinking_content(chunk['data']['outputs']['summary'])
yield provider_message.Message(
role='assistant',
content=content,
)
def _resolve_pending_form(self, session_key: str, form_action: dict) -> dict | None:
"""Locate the pending form this action targets.
Tries identifiers in order of specificity: form_token, full
workflow_run_id, workflow_run_id suffix (Telegram-style compact id),
then falls back to the newest pending form for the session.
"""
form_token = form_action.get('form_token')
if form_token:
form = _get_pending_form_by_token(session_key, form_token)
if form:
return form
workflow_run_id = form_action.get('workflow_run_id')
if workflow_run_id:
for form in _iter_pending_forms(session_key):
if form.get('workflow_run_id') == workflow_run_id:
return form
w_suffix = form_action.get('w_suffix')
if w_suffix:
form = _get_pending_form_by_w_suffix(session_key, w_suffix)
if form:
return form
return _get_latest_pending_form(session_key)
def _merge_pending_form_action(self, session_key: str, form_action: dict | None) -> dict | None:
"""Backfill resume fields from the matching pending form."""
if not form_action:
return None
merged_action = dict(form_action)
merged_action.pop('w_suffix', None)
pending_form = self._resolve_pending_form(session_key, form_action)
if pending_form:
merged_action['form_token'] = merged_action.get('form_token') or pending_form.get('form_token', '')
merged_action['workflow_run_id'] = merged_action.get('workflow_run_id') or pending_form.get(
'workflow_run_id', ''
)
merged_action.setdefault('inputs', pending_form.get('inputs', {}))
merged_action.setdefault('user', pending_form.get('user', ''))
merged_action.setdefault('node_title', pending_form.get('node_title', ''))
# Resolve clicked action's display title from the stored actions list
if 'action_title' not in merged_action:
clicked_id = merged_action.get('action_id', '')
for action in pending_form.get('actions', []):
if str(action.get('id', '')) == str(clicked_id):
merged_action['action_title'] = action.get('title', clicked_id)
break
return merged_action
def _match_pending_form_action(self, session_key: str, user_text: str) -> dict | None:
"""Match plain text replies against pending Dify form actions.
Resolution order:
1. A pure digit reply (e.g. "1", "2") maps to the 1-indexed action of
the most recent pending form. Lets users on plain-text platforms
pick options without retyping titles.
2. Otherwise, iterate pending forms newest-first and match each
action's title/id case-insensitively. The first hit wins, so when
two forms share a button label the newer one resolves.
"""
normalized_text = user_text.strip().lower()
if not normalized_text:
return None
def _build(pending_form: dict, action: dict) -> dict:
return {
'form_token': pending_form.get('form_token', ''),
'workflow_run_id': pending_form.get('workflow_run_id', ''),
'action_id': action.get('id', ''),
'action_title': action.get('title', action.get('id', '')),
'node_title': pending_form.get('node_title', ''),
'inputs': pending_form.get('inputs', {}),
'user': pending_form.get('user', ''),
}
if normalized_text.isdigit():
position = int(normalized_text)
latest_form = _get_latest_pending_form(session_key)
if latest_form is not None:
actions = latest_form.get('actions', [])
if 1 <= position <= len(actions):
return _build(latest_form, actions[position - 1])
for pending_form in _iter_pending_forms(session_key):
for action in pending_form.get('actions', []):
titles = {
str(action.get('title', '')).strip().lower(),
str(action.get('id', '')).strip().lower(),
}
if normalized_text in titles:
return _build(pending_form, action)
return None
async def _workflow_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""调用工作流"""
# Check if this is a form action resume (button click or text match)
form_action_raw = query.variables.get('_dify_form_action')
session_key = _session_key_from_query(query)
if form_action_raw:
form_action = self._merge_pending_form_action(session_key, form_action_raw)
else:
form_action = self._match_pending_form_action(session_key, str(query.message_chain))
if form_action:
_clear_pending_form(session_key, form_action.get('form_token') or None)
async for msg in self._submit_workflow_form_blocking(form_action):
yield msg
return
if not query.session.using_conversation.uuid:
query.session.using_conversation.uuid = str(uuid.uuid4())
@@ -631,7 +366,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
}
inputs.update(query.variables)
human_input_yielded = False
async for chunk in self.dify_client.workflow_run(
inputs=inputs,
@@ -643,45 +377,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if chunk['event'] in ignored_events:
continue
if chunk['event'] == 'workflow_paused':
reasons = chunk['data'].get('reasons', [])
workflow_run_id = chunk['data'].get('workflow_run_id', '')
for reason in reasons:
if reason.get('TYPE') == 'human_input_required':
form_content = reason.get('form_content', '')
actions = reason.get('actions', [])
node_title = reason.get('node_title', '')
_set_pending_form(
_session_key_from_query(query),
{
'workflow_run_id': workflow_run_id,
'form_id': reason.get('form_id'),
'form_token': reason.get('form_token'),
'node_id': reason.get('node_id'),
'node_title': node_title,
'form_content': form_content,
'inputs': reason.get('inputs', {}),
'actions': actions,
'expiration_time': reason.get('expiration_time'),
'user': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
},
)
query.variables['_dify_form_render'] = {
'form_content': form_content,
'actions': actions,
'node_title': node_title,
}
display_text = _format_human_input_text(node_title, form_content, actions)
human_input_yielded = True
yield provider_message.Message(
role='assistant',
content=display_text,
)
if chunk['event'] == 'node_started':
if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end':
continue
@@ -704,8 +399,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
yield msg
elif chunk['event'] == 'workflow_finished':
if human_input_yielded:
break
if chunk['data']['error']:
raise errors.DifyAPIError(chunk['data']['error'])
content, _ = self._process_thinking_content(chunk['data']['outputs']['summary'])
@@ -943,153 +636,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
query.session.using_conversation.uuid = chunk['conversation_id']
async def _submit_workflow_form(
self, form_action: dict
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""Submit human input to resume a paused Dify workflow."""
form_token = form_action['form_token']
workflow_run_id = form_action['workflow_run_id']
user = form_action['user']
action_id = form_action.get('action_id', '')
action_title = form_action.get('action_title', '') or action_id
node_title = form_action.get('node_title', '')
inputs = form_action.get('inputs', {})
messsage_idx = 0
is_final = False
think_start = False
think_end = False
workflow_contents = ''
repause_form_data: dict | None = None
remove_think = self.pipeline_config['output'].get('misc', {}).get('remove-think')
async for chunk in self.dify_client.workflow_submit(
form_token=form_token,
workflow_run_id=workflow_run_id,
inputs=inputs,
user=user,
action=action_id,
timeout=120,
):
self.ap.logger.debug('dify-workflow-submit-chunk: ' + str(chunk))
yield_this_iteration = False
if chunk['event'] == 'workflow_finished':
is_final = True
yield_this_iteration = True
if chunk['data'].get('error'):
raise errors.DifyAPIError(chunk['data']['error'])
if chunk['event'] == 'workflow_paused':
reasons = chunk['data'].get('reasons', [])
new_run_id = chunk['data'].get('workflow_run_id', workflow_run_id)
for reason in reasons:
if reason.get('TYPE') != 'human_input_required':
continue
form_content = reason.get('form_content', '')
actions = reason.get('actions', [])
# Use a distinct name — `node_title` (the just-resolved step)
# must keep its value so the resume notice on the previous
# card still shows which step the user acted on.
paused_node_title = reason.get('node_title', '')
raw_inputs = reason.get('inputs', {})
_set_pending_form(
user,
{
'workflow_run_id': new_run_id,
'form_id': reason.get('form_id'),
'form_token': reason.get('form_token'),
'node_id': reason.get('node_id'),
'node_title': paused_node_title,
'form_content': form_content,
'inputs': raw_inputs if isinstance(raw_inputs, dict) else {},
'actions': actions,
'expiration_time': reason.get('expiration_time'),
'user': user,
},
)
repause_form_data = {
'form_content': form_content,
'actions': actions,
'node_title': paused_node_title,
'workflow_run_id': new_run_id,
'form_token': reason.get('form_token', ''),
}
# Ensure the final chunk has non-empty content so
# ResponseWrapper (which skips empty-content chunks) lets it
# propagate to SendResponseBackStage. Use a zero-width space
# so neither Lark nor Telegram renders visible noise — the
# adapter substitutes its own card text from _form_data.
if not workflow_contents:
workflow_contents = ''
is_final = True
yield_this_iteration = True
break
if chunk['event'] == 'text_chunk':
messsage_idx += 1
if remove_think:
if '<think>' in chunk['data']['text'] and not think_start:
think_start = True
continue
if '</think>' in chunk['data']['text'] and not think_end:
import re
content = re.sub(r'^\n</think>', '', chunk['data']['text'])
workflow_contents += content
think_end = True
elif think_end:
workflow_contents += chunk['data']['text']
if think_start:
continue
else:
workflow_contents += chunk['data']['text']
if messsage_idx % 8 == 0:
yield_this_iteration = True
if yield_this_iteration:
msg = provider_message.MessageChunk(
role='assistant',
content=workflow_contents,
is_final=is_final,
)
msg._resume_from_form = True
if action_title:
msg._resume_action_title = action_title
if node_title:
msg._resume_node_title = node_title
if is_final and repause_form_data:
msg._form_data = repause_form_data
msg._open_new_card = True
yield msg
if is_final:
return
async def _workflow_messages_chunk(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""调用工作流"""
# Check if this is a form action resume (button click or text match)
form_action_raw = query.variables.get('_dify_form_action')
session_key = _session_key_from_query(query)
if form_action_raw:
form_action = self._merge_pending_form_action(session_key, form_action_raw)
else:
form_action = self._match_pending_form_action(session_key, str(query.message_chain))
if form_action:
_clear_pending_form(session_key, form_action.get('form_token') or None)
# Resume paused workflow via submit endpoint
async for msg in self._submit_workflow_form(form_action):
yield msg
return
if not query.session.using_conversation.uuid:
query.session.using_conversation.uuid = str(uuid.uuid4())
@@ -1121,13 +672,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
think_start = False
think_end = False
workflow_contents = ''
workflow_run_id = ''
human_input_yielded = False
# Saved form data to attach to the final MessageChunk so the adapter
# can detect it when is_final=True and render buttons.
pending_form_data = None
display_text = ''
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
async for chunk in self.dify_client.workflow_run(
@@ -1138,61 +682,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
):
self.ap.logger.debug('dify-workflow-chunk: ' + str(chunk))
if chunk['event'] in ignored_events:
if chunk['event'] == 'workflow_started':
workflow_run_id = chunk['data'].get('workflow_run_id', '')
continue
if chunk['event'] == 'workflow_paused':
reasons = chunk['data'].get('reasons', [])
workflow_run_id = chunk['data'].get('workflow_run_id', workflow_run_id)
for reason in reasons:
if reason.get('TYPE') == 'human_input_required':
form_content = reason.get('form_content', '')
actions = reason.get('actions', [])
node_title = reason.get('node_title', '')
# Persist form state in module-level store keyed by session
raw_inputs = reason.get('inputs', {})
_set_pending_form(
_session_key_from_query(query),
{
'workflow_run_id': workflow_run_id,
'form_id': reason.get('form_id'),
'form_token': reason.get('form_token'),
'node_id': reason.get('node_id'),
'node_title': node_title,
'form_content': form_content,
'inputs': raw_inputs if isinstance(raw_inputs, dict) else {},
'actions': actions,
'expiration_time': reason.get('expiration_time'),
'user': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
},
)
# Pass form render metadata to downstream stages
query.variables['_dify_form_render'] = {
'form_content': form_content,
'actions': actions,
'node_title': node_title,
}
display_text = _format_human_input_text(node_title, form_content, actions)
workflow_contents += display_text + '\n'
# Save form data to attach to the final chunk later.
# We do NOT yield here — the form content will be sent
# as the final MessageChunk (with is_final=True and
# _form_data) so the adapter can update the card and
# add buttons in one pass.
pending_form_data = {
'form_content': form_content,
'actions': actions,
'node_title': node_title,
'workflow_run_id': workflow_run_id,
'form_token': reason.get('form_token', ''),
}
human_input_yielded = True
if chunk['event'] == 'workflow_finished':
is_final = True
if chunk['data']['error']:
@@ -1240,29 +730,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
yield msg
if messsage_idx % 8 == 0 or is_final:
final_content = workflow_contents if workflow_contents.strip() else ''
msg = provider_message.MessageChunk(
yield provider_message.MessageChunk(
role='assistant',
content=final_content,
content=workflow_contents,
is_final=is_final,
)
# Attach form data to the final chunk for the adapter
if is_final and pending_form_data:
msg._form_data = pending_form_data
pending_form_data = None
yield msg
# If the stream ended after workflow_paused without a
# workflow_finished event, yield a final chunk so the adapter
# can update the card and add buttons.
if human_input_yielded and not is_final:
msg = provider_message.MessageChunk(
role='assistant',
content=workflow_contents or display_text,
is_final=True,
)
msg._form_data = pending_form_data
yield msg
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""运行请求"""

View File

@@ -1,12 +1,8 @@
from __future__ import annotations
import posixpath
import re
from typing import TYPE_CHECKING, Any
from urllib.parse import unquote
if TYPE_CHECKING:
from langbot.pkg.core import app
from typing import Any
from langbot.pkg.core import app
class RAGRuntimeService:
@@ -113,17 +109,8 @@ class RAGRuntimeService:
regardless of the underlying storage provider.
"""
# Validate storage_path to prevent path traversal
decoded_path = unquote(storage_path).replace('\\', '/')
decoded_segments = decoded_path.split('/')
normalized = posixpath.normpath(decoded_path)
if (
not storage_path
or '\x00' in decoded_path
or normalized.startswith('/')
or '..' in decoded_segments
or '..' in normalized.split('/')
or re.match(r'^[A-Za-z]:/', normalized)
):
normalized = posixpath.normpath(storage_path)
if normalized.startswith('/') or '..' in normalized.split('/'):
raise ValueError('Invalid storage path')
content_bytes = await self.ap.storage_mgr.storage_provider.load(normalized)
return content_bytes if content_bytes else b''

View File

@@ -13,11 +13,12 @@ class TelemetryManager:
await telemetry.send({ ... })
"""
send_tasks: list[asyncio.Task] = []
def __init__(self, ap: core_app.Application):
self.ap = ap
self.telemetry_config = {}
self.send_tasks: list[asyncio.Task] = []
async def initialize(self):
self.telemetry_config = self.ap.instance_config.data.get('space', {})

View File

@@ -83,7 +83,7 @@ def get_func_schema(function: typing.Callable) -> dict:
parameters['properties'][param.name] = {
'type': param_type,
'description': args_doc.get(param.name, ''),
'description': args_doc[param.name],
}
# add schema for array

View File

@@ -145,8 +145,7 @@ def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]:
"""获取QQ图片的下载链接"""
parsed = urlparse(image_url)
query = parse_qs(parsed.query)
scheme = parsed.scheme or 'http'
return f'{scheme}://{parsed.netloc}{parsed.path}', query
return f'http://{parsed.netloc}{parsed.path}', query
async def get_qq_image_bytes(image_url: str, query: dict = {}) -> tuple[bytes, str]:

View File

@@ -23,10 +23,7 @@ def run_pip(params: list):
pipmain(params)
def install_requirements(file, extra_params: list | None = None):
if extra_params is None:
extra_params = []
def install_requirements(file, extra_params: list = []):
pipmain(
[
'install',

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
import ipaddress
import re
from urllib.parse import urlparse
@@ -46,40 +44,6 @@ LOCAL_PATTERNS = [
'172.31.',
]
HOST_LABEL_PATTERN = re.compile(r'^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?$')
def _is_valid_hostname(host: str) -> bool:
if host == 'localhost':
return True
try:
ipaddress.ip_address(host)
return True
except ValueError:
pass
if not host or len(host) > 253 or any(char.isspace() for char in host):
return False
host = host.rstrip('.')
if not host:
return False
return all(HOST_LABEL_PATTERN.match(label) for label in host.split('.'))
def _is_local_host(host: str) -> bool:
if host == 'localhost':
return True
try:
ip_address = ipaddress.ip_address(host)
except ValueError:
return False
return ip_address.is_private or ip_address.is_loopback or ip_address.is_unspecified
def get_runner_category(runner_name: str, runner_url: str) -> str:
if not runner_url:
@@ -88,15 +52,12 @@ def get_runner_category(runner_name: str, runner_url: str) -> str:
try:
parsed_url = urlparse(runner_url)
host = parsed_url.hostname.lower() if parsed_url.hostname else ''
_ = parsed_url.port
except Exception:
return RunnerCategory.UNKNOWN
if not parsed_url.scheme or not host or not _is_valid_hostname(host):
return RunnerCategory.UNKNOWN
if _is_local_host(host):
return RunnerCategory.LOCAL
for pattern in LOCAL_PATTERNS:
if host.startswith(pattern):
return RunnerCategory.LOCAL
for domain in CLOUD_DOMAINS:
if host.endswith(domain):

File diff suppressed because one or more lines are too long

View File

@@ -2,48 +2,6 @@
This directory contains the test suite for LangBot, with a focus on comprehensive unit testing of pipeline stages.
## Quality Gate Layers
LangBot uses a layered quality gate system for developers and CI:
| Layer | Command | What it runs | When to use |
|-------|---------|--------------|-------------|
| **Quick** | `make test-quick` or `bash scripts/test-quick.sh` | Ruff lint + Unit tests + Smoke tests | Before every commit |
| **Fast Integration** | `make test-integration-fast` or `bash scripts/test-integration-fast.sh` | SQLite/API/Pipeline integration (no external services) | Before PR, weekly |
| **Coverage Gate** | `make test-coverage` or `bash scripts/test-coverage.sh` | All tests with coverage, threshold: 18% | Before merge, CI |
| **Full Local** | `make test-all-local` | Quick + Integration + Coverage | Before major changes |
**Note**: PostgreSQL migration tests and slow tests are NOT in local default gates. They run in separate CI workflows.
### Developer Workflow
```bash
# Daily: Quick self-test
bash scripts/test-quick.sh
# Before PR: Full local gate
make test-all-local
# Or run each layer separately:
bash scripts/test-quick.sh # ~2 min
bash scripts/test-integration-fast.sh # ~3 min
bash scripts/test-coverage.sh # ~8 min
```
### Coverage Baseline
Current coverage threshold: **18%**
Actual coverage: **30%**
This is a conservative baseline to prevent coverage regression. It does NOT represent the final quality target. Key modules have higher coverage:
- `pipeline.preproc.preproc`: 53%
- `pipeline.process.process`: 96%
- `pipeline.respback.respback`: 88%
- `telemetry.telemetry`: 87%
- `provider.session.sessionmgr`: 100%
- `provider.tools.toolmgr`: 83%
- `storage.providers.s3storage`: 80%
## Important Note
Due to circular import dependencies in the pipeline module structure, the test files use **lazy imports** via `importlib.import_module()` instead of direct imports. This ensures tests can run without triggering circular import errors.
@@ -52,81 +10,19 @@ Due to circular import dependencies in the pipeline module structure, the test f
```
tests/
├── __init__.py
├── factories/ # Shared test factories
│ ├── __init__.py # Factory exports
│ ├── app.py # FakeApp factory
│ ├── message.py # Message/query factories
│ ├── provider.py # FakeProvider factory
── platform.py # FakePlatform factory
├── integration/ # Integration tests (real resources)
│ ├── __init__.py
── api/ # HTTP API tests
├── __init__.py
│ │ └── test_smoke.py # API smoke tests
│ ├── pipeline/ # Pipeline stage-chain tests
│ │ ├── __init__.py
│ │ └── test_full_flow.py # Full flow integration
│ └── persistence/ # Database/persistence tests
│ ├── __init__.py
│ └── test_migrations.py # Alembic migration tests
├── smoke/ # Smoke tests (quick validation)
│ └── test_fake_message_flow.py
├── unit_tests/ # Unit tests
│ ├── box/ # Box module tests
│ ├── config/ # Configuration tests
│ ├── pipeline/ # Pipeline stage tests
│ │ └── conftest.py # Shared fixtures and test infrastructure
│ ├── platform/ # Platform adapter tests
│ ├── plugin/ # Plugin system tests
│ │ └── test_handler_actions.py # Action handler tests
│ ├── provider/ # Provider tests
│ │ ├── test_session_manager.py # SessionManager tests
│ │ └── test_tool_manager.py # ToolManager tests
│ ├── rag/ # RAG tests
│ │ └── test_file_storage.py # File/ZIP storage tests
│ ├── storage/ # Storage tests
│ │ └── test_s3storage.py # S3StorageProvider tests
│ ├── vector/ # Vector tests
│ │ └── test_vdb_filter_conversion.py # VDB filter tests
│ └── telemetry/ # Telemetry tests (rewritten)
├── utils/ # Test utilities
│ ├── __init__.py
│ └── import_isolation.py # sys.modules isolation for circular imports
└── README.md # This file
├── pipeline/ # Pipeline stage tests
│ ├── conftest.py # Shared fixtures and test infrastructure
│ ├── test_simple.py # Basic infrastructure tests (always pass)
│ ├── test_bansess.py # BanSessionCheckStage tests
│ ├── test_ratelimit.py # RateLimit stage tests
│ ├── test_preproc.py # PreProcessor stage tests
── test_respback.py # SendResponseBackStage tests
│ ├── test_resprule.py # GroupRespondRuleCheckStage tests
│ ├── test_pipelinemgr.py # PipelineManager tests
── test_stages_integration.py # Integration tests
└── README.md # This file
```
## Test Factories
The `tests/factories/` package provides reusable test factories:
```python
from tests.factories import (
FakeApp, # Mock application
FakeProvider, # Fake LLM provider
FakePlatform, # Fake platform adapter
text_query, # Create text query
group_text_query, # Create group query
command_query, # Create command query
)
# Create fake app
app = FakeApp()
# Create query with text
query = text_query("hello world")
# Create fake provider that returns specific response
provider = FakeProvider().returns("test response")
# Create fake platform for outbound capture
platform = FakePlatform()
await platform.reply_message(query.message_event, reply_chain)
outbound = platform.get_outbound_messages()
```
See `tests/factories/__init__.py` for all available factories.
## Test Architecture
### Fixtures (`conftest.py`)
@@ -147,28 +43,7 @@ The test suite uses a centralized fixture system that provides:
## Running Tests
### Quick self-test for developers
For local branch validation without real provider keys:
```bash
make test-quick
```
or
```bash
bash scripts/test-quick.sh
```
This runs:
1. Ruff lint check
2. Unit tests
3. Smoke tests
Suitable for quick validation before committing.
### Using the test runner script (recommended for full coverage)
### Using the test runner script (recommended)
```bash
bash run_tests.sh
```
@@ -181,135 +56,38 @@ This script automatically:
### Manual test execution
#### Run all unit tests
#### Run all tests
```bash
uv run pytest tests/unit_tests/ --cov=langbot --cov-report=xml --cov-report=term
pytest tests/pipeline/
```
#### Run specific test module
#### Run only simple tests (no imports, always pass)
```bash
uv run pytest tests/unit_tests/pipeline/ -v
pytest tests/pipeline/test_simple.py -v
```
#### Run specific test file
```bash
uv run pytest tests/unit_tests/pipeline/test_bansess.py -v
pytest tests/pipeline/test_bansess.py -v
```
#### Run with coverage
```bash
uv run pytest tests/unit_tests/pipeline/ --cov=langbot --cov-report=html
pytest tests/pipeline/ --cov=pkg/pipeline --cov-report=html
```
#### Run specific test
```bash
uv run pytest tests/unit_tests/pipeline/test_bansess.py::test_bansess_whitelist_allow -v
pytest tests/pipeline/test_bansess.py::test_bansess_whitelist_allow -v
```
### Using markers
```bash
# Run only unit tests
uv run pytest tests/unit_tests/ -m unit
# Run only integration tests
uv run pytest tests/integration/ -m integration
# Run integration tests excluding slow ones
uv run pytest tests/integration/ -m "not slow" -q
# Skip slow tests
uv run pytest tests/unit_tests/ -m "not slow"
```
### Running integration tests
Integration tests validate real system behavior with actual database/network resources.
```bash
# Run all integration tests (excluding slow ones)
uv run pytest tests/integration/ -m "not slow" -q
# Run SQLite migration integration tests
uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
# Run API smoke integration tests
uv run pytest tests/integration/api/test_smoke.py -q
# Run pipeline full-flow integration tests
uv run pytest tests/integration/pipeline/test_full_flow.py -q
# Run with verbose output
uv run pytest tests/integration/ -v
```
Note: Integration tests use:
- Temporary databases (tmp_path) for persistence tests
- Fake app/services for API tests (no real provider/platform)
- Fake runner/provider for pipeline tests (no real LLM API)
- Do not require external services
### Running migration tests locally
SQLite migration tests can be run locally without any external dependencies:
```bash
# SQLite migration tests (uses tmp_path, no external DB needed)
uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
```
PostgreSQL migration tests require an external PostgreSQL database:
```bash
# PostgreSQL migration tests (requires PostgreSQL service)
# Tests are marked as slow and skipped if TEST_POSTGRES_URL is not set
TEST_POSTGRES_URL=postgresql+asyncpg://user:pass@localhost:5432/test_db \
uv run pytest tests/integration/persistence/test_migrations_postgres.py -q --tb=short
# Or skip by default (no PostgreSQL available)
uv run pytest tests/integration/persistence/test_migrations_postgres.py -q --tb=short
# Output: SKIPPED (TEST_POSTGRES_URL not set)
```
Note: PostgreSQL tests are **not** included in fast integration gate because they:
- Require external PostgreSQL service
- Are marked with `@pytest.mark.slow`
- Need `TEST_POSTGRES_URL` environment variable
CI workflow `.github/workflows/test-migrations.yml` runs:
- SQLite tests in `test-migrations-sqlite` job (fast, no external services)
- PostgreSQL tests in `test-migrations-postgres` job (uses PostgreSQL service container)
### Running pipeline integration tests locally
Pipeline full-flow integration tests validate real stage interactions:
```bash
# Run pipeline integration tests (uses fake runner, no real LLM API)
uv run pytest tests/integration/pipeline/test_full_flow.py -q --tb=short
# Run with coverage for pipeline modules
uv run pytest tests/integration/pipeline \
--cov=langbot.pkg.pipeline.preproc.preproc \
--cov=langbot.pkg.pipeline.process.process \
--cov=langbot.pkg.pipeline.respback.respback \
--cov-report=term -q
```
These tests:
- Use `FakeRunner` class to simulate LLM responses without real API calls
- Import real `PreProcessor`, `MessageProcessor`, `SendResponseBackStage` stages
- Validate stage chain: PreProcessor → Processor → SendResponseBackStage
- Test prevent_default, exception handling, and full message flow
- Do not require real LLM provider keys
### Known Issues
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
1. Make sure you're running from the project root directory
2. Ensure dependencies are installed: `uv sync --dev`
3. Try running a simple test first to verify the test infrastructure works
2. Ensure the virtual environment is activated
3. Try running `test_simple.py` first to verify the test infrastructure works
## CI/CD Integration
@@ -319,7 +97,7 @@ Tests are automatically run on:
- Push to PR branch
- Push to master/develop branches
The workflow runs tests on Python 3.11, 3.12, and 3.13 to ensure compatibility.
The workflow runs tests on Python 3.10, 3.11, and 3.12 to ensure compatibility.
## Adding New Tests
@@ -333,8 +111,8 @@ Create a new test file `test_<stage_name>.py`:
"""
import pytest
from langbot.pkg.pipeline.<module>.<stage> import <StageClass>
from langbot.pkg.pipeline import entities as pipeline_entities
from pkg.pipeline.<module>.<stage> import <StageClass>
from pkg.pipeline import entities as pipeline_entities
@pytest.mark.asyncio
@@ -350,7 +128,7 @@ async def test_stage_basic_flow(mock_app, sample_query):
### 2. For additional fixtures
Add new fixtures to the appropriate `conftest.py`:
Add new fixtures to `conftest.py`:
```python
@pytest.fixture
@@ -364,7 +142,7 @@ def my_custom_fixture():
Use the helper functions in `conftest.py`:
```python
from tests.unit_tests.pipeline.conftest import create_stage_result, assert_result_continue
from tests.pipeline.conftest import create_stage_result, assert_result_continue
result = create_stage_result(
result_type=pipeline_entities.ResultType.CONTINUE,
@@ -388,7 +166,7 @@ assert_result_continue(result)
### Import errors
Make sure you've installed the package in development mode:
```bash
uv sync --dev
uv pip install -e .
```
### Async test failures
@@ -399,11 +177,7 @@ Check that you're mocking at the right level and using `AsyncMock` for async fun
## Future Enhancements
- [x] Add integration tests for database migrations (SQLite)
- [x] Add PostgreSQL migration integration tests (G-003)
- [x] Add integration tests for full pipeline execution
- [x] Add API smoke integration tests
- [ ] Add E2E tests
- [ ] Add integration tests for full pipeline execution
- [ ] Add performance benchmarks
- [ ] Add mutation testing for better coverage quality
- [ ] Add property-based testing with Hypothesis
- [ ] Add property-based testing with Hypothesis

View File

@@ -1,102 +0,0 @@
"""E2E test fixtures.
Provides fixtures for starting real LangBot process with minimal configuration.
"""
from __future__ import annotations
import pytest
import tempfile
import shutil
import logging
from pathlib import Path
from tests.e2e.utils.config_factory import create_minimal_config, create_test_directories
from tests.e2e.utils.process_manager import LangBotProcess, find_project_root
logger = logging.getLogger(__name__)
pytestmark = pytest.mark.e2e
@pytest.fixture(scope='session')
def e2e_port():
"""Port for E2E testing (non-default to avoid conflicts)."""
return 15300
@pytest.fixture(scope='session')
def e2e_tmpdir():
"""Create temporary directory for E2E testing."""
tmpdir = Path(tempfile.mkdtemp(prefix='langbot_e2e_'))
logger.info(f'E2E tmpdir: {tmpdir}')
yield tmpdir
# Cleanup
logger.info(f'Cleaning up E2E tmpdir: {tmpdir}')
shutil.rmtree(tmpdir, ignore_errors=True)
@pytest.fixture(scope='session')
def e2e_config_path(e2e_tmpdir, e2e_port):
"""Create minimal config.yaml for E2E testing."""
config_path = create_minimal_config(e2e_tmpdir, port=e2e_port)
create_test_directories(e2e_tmpdir)
logger.info(f'E2E config: {config_path}')
return config_path
@pytest.fixture(scope='session')
def langbot_process(e2e_config_path, e2e_port, e2e_tmpdir):
"""Start real LangBot process for E2E testing.
This fixture starts LangBot once per session and reuses it for all tests.
Coverage data is collected from the subprocess.
"""
project_root = find_project_root()
collect_coverage = True
proc = LangBotProcess(
project_root=project_root,
work_dir=e2e_tmpdir, # Run in tmpdir where data/config.yaml exists
port=e2e_port,
timeout=60, # Longer timeout for first startup
collect_coverage=collect_coverage,
)
success = proc.start()
if not success:
stdout, stderr = proc.get_logs()
pytest.fail(f'LangBot failed to start:\nstdout: {stdout}\nstderr: {stderr}')
yield proc
# Cleanup
proc.stop()
# Combine coverage data if collected
if collect_coverage and proc.get_coverage_file():
coverage_file = proc.get_coverage_file()
if coverage_file.exists():
# Copy coverage data to project root for combining
target = project_root / '.coverage.e2e'
shutil.copy(coverage_file, target)
logger.info(f'Coverage data saved to: {target}')
@pytest.fixture
def e2e_client(e2e_port, langbot_process):
"""HTTP client for E2E testing."""
import httpx
base_url = f'http://127.0.0.1:{e2e_port}'
with httpx.Client(base_url=base_url, timeout=10.0) as client:
yield client
@pytest.fixture(scope='session')
def e2e_db_path(e2e_tmpdir):
"""Path to SQLite database file."""
return e2e_tmpdir / 'data' / 'langbot.db'

View File

@@ -1,142 +0,0 @@
"""E2E tests for LangBot startup flow.
Tests the complete startup process including:
- boot.py startup orchestration
- stages/ (build_app, load_config, migrate, etc.)
- database initialization
- API availability
Run: uv run pytest tests/e2e/test_startup.py -v -m e2e
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.e2e
class TestStartupFlow:
"""Tests for LangBot startup process."""
def test_process_is_running(self, langbot_process):
"""Verify LangBot process is running."""
assert langbot_process.is_running()
def test_health_check(self, langbot_process, e2e_port):
"""Verify LangBot API is responding."""
assert langbot_process.health_check()
def test_system_info_endpoint(self, e2e_client):
"""Test /api/v1/system/info endpoint."""
response = e2e_client.get('/api/v1/system/info')
assert response.status_code == 200
data = response.json()
assert data['code'] == 0
assert 'data' in data
# System info should contain version info
assert 'version' in data['data'] or 'edition' in data['data']
def test_database_initialized(self, e2e_db_path):
"""Verify SQLite database was created and initialized."""
assert e2e_db_path.exists()
# Database should have some tables after migration
import sqlite3
conn = sqlite3.connect(str(e2e_db_path))
cursor = conn.cursor()
# Check that core tables exist
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [row[0] for row in cursor.fetchall()]
# Core tables should be created by Alembic migrations
# Note: table names may differ (legacy_pipelines instead of pipelines)
expected_tables = ['legacy_pipelines', 'bots', 'model_providers', 'llm_models']
for table in expected_tables:
assert table in tables, f'Table {table} should exist. Available: {tables}'
conn.close()
def test_chroma_directory_created(self, e2e_tmpdir):
"""Verify Chroma vector database directory was created."""
chroma_path = e2e_tmpdir / 'chroma'
# Created by the E2E config factory before startup.
assert chroma_path.exists()
def test_pipelines_endpoint(self, e2e_client):
"""Test /api/v1/pipelines endpoint (requires auth)."""
# Without auth, should return 401
response = e2e_client.get('/api/v1/pipelines')
assert response.status_code == 401
def test_auth_endpoint(self, e2e_client, e2e_tmpdir):
"""Test auth endpoint."""
# First startup may allow initial setup
response = e2e_client.post('/api/v1/user/auth', json={
'username': 'admin',
'password': 'admin',
})
# Response could be:
# - 200 if auth succeeds
# - 400 if credentials wrong
# - 401 if user not initialized
assert response.status_code in [200, 400, 401]
class TestStartupStages:
"""Tests that verify individual startup stages worked correctly."""
def test_config_loaded(self, e2e_client):
"""Verify config was loaded correctly by checking API port."""
# If API responds on e2e_port, config was loaded
assert e2e_client.get('/api/v1/system/info').status_code == 200
def test_migrations_applied(self, e2e_db_path):
"""Verify database migrations were applied."""
import sqlite3
conn = sqlite3.connect(str(e2e_db_path))
cursor = conn.cursor()
# Check alembic_version table exists and has version
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='alembic_version';")
result = cursor.fetchone()
assert result is not None, 'alembic_version table should exist'
cursor.execute('SELECT version_num FROM alembic_version;')
version = cursor.fetchone()
assert version is not None, 'Migration version should be set'
conn.close()
def test_http_controller_initialized(self, e2e_client):
"""Verify HTTP controller was initialized."""
# Multiple endpoints should be available
endpoints = [
'/api/v1/system/info',
'/api/v1/pipelines',
'/api/v1/provider/providers',
'/api/v1/platform/bots',
]
for endpoint in endpoints:
response = e2e_client.get(endpoint)
# Should get a real route response, even if auth is required.
assert response.status_code in [200, 401, 403], f'{endpoint} should be registered'
class TestMinimalStartupNoLLM:
"""Tests verifying LangBot can start without LLM providers."""
def test_api_available_without_llm(self, e2e_client):
"""API should be available even without LLM providers configured."""
response = e2e_client.get('/api/v1/system/info')
assert response.status_code == 200
def test_pipeline_metadata_available(self, e2e_client):
"""Pipeline metadata endpoint should work without LLM."""
# Requires auth, but endpoint should exist
response = e2e_client.get('/api/v1/pipelines/_/metadata')
assert response.status_code in [200, 401] # Not 404 or 500

View File

@@ -1,179 +0,0 @@
"""E2E test configuration factory.
Generates minimal config.yaml for testing LangBot startup without external dependencies.
"""
from __future__ import annotations
import yaml
from pathlib import Path
def create_minimal_config(tmpdir: Path, port: int = 15300) -> Path:
"""Create minimal config.yaml for E2E testing.
Uses embedded databases (SQLite, Chroma) to avoid external dependencies.
Config is created at tmpdir/data/config.yaml (LangBot expects this location).
"""
# LangBot expects config at data/config.yaml
data_dir = tmpdir / 'data'
data_dir.mkdir(parents=True, exist_ok=True)
config = {
'admins': [],
'api': {
'port': port,
'webhook_prefix': f'http://127.0.0.1:{port}',
'extra_webhook_prefix': '',
},
'command': {
'enable': True,
'prefix': ['!', '!'],
'privilege': {},
},
'concurrency': {
'pipeline': 20,
'session': 1,
},
'proxy': {
'http': '',
'https': '',
},
'system': {
'instance_id': '',
'edition': 'community',
'recovery_key': '',
'allow_modify_login_info': True,
'disabled_adapters': [],
'limitation': {
'max_bots': -1,
'max_pipelines': -1,
'max_extensions': -1,
},
'task_retention': {
'completed_limit': 200,
},
'jwt': {
'expire': 604800,
'secret': 'e2e-test-secret-key',
},
},
'database': {
'use': 'sqlite',
'sqlite': {
'path': str(tmpdir / 'data' / 'langbot.db'),
},
'postgresql': {
'host': '127.0.0.1',
'port': 5432,
'user': 'postgres',
'password': 'postgres',
'database': 'postgres',
},
},
'vdb': {
'use': 'chroma', # Chroma is embedded, no external dependency
'chroma': {
'path': str(tmpdir / 'chroma'),
},
'qdrant': {
'url': '',
'host': 'localhost',
'port': 6333,
'api_key': '',
},
'seekdb': {
'mode': 'embedded',
'path': str(tmpdir / 'seekdb'),
'database': 'langbot',
'host': 'localhost',
'port': 2881,
'user': 'root',
'password': '',
'tenant': '',
},
'milvus': {
'uri': 'http://127.0.0.1:19530',
'token': '',
'db_name': '',
},
'pgvector': {
'host': '127.0.0.1',
'port': 5433,
'database': 'langbot',
'user': 'postgres',
'password': 'postgres',
},
},
'storage': {
'use': 'local',
'cleanup': {
'enabled': False, # Disable cleanup for tests
'check_interval_hours': 1,
'uploaded_file_retention_days': 7,
'log_retention_days': 3,
},
'local': {
'path': str(tmpdir / 'storage'),
},
's3': {
'endpoint_url': '',
'access_key_id': '',
'secret_access_key': '',
'region': 'us-east-1',
'bucket': 'langbot-storage',
},
},
'plugin': {
'enable': False, # Disable plugin system for minimal startup
'runtime_ws_url': '',
'enable_marketplace': False,
'display_plugin_debug_url': '',
'binary_storage': {
'max_value_bytes': 10485760,
},
},
'monitoring': {
'auto_cleanup': {
'enabled': False, # Disable cleanup for tests
'retention_days': 30,
'check_interval_hours': 1,
'delete_batch_size': 1000,
},
},
'space': {
'url': 'https://space.langbot.app',
'models_gateway_api_url': 'https://api.langbot.cloud/v1',
'oauth_authorize_url': 'https://space.langbot.app/auth/authorize',
'disable_models_service': True, # Disable external services
'disable_telemetry': True, # Disable telemetry for tests
},
'provider': {}, # Empty providers - minimal startup
'llm': [], # Empty LLM models
}
# Ensure data directory exists (LangBot expects config at data/config.yaml)
data_dir = tmpdir / 'data'
data_dir.mkdir(parents=True, exist_ok=True)
# Write config to data/config.yaml (LangBot's expected location)
config_path = data_dir / 'config.yaml'
with open(config_path, 'w', encoding='utf-8') as f:
yaml.dump(config, f, default_flow_style=False)
return config_path
def create_test_directories(tmpdir: Path) -> dict[str, Path]:
"""Create necessary directories for LangBot testing."""
directories = {
'data': tmpdir / 'data',
'logs': tmpdir / 'logs',
'storage': tmpdir / 'storage',
'chroma': tmpdir / 'chroma',
}
for path in directories.values():
path.mkdir(parents=True, exist_ok=True)
return directories

View File

@@ -1,204 +0,0 @@
"""E2E test process manager.
Manages LangBot subprocess lifecycle for E2E testing.
"""
from __future__ import annotations
import subprocess
import time
import signal
import os
from pathlib import Path
from typing import Optional
import logging
logger = logging.getLogger(__name__)
class LangBotProcess:
"""Manages a LangBot subprocess for E2E testing."""
def __init__(
self,
project_root: Path,
work_dir: Path,
port: int = 15300,
timeout: int = 30,
collect_coverage: bool = True,
):
self.project_root = project_root
self.work_dir = work_dir # Directory containing data/config.yaml
self.port = port
self.timeout = timeout
self.collect_coverage = collect_coverage
self.process: Optional[subprocess.Popen] = None
self._stdout_data: bytes = b''
self._stderr_data: bytes = b''
self._coverage_file: Optional[Path] = None
def start(self) -> bool:
"""Start LangBot process and wait for it to be ready."""
import httpx
# Prepare environment
env = os.environ.copy()
env['PYTHONPATH'] = str(self.project_root / 'src')
# Set API port via environment variable
env['API__PORT'] = str(self.port)
env['API__WEBHOOK_PREFIX'] = f'http://127.0.0.1:{self.port}'
# Disable telemetry
env['SPACE__DISABLE_TELEMETRY'] = 'true'
env['SPACE__DISABLE_MODELS_SERVICE'] = 'true'
# Build command
if self.collect_coverage:
# Use coverage.py to collect coverage data
# Set COVERAGE_PROCESS_START to enable coverage in subprocess
self._coverage_file = self.work_dir / '.coverage.e2e'
env['COVERAGE_PROCESS_START'] = str(self.project_root / '.coveragerc')
env['COVERAGE_FILE'] = str(self._coverage_file)
# Create .coveragerc for subprocess
coveragerc_content = """
[run]
source = langbot.pkg
parallel = True
data_file = {}
omit =
*/tests/*
*/test_*.py
[report]
precision = 2
""".format(str(self._coverage_file))
coveragerc_path = self.work_dir / '.coveragerc'
with open(coveragerc_path, 'w') as f:
f.write(coveragerc_content)
cmd = [
'coverage', 'run',
'--rcfile=' + str(coveragerc_path),
'-m', 'langbot',
]
else:
cmd = ['uv', 'run', 'python', '-m', 'langbot']
logger.info(f'Starting LangBot in: {self.work_dir}')
logger.info(f'Command: {cmd}')
# Start process (run in work_dir so it finds data/config.yaml)
self.process = subprocess.Popen(
cmd,
cwd=self.work_dir,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid if os.name != 'nt' else None,
)
# Wait for startup
start_time = time.time()
while time.time() - start_time < self.timeout:
# Check if process died
if self.process.poll() is not None:
self._stdout_data, self._stderr_data = self.process.communicate()
logger.error(f'LangBot process died: {self._stderr_data.decode()}')
return False
# Try to connect
try:
r = httpx.get(
f'http://127.0.0.1:{self.port}/api/v1/system/info',
timeout=2.0,
)
if r.status_code == 200:
logger.info(f'LangBot started successfully on port {self.port}')
return True
except (httpx.ConnectError, httpx.TimeoutException):
pass
time.sleep(1)
# Timeout
logger.error(f'LangBot startup timeout after {self.timeout}s')
self.stop()
return False
def stop(self) -> None:
"""Stop LangBot process gracefully."""
if self.process is None:
return
logger.info('Stopping LangBot process...')
# Try graceful shutdown first
if os.name != 'nt':
# Send SIGTERM to process group
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
else:
self.process.terminate()
# Wait for graceful shutdown
try:
self.process.wait(timeout=5)
logger.info('LangBot stopped gracefully')
except subprocess.TimeoutExpired:
# Force kill
logger.warning('Force killing LangBot process')
if os.name != 'nt':
os.killpg(os.getpgid(self.process.pid), signal.SIGKILL)
else:
self.process.kill()
self.process.wait()
# Collect output for debugging
if self.process.stdout or self.process.stderr:
self._stdout_data, self._stderr_data = self.process.communicate()
self.process = None
def is_running(self) -> bool:
"""Check if process is still running."""
return self.process is not None and self.process.poll() is None
def get_logs(self) -> tuple[str, str]:
"""Get stdout and stderr logs."""
stdout = self._stdout_data.decode('utf-8', errors='replace')
stderr = self._stderr_data.decode('utf-8', errors='replace')
return stdout, stderr
def get_coverage_file(self) -> Optional[Path]:
"""Get coverage data file path."""
return self._coverage_file
def health_check(self) -> bool:
"""Check if LangBot API is responding."""
import httpx
if not self.is_running():
return False
try:
r = httpx.get(
f'http://127.0.0.1:{self.port}/api/v1/system/info',
timeout=5.0,
)
return r.status_code == 200
except Exception:
return False
def find_project_root() -> Path:
"""Find LangBot project root directory."""
current = Path(__file__).resolve()
# Walk up until we find src/langbot
for parent in current.parents:
if (parent / 'src' / 'langbot').exists():
return parent
# Fallback to LangBot-test-build directory
return Path('/home/glwuy/langbot-app/LangBot-test-build')

View File

@@ -1,102 +0,0 @@
"""
Shared test factories for LangBot tests.
Provides reusable factories for:
- Fake application (app.py)
- Messages and queries (message.py)
- Fake providers (provider.py)
- Fake platforms (platform.py)
Usage:
from tests.factories import FakeApp, text_query, FakeProvider
app = FakeApp()
query = text_query("hello")
provider = FakeProvider.returns("response")
"""
from tests.factories.app import FakeApp, fake_app
from tests.factories.message import (
text_chain,
group_text_chain,
mention_chain,
image_chain,
text_query,
group_text_query,
private_text_query,
command_query,
mention_query,
empty_query,
image_query,
file_query,
unsupported_query,
voice_query,
at_all_query,
query_with_session,
query_with_config,
friend_message_event,
group_message_event,
mock_adapter,
)
from tests.factories.provider import (
FakeProvider,
fake_provider,
fake_provider_pong,
fake_provider_timeout,
fake_provider_auth_error,
fake_provider_rate_limit,
fake_provider_malformed,
fake_model,
)
from tests.factories.platform import (
FakePlatform,
fake_platform,
fake_platform_with_streaming,
fake_platform_with_failure,
mock_platform_adapter,
)
__all__ = [
# App
"FakeApp",
"fake_app",
# Message chains
"text_chain",
"group_text_chain",
"mention_chain",
"image_chain",
# Message events
"friend_message_event",
"group_message_event",
# Mock adapters
"mock_adapter",
# Queries
"text_query",
"group_text_query",
"private_text_query",
"command_query",
"mention_query",
"empty_query",
"image_query",
"file_query",
"unsupported_query",
"voice_query",
"at_all_query",
"query_with_session",
"query_with_config",
# Provider
"FakeProvider",
"fake_provider",
"fake_provider_pong",
"fake_provider_timeout",
"fake_provider_auth_error",
"fake_provider_rate_limit",
"fake_provider_malformed",
"fake_model",
# Platform
"FakePlatform",
"fake_platform",
"fake_platform_with_streaming",
"fake_platform_with_failure",
"mock_platform_adapter",
]

View File

@@ -1,137 +0,0 @@
"""
Fake application factory for tests.
Provides a mock Application object with all dependencies needed by pipeline stages.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, Mock
class FakeApp:
"""Mock Application object providing all basic dependencies needed by stages."""
def __init__(
self,
*,
command_prefix: list[str] = ["/", "!"],
command_enable: bool = True,
pipeline_concurrency: int = 10,
admins: list[str] | None = None,
**extra_attrs,
):
self.logger = self._create_mock_logger()
self.sess_mgr = self._create_mock_session_manager()
self.model_mgr = self._create_mock_model_manager()
self.tool_mgr = self._create_mock_tool_manager()
self.plugin_connector = self._create_mock_plugin_connector()
self.persistence_mgr = self._create_mock_persistence_manager()
self.query_pool = self._create_mock_query_pool()
self.instance_config = self._create_mock_instance_config(
command_prefix=command_prefix,
command_enable=command_enable,
pipeline_concurrency=pipeline_concurrency,
admins=admins or [],
)
self.task_mgr = self._create_mock_task_manager()
# Handler-specific optional attributes
self.telemetry = self._create_mock_telemetry()
self.survey = None
self.cmd_mgr = self._create_mock_cmd_mgr()
# Apply any extra attributes for specific test scenarios
for name, value in extra_attrs.items():
setattr(self, name, value)
# Captured outbound messages (for assertions)
self._outbound_messages: list = []
def _create_mock_logger(self):
logger = Mock()
logger.debug = Mock()
logger.info = Mock()
logger.error = Mock()
logger.warning = Mock()
return logger
def _create_mock_session_manager(self):
sess_mgr = AsyncMock()
sess_mgr.get_session = AsyncMock()
sess_mgr.get_conversation = AsyncMock()
return sess_mgr
def _create_mock_model_manager(self):
model_mgr = AsyncMock()
model_mgr.get_model_by_uuid = AsyncMock()
return model_mgr
def _create_mock_tool_manager(self):
tool_mgr = AsyncMock()
tool_mgr.get_all_tools = AsyncMock(return_value=[])
return tool_mgr
def _create_mock_plugin_connector(self):
plugin_connector = AsyncMock()
plugin_connector.emit_event = AsyncMock()
return plugin_connector
def _create_mock_persistence_manager(self):
persistence_mgr = AsyncMock()
persistence_mgr.execute_async = AsyncMock()
return persistence_mgr
def _create_mock_query_pool(self):
query_pool = Mock()
query_pool.cached_queries = {}
query_pool.queries = []
query_pool.condition = AsyncMock()
return query_pool
def _create_mock_instance_config(
self,
command_prefix: list[str],
command_enable: bool,
pipeline_concurrency: int,
admins: list[str],
):
instance_config = Mock()
instance_config.data = {
"command": {"prefix": command_prefix, "enable": command_enable},
"concurrency": {"pipeline": pipeline_concurrency},
"admins": admins,
}
return instance_config
def _create_mock_task_manager(self):
task_mgr = Mock()
task_mgr.create_task = Mock()
return task_mgr
def _create_mock_telemetry(self):
telemetry = AsyncMock()
telemetry.start_send_task = AsyncMock()
return telemetry
def _create_mock_cmd_mgr(self):
cmd_mgr = AsyncMock()
cmd_mgr.execute = AsyncMock()
return cmd_mgr
def capture_message(self, message):
"""Capture an outbound message for test assertions."""
self._outbound_messages.append(message)
def get_outbound_messages(self) -> list:
"""Get all captured outbound messages."""
return self._outbound_messages.copy()
def clear_outbound_messages(self):
"""Clear captured outbound messages."""
self._outbound_messages.clear()
def fake_app(**kwargs) -> FakeApp:
"""Create a FakeApp instance with optional overrides."""
return FakeApp(**kwargs)

View File

@@ -1,472 +0,0 @@
"""
Message and query factories for tests.
Provides reusable factories for creating message chains, events, and query objects.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, Mock
import typing
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.entities.builtin.provider.session as provider_session
# Counter for generating unique IDs
_query_counter = 0
def _next_query_id() -> int:
"""Generate a unique query ID."""
global _query_counter
_query_counter += 1
return _query_counter
# ============== Message Chain Factories ==============
def text_chain(text: str = "hello") -> platform_message.MessageChain:
"""Create a simple text message chain."""
return platform_message.MessageChain([
platform_message.Plain(text=text),
])
def group_text_chain(text: str = "hello") -> platform_message.MessageChain:
"""Create a group text message chain (same as text_chain, context provided by event)."""
return text_chain(text)
def mention_chain(
text: str = "hello",
target: typing.Union[int, str] = 12345,
) -> platform_message.MessageChain:
"""Create a message chain with @mention."""
return platform_message.MessageChain([
platform_message.At(target=target),
platform_message.Plain(text=f" {text}"),
])
def image_chain(
text: str = "",
url: str = "https://example.com/image.png",
) -> platform_message.MessageChain:
"""Create a message chain with an image."""
components = []
if text:
components.append(platform_message.Plain(text=text))
components.append(platform_message.Image(url=url))
return platform_message.MessageChain(components)
def command_chain(
command: str = "help",
prefix: str = "/",
) -> platform_message.MessageChain:
"""Create a command message chain."""
return platform_message.MessageChain([
platform_message.Plain(text=f"{prefix}{command}"),
])
# ============== Message Event Factories ==============
def friend_message_event(
message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345,
nickname: str = "TestUser",
) -> platform_events.FriendMessage:
"""Create a friend (private) message event."""
sender = platform_entities.Friend(
id=sender_id,
nickname=nickname,
remark=None,
)
return platform_events.FriendMessage(
type="FriendMessage",
sender=sender,
message_chain=message_chain,
time=1609459200,
)
def group_message_event(
message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345,
sender_name: str = "TestUser",
group_id: typing.Union[int, str] = 99999,
group_name: str = "TestGroup",
) -> platform_events.GroupMessage:
"""Create a group message event."""
group = platform_entities.Group(
id=group_id,
name=group_name,
permission=platform_entities.Permission.Member,
)
sender = platform_entities.GroupMember(
id=sender_id,
member_name=sender_name,
permission=platform_entities.Permission.Member,
group=group,
)
return platform_events.GroupMessage(
type="GroupMessage",
sender=sender,
message_chain=message_chain,
time=1609459200,
)
# ============== Mock Adapter Factory ==============
def mock_adapter() -> Mock:
"""Create a mock platform adapter."""
adapter = AsyncMock()
adapter.is_stream_output_supported = AsyncMock(return_value=False)
adapter.reply_message = AsyncMock()
adapter.reply_message_chunk = AsyncMock()
return adapter
# ============== Query Factories ==============
def _base_query(
message_chain: platform_message.MessageChain,
message_event: platform_events.MessageEvent,
launcher_type: provider_session.LauncherTypes,
launcher_id: typing.Union[int, str],
sender_id: typing.Union[int, str],
adapter: Mock,
**overrides,
) -> pipeline_query.Query:
"""Create a base query with model_construct to bypass validation."""
query_id = _next_query_id()
base_data = {
"query_id": query_id,
"launcher_type": launcher_type,
"launcher_id": launcher_id,
"sender_id": sender_id,
"message_chain": message_chain,
"message_event": message_event,
"adapter": adapter,
"pipeline_uuid": "test-pipeline-uuid",
"bot_uuid": "test-bot-uuid",
"pipeline_config": {
"ai": {
"runner": {"runner": "local-agent"},
"local-agent": {
"model": {"primary": "test-model-uuid", "fallbacks": []},
"prompt": "test-prompt",
},
},
"output": {"misc": {"at-sender": False, "quote-origin": False}},
"trigger": {"misc": {"combine-quote-message": False}},
},
"session": None,
"prompt": None,
"messages": [],
"user_message": None,
"use_funcs": [],
"use_llm_model_uuid": None,
"variables": {},
"resp_messages": [],
"resp_message_chain": None,
"current_stage_name": None,
}
# Apply overrides
for key, value in overrides.items():
base_data[key] = value
return pipeline_query.Query.model_construct(**base_data)
def text_query(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a basic text query (private chat)."""
chain = text_chain(text)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def private_text_query(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a private text query (alias for text_query)."""
return text_query(text, sender_id, **overrides)
def group_text_query(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
) -> pipeline_query.Query:
"""Create a group text query."""
chain = text_chain(text)
event = group_message_event(chain, sender_id, group_id=group_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.GROUP,
launcher_id=group_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def command_query(
command: str = "help",
prefix: str = "/",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a command-like query."""
chain = command_chain(command, prefix)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def mention_query(
text: str = "hello",
target: typing.Union[int, str] = 12345,
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
) -> pipeline_query.Query:
"""Create a mention-bot query (group chat with @mention)."""
chain = mention_chain(text, target)
event = group_message_event(chain, sender_id, group_id=group_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.GROUP,
launcher_id=group_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def empty_query(**overrides) -> pipeline_query.Query:
"""Create an empty message query."""
chain = platform_message.MessageChain([])
event = friend_message_event(chain)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345,
sender_id=12345,
adapter=adapter,
**overrides,
)
def image_query(
text: str = "",
url: str = "https://example.com/image.png",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create an image query."""
chain = image_chain(text, url)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def file_query(
url: str = "https://example.com/document.pdf",
name: str = "document.pdf",
text: str = "",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a file attachment query."""
components = []
if text:
components.append(platform_message.Plain(text=text))
components.append(platform_message.File(url=url, name=name))
chain = platform_message.MessageChain(components)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def unsupported_query(
unsupported_type: str = "CustomComponent",
text: str = "",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a query with unsupported/unknown message segment."""
components = []
if text:
components.append(platform_message.Plain(text=text))
# Use Unknown component for unsupported types
components.append(platform_message.Unknown(text=f"Unsupported: {unsupported_type}"))
chain = platform_message.MessageChain(components)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def query_with_session(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
session: provider_session.Session = None,
**overrides,
) -> pipeline_query.Query:
"""Create a query with a session object.
If session is None, creates a default session with empty conversation.
"""
if session is None:
# Create a default session
session = provider_session.Session(
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
use_prompt_name="default",
using_conversation=None,
conversations=[],
)
return text_query(text, sender_id, session=session, **overrides)
def query_with_config(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
pipeline_config: dict = None,
**overrides,
) -> pipeline_query.Query:
"""Create a query with custom pipeline configuration.
If pipeline_config is None, uses default config.
Useful for testing specific stage behaviors.
"""
if pipeline_config is None:
pipeline_config = {
"ai": {
"runner": {"runner": "local-agent"},
"local-agent": {
"model": {"primary": "test-model-uuid", "fallbacks": []},
"prompt": "test-prompt",
},
},
"output": {"misc": {"at-sender": False, "quote-origin": False}},
"trigger": {"misc": {"combine-quote-message": False}},
}
return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides)
def voice_query(
url: str = "https://example.com/audio.mp3",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a voice/audio query."""
components = [
platform_message.Voice(url=url),
]
chain = platform_message.MessageChain(components)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def at_all_query(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
) -> pipeline_query.Query:
"""Create a group query with @All mention."""
components = [
platform_message.AtAll(),
platform_message.Plain(text=f" {text}"),
]
chain = platform_message.MessageChain(components)
event = group_message_event(chain, sender_id, group_id=group_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.GROUP,
launcher_id=group_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)

View File

@@ -1,336 +0,0 @@
"""
Fake platform factory for tests.
Provides a fake platform adapter for tests that need inbound message injection
and outbound message capture.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, Mock
import typing
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
class FakePlatform:
"""Fake platform adapter for unit and integration tests.
Simulates platform behavior without real network calls:
- Inbound text message construction
- Group and private conversation identities
- Mention-bot flag
- Outbound text capture
- Outbound file/image capture
- Send failure simulation
Does not start real platform adapters.
Does not call IM platform SDKs.
"""
def __init__(
self,
*,
bot_account_id: str = "test-bot",
stream_output_supported: bool = False,
raise_error: Exception = None,
):
self.bot_account_id = bot_account_id
self._stream_output_supported = stream_output_supported
self._raise_error = raise_error
# Captured outbound messages
self._outbound_messages: list[dict] = []
self._outbound_chunks: list[dict] = []
# Registered listeners
self._listeners: dict = {}
def raises(self, error: Exception) -> "FakePlatform":
"""Configure platform to raise an error on send."""
self._raise_error = error
return self
def send_failure(self) -> "FakePlatform":
"""Configure platform to simulate send failure."""
return self.raises(Exception("Platform send failure"))
def supports_streaming(self, supported: bool = True) -> "FakePlatform":
"""Configure whether streaming output is supported."""
self._stream_output_supported = supported
return self
def get_outbound_messages(self) -> list[dict]:
"""Get all captured outbound messages for assertions."""
return self._outbound_messages.copy()
def get_outbound_chunks(self) -> list[dict]:
"""Get all captured outbound streaming chunks for assertions."""
return self._outbound_chunks.copy()
def clear_outbound(self):
"""Clear captured outbound messages."""
self._outbound_messages.clear()
self._outbound_chunks.clear()
def last_message(self) -> dict | None:
"""Get the last captured outbound message."""
return self._outbound_messages[-1] if self._outbound_messages else None
def last_chunk(self) -> dict | None:
"""Get the last captured streaming chunk."""
return self._outbound_chunks[-1] if self._outbound_chunks else None
# ============== Inbound Message Construction ==============
def create_friend_message(
self,
text: str,
sender_id: typing.Union[int, str] = 12345,
nickname: str = "TestUser",
) -> platform_events.FriendMessage:
"""Create an inbound friend (private) message event."""
sender = platform_entities.Friend(
id=sender_id,
nickname=nickname,
remark=None,
)
chain = platform_message.MessageChain([
platform_message.Plain(text=text),
])
return platform_events.FriendMessage(
type="FriendMessage",
sender=sender,
message_chain=chain,
time=1609459200,
)
def create_group_message(
self,
text: str,
sender_id: typing.Union[int, str] = 12345,
sender_name: str = "TestUser",
group_id: typing.Union[int, str] = 99999,
group_name: str = "TestGroup",
mention_bot: bool = False,
) -> platform_events.GroupMessage:
"""Create an inbound group message event.
Args:
text: Message text content
sender_id: Sender user ID
sender_name: Sender display name
group_id: Group ID
group_name: Group name
mention_bot: If True, prepend @mention of bot account
"""
group = platform_entities.Group(
id=group_id,
name=group_name,
permission=platform_entities.Permission.Member,
)
sender = platform_entities.GroupMember(
id=sender_id,
member_name=sender_name,
permission=platform_entities.Permission.Member,
group=group,
)
# Build message chain with optional mention
components = []
if mention_bot:
components.append(platform_message.At(target=self.bot_account_id))
components.append(platform_message.Plain(text=" "))
components.append(platform_message.Plain(text=text))
chain = platform_message.MessageChain(components)
return platform_events.GroupMessage(
type="GroupMessage",
sender=sender,
message_chain=chain,
time=1609459200,
)
def create_image_message(
self,
url: str = "https://example.com/image.png",
text: str = "",
sender_id: typing.Union[int, str] = 12345,
is_group: bool = False,
group_id: typing.Union[int, str] = 99999,
) -> platform_events.MessageEvent:
"""Create an inbound image message event."""
components = []
if text:
components.append(platform_message.Plain(text=text))
components.append(platform_message.Image(url=url))
chain = platform_message.MessageChain(components)
if is_group:
return self.create_group_message("", sender_id, group_id=group_id)
# Replace chain
else:
sender = platform_entities.Friend(id=sender_id, nickname="TestUser", remark=None)
return platform_events.FriendMessage(
type="FriendMessage",
sender=sender,
message_chain=chain,
time=1609459200,
)
# ============== Adapter Methods (Simulated) ==============
async def send_message(
self,
target_type: str,
target_id: str,
message: platform_message.MessageChain,
):
"""Simulate sending a message (captures for assertions)."""
if self._raise_error:
raise self._raise_error
self._outbound_messages.append({
"type": "send",
"target_type": target_type,
"target_id": target_id,
"message": message,
})
async def reply_message(
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
):
"""Simulate replying to a message (captures for assertions)."""
if self._raise_error:
raise self._raise_error
self._outbound_messages.append({
"type": "reply",
"source_type": message_source.type,
"source": message_source,
"message": message,
"quote_origin": quote_origin,
})
async def reply_message_chunk(
self,
message_source: platform_events.MessageEvent,
bot_message: dict,
message: platform_message.MessageChain,
quote_origin: bool = False,
is_final: bool = False,
):
"""Simulate streaming reply (captures for assertions)."""
if self._raise_error:
raise self._raise_error
self._outbound_chunks.append({
"type": "reply_chunk",
"source_type": message_source.type,
"source": message_source,
"bot_message": bot_message,
"message": message,
"quote_origin": quote_origin,
"is_final": is_final,
})
async def is_stream_output_supported(self) -> bool:
"""Return whether streaming output is supported."""
return self._stream_output_supported
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable,
):
"""Register an event listener (stores for simulation)."""
if event_type not in self._listeners:
self._listeners[event_type] = []
self._listeners[event_type].append(callback)
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable,
):
"""Unregister an event listener."""
if event_type in self._listeners:
self._listeners[event_type].remove(callback)
async def run_async(self):
"""Simulate running the adapter (does nothing)."""
pass
async def kill(self) -> bool:
"""Simulate killing the adapter."""
return True
async def is_muted(self, group_id: int) -> bool:
"""Simulate checking mute status."""
return False
async def create_message_card(
self,
message_id: typing.Type[str, int],
event: platform_events.MessageEvent,
) -> bool:
"""Simulate creating a message card."""
return False
# ============== Simulation Helpers ==============
async def simulate_inbound_event(
self,
event: platform_events.Event,
):
"""Simulate receiving an inbound event by calling registered listeners."""
listeners = self._listeners.get(type(event), [])
for callback in listeners:
await callback(event, self)
def fake_platform(
bot_account_id: str = "test-bot",
stream_output_supported: bool = False,
) -> FakePlatform:
"""Create a FakePlatform instance."""
return FakePlatform(
bot_account_id=bot_account_id,
stream_output_supported=stream_output_supported,
)
def fake_platform_with_streaming() -> FakePlatform:
"""Create a FakePlatform that supports streaming output."""
return FakePlatform(stream_output_supported=True)
def fake_platform_with_failure() -> FakePlatform:
"""Create a FakePlatform that simulates send failure."""
return FakePlatform().send_failure()
# ============== Mock Adapter (for Query) ==============
def mock_platform_adapter(platform: FakePlatform = None) -> Mock:
"""Create a mock platform adapter using FakePlatform or a simple mock."""
if platform is None:
platform = FakePlatform()
adapter = Mock()
adapter.bot_account_id = platform.bot_account_id
adapter.reply_message = AsyncMock(side_effect=platform.reply_message)
adapter.reply_message_chunk = AsyncMock(side_effect=platform.reply_message_chunk)
adapter.send_message = AsyncMock(side_effect=platform.send_message)
adapter.is_stream_output_supported = AsyncMock(
return_value=platform._stream_output_supported
)
adapter._fake_platform = platform # Store for assertions
return adapter

View File

@@ -1,224 +0,0 @@
"""
Fake provider factory for tests.
Provides a deterministic fake provider that simulates LLM responses without real API calls.
"""
from __future__ import annotations
from unittest.mock import Mock
import typing
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class FakeProvider:
"""Deterministic fake provider for unit and integration tests.
Simulates various provider behaviors:
- Normal text response
- Streaming response
- Timeout error
- Auth error
- Rate-limit error
- Malformed response
Does not call real LLM vendors.
Does not require API keys.
"""
PONG_RESPONSE = "LANGBOT_FAKE_PONG"
def __init__(
self,
*,
default_response: str = "fake response",
streaming_chunks: list[str] = None,
raise_error: Exception = None,
captured_requests: list = None,
):
self._default_response = default_response
self._streaming_chunks = streaming_chunks or ["fake ", "response"]
self._raise_error = raise_error
self._captured_requests = captured_requests if captured_requests is not None else []
def returns(self, text: str) -> "FakeProvider":
"""Configure provider to return a specific text response."""
self._default_response = text
self._streaming_chunks = [text]
return self
def returns_streaming(self, chunks: list[str]) -> "FakeProvider":
"""Configure provider to return streaming chunks."""
self._streaming_chunks = chunks
self._default_response = "".join(chunks)
return self
def raises(self, error: Exception) -> "FakeProvider":
"""Configure provider to raise an error."""
self._raise_error = error
return self
def timeout(self) -> "FakeProvider":
"""Configure provider to simulate timeout."""
return self.raises(TimeoutError("Provider timeout"))
def auth_error(self) -> "FakeProvider":
"""Configure provider to simulate auth error."""
return self.raises(Exception("Invalid API key"))
def rate_limit(self) -> "FakeProvider":
"""Configure provider to simulate rate limit."""
return self.raises(Exception("Rate limit exceeded"))
def malformed(self) -> "FakeProvider":
"""Configure provider to simulate malformed response."""
self._default_response = None
return self
def get_captured_requests(self) -> list:
"""Get all captured request arguments for assertions."""
return self._captured_requests.copy()
def clear_captured_requests(self):
"""Clear captured requests."""
self._captured_requests.clear()
def _create_message(self, content: str) -> provider_message.Message:
"""Create a provider message from text content."""
return provider_message.Message(
role="assistant",
content=content,
)
def _create_chunk(
self,
content: str,
is_final: bool = False,
msg_sequence: int = 0,
) -> provider_message.MessageChunk:
"""Create a provider message chunk."""
return provider_message.MessageChunk(
role="assistant",
content=content,
is_final=is_final,
msg_sequence=msg_sequence,
)
async def invoke_llm(
self,
query,
model,
messages: list,
funcs: list,
extra_args: dict,
remove_think: bool = False,
) -> provider_message.Message:
"""Simulate non-streaming LLM invocation."""
# Capture request for assertions
self._captured_requests.append({
"query_id": query.query_id if query else None,
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
"messages": messages,
"funcs": funcs,
"extra_args": extra_args,
})
# Simulate error if configured
if self._raise_error:
raise self._raise_error
# Return response
if self._default_response is None:
# Malformed response
return provider_message.Message(role="assistant", content=None)
return self._create_message(self._default_response)
async def invoke_llm_stream(
self,
query,
model,
messages: list,
funcs: list,
extra_args: dict,
remove_think: bool = False,
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""Simulate streaming LLM invocation."""
# Capture request for assertions
self._captured_requests.append({
"query_id": query.query_id if query else None,
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
"messages": messages,
"funcs": funcs,
"extra_args": extra_args,
"streaming": True,
})
# Simulate error if configured
if self._raise_error:
raise self._raise_error
# Yield chunks
for i, chunk in enumerate(self._streaming_chunks):
is_final = (i == len(self._streaming_chunks) - 1)
yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i)
def fake_provider(
default_response: str = "fake response",
) -> FakeProvider:
"""Create a FakeProvider with optional default response."""
return FakeProvider(default_response=default_response)
def fake_provider_pong() -> FakeProvider:
"""Create a FakeProvider that returns the pong response."""
return FakeProvider(default_response=FakeProvider.PONG_RESPONSE)
def fake_provider_timeout() -> FakeProvider:
"""Create a FakeProvider that simulates timeout."""
return FakeProvider().timeout()
def fake_provider_auth_error() -> FakeProvider:
"""Create a FakeProvider that simulates auth error."""
return FakeProvider().auth_error()
def fake_provider_rate_limit() -> FakeProvider:
"""Create a FakeProvider that simulates rate limit."""
return FakeProvider().rate_limit()
def fake_provider_malformed() -> FakeProvider:
"""Create a FakeProvider that simulates malformed response."""
return FakeProvider().malformed()
# ============== Mock Model Factory ==============
def fake_model(
*,
uuid: str = "test-model-uuid",
name: str = "test-model",
abilities: list[str] = None,
provider: FakeProvider = None,
) -> Mock:
"""Create a mock model with a fake provider."""
model = Mock()
model.model_entity = Mock()
model.model_entity.uuid = uuid
model.model_entity.name = name
model.model_entity.abilities = abilities or ["func_call", "vision"]
model.model_entity.extra_args = {}
# Attach fake provider
if provider is None:
provider = FakeProvider()
model.provider = provider
return model

View File

@@ -1,6 +0,0 @@
"""
Integration tests package.
These tests validate real system behavior with actual database/network resources.
Run with: uv run pytest tests/integration/ -m "not slow" -q
"""

View File

@@ -1,5 +0,0 @@
"""
API integration tests package.
Tests for HTTP API endpoints using Quart test client.
"""

View File

@@ -1,28 +0,0 @@
from __future__ import annotations
import pytest
def dedupe_preregistered_groups() -> None:
"""Keep API integration route registration isolated across test modules."""
from langbot.pkg.api.http.controller import group
seen: set[tuple[str, str]] = set()
unique_groups = []
for group_cls in group.preregistered_groups:
key = (group_cls.name, group_cls.path)
if key in seen:
continue
seen.add(key)
unique_groups.append(group_cls)
group.preregistered_groups[:] = unique_groups
@pytest.fixture(scope='module')
def http_controller_cls(mock_circular_import_chain):
"""Import HTTPController under each module's circular-import isolation."""
from langbot.pkg.api.http.controller.main import HTTPController
dedupe_preregistered_groups()
return HTTPController

View File

@@ -1,253 +0,0 @@
"""
API integration tests for bot endpoints.
Tests real HTTP API behavior for bot management.
Run: uv run pytest tests/integration/api/test_bots.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""Break circular import chain for API controller."""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.platform',
'langbot.pkg.api.http.controller.groups.platform.bots',
'langbot.pkg.api.http.controller.groups.platform.adapters',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
import langbot.pkg.api.http.controller.groups.platform.bots as _bots # noqa: E402, F401
yield
@pytest.fixture(scope='module')
def fake_bot_app():
"""Create FakeApp with bot services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# Auth services
app.user_service = Mock()
app.user_service.is_initialized = AsyncMock(return_value=True)
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
app.apikey_service = Mock()
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
# Bot service
app.bot_service = Mock()
app.bot_service.get_bots = AsyncMock(return_value=[
{
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
}
])
app.bot_service.get_runtime_bot_info = AsyncMock(return_value={
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
})
app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'})
app.bot_service.update_bot = AsyncMock(return_value={})
app.bot_service.delete_bot = AsyncMock()
app.bot_service.list_event_logs = AsyncMock(return_value=(
[{'uuid': 'log-1', 'message': 'test log'}],
1
))
app.bot_service.send_message = AsyncMock()
# Platform manager
app.platform_mgr = Mock()
return app
@pytest.fixture(scope='module')
async def quart_test_client(fake_bot_app, http_controller_cls):
"""Create Quart test client (module scope to avoid route re-registration)."""
controller = http_controller_cls(fake_bot_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestBotEndpoints:
"""Tests for /api/v1/platform/bots endpoints."""
@pytest.mark.asyncio
async def test_get_bots_success(self, quart_test_client):
"""GET /api/v1/platform/bots returns bot list."""
response = await quart_test_client.get(
'/api/v1/platform/bots',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
assert 'bots' in data['data']
@pytest.mark.asyncio
async def test_create_bot_success(self, quart_test_client):
"""POST /api/v1/platform/bots creates new bot."""
response = await quart_test_client.post(
'/api/v1/platform/bots',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']
@pytest.mark.asyncio
async def test_get_single_bot_success(self, quart_test_client):
"""GET /api/v1/platform/bots/{uuid} returns bot with runtime info."""
response = await quart_test_client.get(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'bot' in data['data']
@pytest.mark.asyncio
async def test_update_bot_success(self, quart_test_client):
"""PUT /api/v1/platform/bots/{uuid} updates bot."""
response = await quart_test_client.put(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Bot'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_delete_bot_success(self, quart_test_client):
"""DELETE /api/v1/platform/bots/{uuid} deletes bot."""
response = await quart_test_client.delete(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestBotLogsEndpoint:
"""Tests for bot logs endpoint."""
@pytest.mark.asyncio
async def test_get_bot_logs_success(self, quart_test_client):
"""POST /api/v1/platform/bots/{uuid}/logs returns logs."""
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/logs',
headers={'Authorization': 'Bearer test_token'},
json={'from_index': -1, 'max_count': 10}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'logs' in data['data']
assert 'total_count' in data['data']
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestBotSendMessageEndpoint:
"""Tests for bot send message endpoint."""
@pytest.mark.asyncio
async def test_send_message_success(self, quart_test_client):
"""POST /api/v1/platform/bots/{uuid}/send_message sends message."""
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/send_message',
headers={'Authorization': 'Bearer test_api_key'},
json={
'target_type': 'person',
'target_id': 'user123',
'message_chain': [{'type': 'text', 'text': 'Hello'}]
}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert data['data']['sent'] is True
@pytest.mark.asyncio
async def test_send_message_missing_target_type(self, quart_test_client):
"""POST send_message without target_type returns 400."""
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/send_message',
headers={'Authorization': 'Bearer test_api_key'},
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]}
)
assert response.status_code == 400
data = await response.get_json()
assert data['code'] == -1
@pytest.mark.asyncio
async def test_send_message_invalid_target_type(self, quart_test_client):
"""POST send_message with invalid target_type returns 400."""
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/send_message',
headers={'Authorization': 'Bearer test_api_key'},
json={
'target_type': 'invalid',
'target_id': 'user123',
'message_chain': [{'type': 'text', 'text': 'Hello'}]
}
)
assert response.status_code == 400
data = await response.get_json()
assert data['code'] == -1

View File

@@ -1,300 +0,0 @@
"""
API integration tests for embed widget endpoints.
Tests real HTTP API behavior for embed widget functionality.
Run: uv run pytest tests/integration/api/test_embed.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""Break circular import chain for API controller."""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.pipelines',
'langbot.pkg.api.http.controller.groups.pipelines.embed',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
import langbot.pkg.api.http.controller.groups.pipelines.embed as _embed # noqa: E402, F401
yield
@pytest.fixture(scope='module')
def fake_embed_app():
"""Create FakeApp with embed widget services (module scope)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# Create mock web_page_bot with valid UUID format
mock_bot_entity = Mock()
mock_bot_entity.uuid = 'a1b2c3d4-5678-90ab-cdef-123456789abc'
mock_bot_entity.adapter = 'web_page_bot'
mock_bot_entity.enable = True
mock_bot_entity.use_pipeline_uuid = 'test-pipeline-uuid'
mock_bot_entity.name = 'Test Web Bot'
mock_bot_entity.adapter_config = {
'turnstile_secret_key': '',
'turnstile_site_key': '',
'language': 'en_US',
'bubble_icon': 'logo',
}
mock_runtime_bot = Mock()
mock_runtime_bot.bot_entity = mock_bot_entity
# Platform manager with bots
app.platform_mgr = Mock()
app.platform_mgr.bots = [mock_runtime_bot]
# WebSocket proxy bot with adapter
mock_websocket_adapter = Mock()
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[
{'id': 'msg-1', 'content': 'test message'}
])
mock_websocket_adapter.reset_session = Mock()
mock_websocket_adapter.handle_websocket_message = AsyncMock()
mock_ws_proxy_bot = Mock()
mock_ws_proxy_bot.adapter = mock_websocket_adapter
app.platform_mgr.websocket_proxy_bot = mock_ws_proxy_bot
# Monitoring service for feedback
app.monitoring_service = Mock()
app.monitoring_service.record_feedback = AsyncMock()
return app
@pytest.fixture(scope='module')
async def quart_test_client(fake_embed_app, http_controller_cls):
"""Create Quart test client (module scope)."""
controller = http_controller_cls(fake_embed_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbedWidgetEndpoint:
"""Tests for widget.js endpoint."""
@pytest.mark.asyncio
async def test_get_widget_js_success(self, quart_test_client):
"""GET /api/v1/embed/{bot_uuid}/widget.js returns JS."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js'
)
assert response.status_code == 200
assert 'javascript' in response.content_type
@pytest.mark.asyncio
async def test_get_widget_js_invalid_uuid(self, quart_test_client):
"""GET widget.js with invalid UUID returns 400."""
response = await quart_test_client.get(
'/api/v1/embed/invalid-uuid/widget.js'
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_get_widget_js_bot_not_found(self, quart_test_client):
"""GET widget.js for non-existent bot returns 404."""
response = await quart_test_client.get(
'/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js'
)
assert response.status_code == 404
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbedLogoEndpoint:
"""Tests for logo endpoint."""
@pytest.mark.asyncio
async def test_get_logo_success(self, quart_test_client):
"""GET /api/v1/embed/logo returns image."""
response = await quart_test_client.get('/api/v1/embed/logo')
assert response.status_code == 200
assert 'image/webp' in response.content_type
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbedTurnstileVerifyEndpoint:
"""Tests for Turnstile verification endpoint."""
@pytest.mark.asyncio
async def test_turnstile_verify_no_secret(self, quart_test_client):
"""POST turnstile verify without secret returns dummy token."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
json={'token': 'test-token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'token' in data['data']
@pytest.mark.asyncio
async def test_turnstile_verify_invalid_uuid(self, quart_test_client):
"""POST turnstile verify with invalid UUID returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/invalid-uuid/turnstile/verify',
json={'token': 'test-token'}
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_turnstile_verify_missing_token(self, quart_test_client):
"""POST turnstile verify without token returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
json={}
)
assert response.status_code == 400
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbedMessagesEndpoint:
"""Tests for messages endpoint."""
@pytest.mark.asyncio
async def test_get_messages_person_success(self, quart_test_client):
"""GET messages/person returns messages."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'messages' in data['data']
@pytest.mark.asyncio
async def test_get_messages_group_success(self, quart_test_client):
"""GET messages/group returns messages."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/group',
headers={'Authorization': 'Bearer 1234567890.dummy'}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_get_messages_invalid_session_type(self, quart_test_client):
"""GET messages with invalid session_type returns 400."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/invalid',
headers={'Authorization': 'Bearer 1234567890.dummy'}
)
assert response.status_code == 400
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbedResetEndpoint:
"""Tests for session reset endpoint."""
@pytest.mark.asyncio
async def test_reset_session_person_success(self, quart_test_client):
"""POST reset/person resets session."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/reset/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_reset_session_invalid_uuid(self, quart_test_client):
"""POST reset with invalid UUID returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/invalid-uuid/reset/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
)
assert response.status_code == 400
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbedFeedbackEndpoint:
"""Tests for feedback submission endpoint."""
@pytest.mark.asyncio
async def test_submit_feedback_like(self, quart_test_client):
"""POST feedback with type=1 (like) succeeds."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 1}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'feedback_id' in data['data']
@pytest.mark.asyncio
async def test_submit_feedback_dislike(self, quart_test_client):
"""POST feedback with type=2 (dislike) succeeds."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 2}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_submit_feedback_invalid_type(self, quart_test_client):
"""POST feedback with invalid type returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 99}
)
assert response.status_code == 400

View File

@@ -1,259 +0,0 @@
"""
API integration tests for knowledge base endpoints.
Tests real HTTP API behavior for knowledge base management.
Run: uv run pytest tests/integration/api/test_knowledge.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""Break circular import chain for API controller."""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.knowledge',
'langbot.pkg.api.http.controller.groups.knowledge.base',
'langbot.pkg.api.http.controller.groups.knowledge.engines',
'langbot.pkg.api.http.controller.groups.knowledge.parsers',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
import langbot.pkg.api.http.controller.groups.knowledge.base as _knowledge # noqa: E402, F401
yield
@pytest.fixture(scope='module')
def fake_knowledge_app():
"""Create FakeApp with knowledge services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# Auth services
app.user_service = Mock()
app.user_service.is_initialized = AsyncMock(return_value=True)
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
app.apikey_service = Mock()
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
# Knowledge service
app.knowledge_service = Mock()
app.knowledge_service.get_knowledge_bases = AsyncMock(return_value=[
{
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
'created_at': '2024-01-01T00:00:00',
'updated_at': '2024-01-01T00:00:00',
}
])
app.knowledge_service.get_knowledge_base = AsyncMock(return_value={
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
})
app.knowledge_service.create_knowledge_base = AsyncMock(return_value={'uuid': 'new-kb-uuid'})
app.knowledge_service.update_knowledge_base = AsyncMock(return_value={})
app.knowledge_service.delete_knowledge_base = AsyncMock()
app.knowledge_service.get_files_by_knowledge_base = AsyncMock(return_value=[
{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}
])
app.knowledge_service.store_file = AsyncMock(return_value={'task_id': 'test-task-id'})
app.knowledge_service.delete_file = AsyncMock()
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[
{'content': 'test result', 'score': 0.95}
])
# RAG manager
app.rag_mgr = Mock()
return app
@pytest.fixture(scope='module')
async def quart_test_client(fake_knowledge_app, http_controller_cls):
"""Create Quart test client (module scope to avoid route re-registration)."""
controller = http_controller_cls(fake_knowledge_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestKnowledgeBaseEndpoints:
"""Tests for /api/v1/knowledge/bases endpoints."""
@pytest.mark.asyncio
async def test_get_knowledge_bases_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases returns knowledge base list."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
assert 'bases' in data['data']
@pytest.mark.asyncio
async def test_create_knowledge_base_success(self, quart_test_client):
"""POST /api/v1/knowledge/bases creates new knowledge base."""
response = await quart_test_client.post(
'/api/v1/knowledge/bases',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']
@pytest.mark.asyncio
async def test_get_single_knowledge_base_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases/{uuid} returns knowledge base."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'base' in data['data']
@pytest.mark.asyncio
async def test_update_knowledge_base_success(self, quart_test_client):
"""PUT /api/v1/knowledge/bases/{uuid} updates knowledge base."""
response = await quart_test_client.put(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated KB'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_delete_knowledge_base_success(self, quart_test_client):
"""DELETE /api/v1/knowledge/bases/{uuid} deletes knowledge base."""
response = await quart_test_client.delete(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestKnowledgeBaseFilesEndpoints:
"""Tests for knowledge base files endpoints."""
@pytest.mark.asyncio
async def test_get_files_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases/{uuid}/files returns files."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases/test-kb-uuid/files',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'files' in data['data']
@pytest.mark.asyncio
async def test_add_file_to_knowledge_base(self, quart_test_client):
"""POST /api/v1/knowledge/bases/{uuid}/files adds file."""
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/files',
headers={'Authorization': 'Bearer test_token'},
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'task_id' in data['data']
@pytest.mark.asyncio
async def test_delete_file_from_knowledge_base(self, quart_test_client):
"""DELETE /api/v1/knowledge/bases/{uuid}/files/{file_id}."""
response = await quart_test_client.delete(
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestKnowledgeBaseRetrieveEndpoint:
"""Tests for knowledge base retrieval endpoint."""
@pytest.mark.asyncio
async def test_retrieve_knowledge_success(self, quart_test_client):
"""POST /api/v1/knowledge/bases/{uuid}/retrieve."""
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
headers={'Authorization': 'Bearer test_token'},
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'results' in data['data']
@pytest.mark.asyncio
async def test_retrieve_without_query_returns_error(self, quart_test_client):
"""POST retrieve without query returns 400."""
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
headers={'Authorization': 'Bearer test_token'},
json={}
)
assert response.status_code == 400
data = await response.get_json()
assert data['code'] == -1

View File

@@ -1,330 +0,0 @@
"""
API integration tests for monitoring endpoints.
Tests real HTTP API behavior for monitoring data retrieval.
Run: uv run pytest tests/integration/api/test_monitoring.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""Break circular import chain for API controller."""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.monitoring',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
import langbot.pkg.api.http.controller.groups.monitoring as _monitoring # noqa: E402, F401
yield
@pytest.fixture(scope='module')
def fake_monitoring_app():
"""Create FakeApp with monitoring services (module scope)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# Auth services - USER_TOKEN auth requires jwt verification AND get_user_by_email
app.user_service = Mock()
app.user_service.is_initialized = AsyncMock(return_value=True)
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
# Monitoring service
app.monitoring_service = Mock()
app.monitoring_service.get_overview_metrics = AsyncMock(return_value={
'total_messages': 100,
'total_llm_calls': 50,
'total_sessions': 20,
'active_sessions': 5,
'total_errors': 2,
})
app.monitoring_service.get_messages = AsyncMock(return_value=(
[{'id': 'msg-1', 'content': 'test'}], 100
))
app.monitoring_service.get_llm_calls = AsyncMock(return_value=(
[{'id': 'llm-1'}], 50
))
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=(
[{'id': 'emb-1'}], 10
))
app.monitoring_service.get_sessions = AsyncMock(return_value=(
[{'session_id': 'sess-1'}], 20
))
app.monitoring_service.get_errors = AsyncMock(return_value=(
[{'id': 'err-1'}], 2
))
app.monitoring_service.get_session_analysis = AsyncMock(return_value={
'found': True,
'session_id': 'sess-1',
})
app.monitoring_service.get_message_details = AsyncMock(return_value={
'found': True,
'message_id': 'msg-1',
})
app.monitoring_service.get_feedback_stats = AsyncMock(return_value={'like_count': 10})
app.monitoring_service.get_feedback_list = AsyncMock(return_value=(
[{'feedback_id': 'fb-1'}], 12
))
app.monitoring_service.export_messages = AsyncMock(return_value=[{'id': 'msg-1'}])
app.monitoring_service.export_llm_calls = AsyncMock(return_value=[{'id': 'llm-1'}])
app.monitoring_service.export_errors = AsyncMock(return_value=[{'id': 'err-1'}])
app.monitoring_service.export_sessions = AsyncMock(return_value=[{'session_id': 'sess-1'}])
app.monitoring_service.export_feedback = AsyncMock(return_value=[{'id': 'fb-1'}])
app.monitoring_service.export_embedding_calls = AsyncMock(return_value=[{'id': 'emb-1'}])
app.monitoring_service._escape_csv_field = Mock(return_value='escaped')
return app
@pytest.fixture(scope='module')
async def quart_test_client(fake_monitoring_app, http_controller_cls):
"""Create Quart test client (module scope)."""
controller = http_controller_cls(fake_monitoring_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringOverviewEndpoint:
"""Tests for /api/v1/monitoring/overview endpoint."""
@pytest.mark.asyncio
async def test_get_overview_success(self, quart_test_client):
"""GET /api/v1/monitoring/overview returns metrics."""
response = await quart_test_client.get(
'/api/v1/monitoring/overview',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringMessagesEndpoint:
"""Tests for /api/v1/monitoring/messages endpoint."""
@pytest.mark.asyncio
async def test_get_messages_success(self, quart_test_client):
"""GET /api/v1/monitoring/messages returns message list."""
response = await quart_test_client.get(
'/api/v1/monitoring/messages',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'messages' in data['data']
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringLLMCallsEndpoint:
"""Tests for /api/v1/monitoring/llm-calls endpoint."""
@pytest.mark.asyncio
async def test_get_llm_calls_success(self, quart_test_client):
"""GET /api/v1/monitoring/llm-calls."""
response = await quart_test_client.get(
'/api/v1/monitoring/llm-calls',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringEmbeddingCallsEndpoint:
"""Tests for /api/v1/monitoring/embedding-calls endpoint."""
@pytest.mark.asyncio
async def test_get_embedding_calls_success(self, quart_test_client):
"""GET /api/v1/monitoring/embedding-calls."""
response = await quart_test_client.get(
'/api/v1/monitoring/embedding-calls',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringSessionsEndpoint:
"""Tests for /api/v1/monitoring/sessions endpoint."""
@pytest.mark.asyncio
async def test_get_sessions_success(self, quart_test_client):
"""GET /api/v1/monitoring/sessions."""
response = await quart_test_client.get(
'/api/v1/monitoring/sessions',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringErrorsEndpoint:
"""Tests for /api/v1/monitoring/errors endpoint."""
@pytest.mark.asyncio
async def test_get_errors_success(self, quart_test_client):
"""GET /api/v1/monitoring/errors."""
response = await quart_test_client.get(
'/api/v1/monitoring/errors',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringAllDataEndpoint:
"""Tests for /api/v1/monitoring/data endpoint."""
@pytest.mark.asyncio
async def test_get_all_data_success(self, quart_test_client):
"""GET /api/v1/monitoring/data returns all data."""
response = await quart_test_client.get(
'/api/v1/monitoring/data',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert 'overview' in data['data']
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringDetailsEndpoints:
"""Tests for detail endpoints."""
@pytest.mark.asyncio
async def test_get_session_analysis(self, quart_test_client):
"""GET /api/v1/monitoring/sessions/{id}/analysis."""
response = await quart_test_client.get(
'/api/v1/monitoring/sessions/sess-1/analysis',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_get_message_details(self, quart_test_client):
"""GET /api/v1/monitoring/messages/{id}/details."""
response = await quart_test_client.get(
'/api/v1/monitoring/messages/msg-1/details',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringFeedbackEndpoints:
"""Tests for feedback endpoints."""
@pytest.mark.asyncio
async def test_get_feedback_stats(self, quart_test_client):
"""GET /api/v1/monitoring/feedback/stats."""
response = await quart_test_client.get(
'/api/v1/monitoring/feedback/stats',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_get_feedback_list(self, quart_test_client):
"""GET /api/v1/monitoring/feedback."""
response = await quart_test_client.get(
'/api/v1/monitoring/feedback',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringExportEndpoint:
"""Tests for /api/v1/monitoring/export endpoint."""
@pytest.mark.asyncio
async def test_export_messages(self, quart_test_client):
"""GET export?type=messages returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=messages',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
assert 'text/csv' in response.content_type
@pytest.mark.asyncio
async def test_export_llm_calls(self, quart_test_client):
"""GET export?type=llm-calls returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=llm-calls',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_export_sessions(self, quart_test_client):
"""GET export?type=sessions returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=sessions',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_export_feedback(self, quart_test_client):
"""GET export?type=feedback returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=feedback',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200

View File

@@ -1,273 +0,0 @@
"""
API integration tests for pipeline endpoints.
Tests real HTTP API behavior using Quart test client with mocked services.
Extends test_smoke.py coverage for pipeline-related endpoints.
Run: uv run pytest tests/integration/api/test_pipelines.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""Break circular import chain for API controller."""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.pipelines',
'langbot.pkg.api.http.controller.groups.pipelines.pipelines',
'langbot.pkg.api.http.controller.groups.pipelines.embed',
'langbot.pkg.api.http.controller.groups.pipelines.websocket_chat',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
# Import groups after mocking to populate preregistered_groups
import langbot.pkg.api.http.controller.groups.pipelines.pipelines as _pipelines # noqa: E402, F401
yield
# ============== FAKE APPLICATION WITH PIPELINE SERVICES ==============
@pytest.fixture(scope='module')
def fake_pipeline_app():
"""Create FakeApp with pipeline-specific services (module scope for reuse)."""
app = FakeApp()
# Pipeline config
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# Auth services
app.user_service = Mock()
app.user_service.is_initialized = AsyncMock(return_value=True)
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
app.apikey_service = Mock()
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
# Pipeline service
app.pipeline_service = Mock()
app.pipeline_service.get_pipeline_metadata = AsyncMock(return_value=[
{'name': 'trigger', 'stages': []},
{'name': 'ai', 'stages': []},
])
app.pipeline_service.get_pipelines = AsyncMock(return_value=[
{
'uuid': 'test-pipeline-uuid',
'name': 'Test Pipeline',
'description': 'Test description',
'created_at': '2024-01-01T00:00:00',
'updated_at': '2024-01-01T00:00:00',
'is_default': False,
}
])
app.pipeline_service.get_pipeline = AsyncMock(return_value={
'uuid': 'test-pipeline-uuid',
'name': 'Test Pipeline',
'config': {},
})
app.pipeline_service.create_pipeline = AsyncMock(return_value={'uuid': 'new-pipeline-uuid'})
app.pipeline_service.update_pipeline = AsyncMock(return_value={})
app.pipeline_service.delete_pipeline = AsyncMock()
app.pipeline_service.copy_pipeline = AsyncMock(return_value={'uuid': 'copied-pipeline-uuid'})
# Bot service
app.bot_service = Mock()
app.bot_service.get_bots = AsyncMock(return_value=[])
app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'})
# MCP service (for extensions endpoint)
app.mcp_service = Mock()
app.mcp_service.get_mcp_servers = AsyncMock(return_value=[])
# Plugin connector (for extensions endpoint)
app.plugin_connector.list_plugins = AsyncMock(return_value=[])
return app
@pytest.fixture(scope='module')
async def quart_test_client(fake_pipeline_app, http_controller_cls):
"""Create Quart test client (module scope to avoid route re-registration)."""
controller = http_controller_cls(fake_pipeline_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
# ============== PIPELINE ENDPOINT TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelineMetadataEndpoint:
"""Tests for /api/v1/pipelines/_/metadata endpoint."""
@pytest.mark.asyncio
async def test_get_pipeline_metadata_success(self, quart_test_client):
"""GET /api/v1/pipelines/_/metadata returns metadata list."""
response = await quart_test_client.get(
'/api/v1/pipelines/_/metadata',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
assert isinstance(data['data'], dict)
@pytest.mark.asyncio
async def test_get_pipeline_metadata_requires_auth(self, quart_test_client):
"""Pipeline metadata endpoint requires authentication."""
response = await quart_test_client.get('/api/v1/pipelines/_/metadata')
assert response.status_code == 401
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelinesListEndpoint:
"""Tests for /api/v1/pipelines endpoint."""
@pytest.mark.asyncio
async def test_get_pipelines_success(self, quart_test_client):
"""GET /api/v1/pipelines returns pipeline list."""
response = await quart_test_client.get(
'/api/v1/pipelines',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
@pytest.mark.asyncio
async def test_get_pipelines_with_sort_param(self, quart_test_client):
"""GET pipelines with sort parameter."""
response = await quart_test_client.get(
'/api/v1/pipelines?sort_by=created_at&sort_order=DESC',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelinesCRUDEndpoints:
"""Tests for pipeline CRUD operations."""
@pytest.mark.asyncio
async def test_get_single_pipeline_success(self, quart_test_client):
"""GET /api/v1/pipelines/{uuid} returns pipeline."""
response = await quart_test_client.get(
'/api/v1/pipelines/test-pipeline-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
@pytest.mark.asyncio
async def test_create_pipeline_success(self, quart_test_client):
"""POST /api/v1/pipelines creates new pipeline."""
response = await quart_test_client.post(
'/api/v1/pipelines',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Pipeline', 'config': {}}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']
@pytest.mark.asyncio
async def test_update_pipeline_success(self, quart_test_client):
"""PUT /api/v1/pipelines/{uuid} updates pipeline."""
response = await quart_test_client.put(
'/api/v1/pipelines/test-pipeline-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Pipeline'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_delete_pipeline_success(self, quart_test_client):
"""DELETE /api/v1/pipelines/{uuid} deletes pipeline."""
response = await quart_test_client.delete(
'/api/v1/pipelines/test-pipeline-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_copy_pipeline_success(self, quart_test_client):
"""POST /api/v1/pipelines/{uuid}/copy copies pipeline."""
response = await quart_test_client.post(
'/api/v1/pipelines/test-pipeline-uuid/copy',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelineExtensionsEndpoint:
"""Tests for pipeline extensions."""
@pytest.mark.asyncio
async def test_get_extensions(self, quart_test_client):
"""GET /api/v1/pipelines/{uuid}/extensions."""
response = await quart_test_client.get(
'/api/v1/pipelines/test-pipeline-uuid/extensions',
headers={'Authorization': 'Bearer test_token'}
)
# Should return 200 if pipeline found
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0

View File

@@ -1,347 +0,0 @@
"""
API integration tests for provider/model endpoints.
Tests real HTTP API behavior for provider and model management.
Run: uv run pytest tests/integration/api/test_providers.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""Break circular import chain for API controller."""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.provider',
'langbot.pkg.api.http.controller.groups.provider.providers',
'langbot.pkg.api.http.controller.groups.provider.models',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
import langbot.pkg.api.http.controller.groups.provider.providers as _providers # noqa: E402, F401
import langbot.pkg.api.http.controller.groups.provider.models as _models # noqa: E402, F401
yield
@pytest.fixture(scope='module')
def fake_provider_app():
"""Create FakeApp with provider/model services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# Auth services
app.user_service = Mock()
app.user_service.is_initialized = AsyncMock(return_value=True)
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
app.apikey_service = Mock()
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
# Provider service
app.provider_service = Mock()
app.provider_service.get_providers = AsyncMock(return_value=[
{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
])
app.provider_service.get_provider = AsyncMock(return_value={
'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'
})
app.provider_service.create_provider = AsyncMock(return_value='new-provider-uuid')
app.provider_service.update_provider = AsyncMock(return_value={})
app.provider_service.delete_provider = AsyncMock()
app.provider_service.get_provider_model_counts = AsyncMock(return_value={
'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0
})
# LLM model service
app.llm_model_service = Mock()
app.llm_model_service.get_llm_models = AsyncMock(return_value=[
{'uuid': 'test-model-uuid', 'name': 'gpt-4'}
])
app.llm_model_service.get_llm_model = AsyncMock(return_value={
'uuid': 'test-model-uuid', 'name': 'gpt-4'
})
app.llm_model_service.create_llm_model = AsyncMock(return_value={'uuid': 'new-model-uuid'})
app.llm_model_service.update_llm_model = AsyncMock(return_value={})
app.llm_model_service.delete_llm_model = AsyncMock()
# Embedding model service
app.embedding_models_service = Mock()
app.embedding_models_service.get_embedding_models = AsyncMock(return_value=[])
app.embedding_models_service.create_embedding_model = AsyncMock(return_value={'uuid': 'new-embedding-uuid'})
# Rerank model service
app.rerank_models_service = Mock()
app.rerank_models_service.get_rerank_models = AsyncMock(return_value=[])
app.rerank_models_service.create_rerank_model = AsyncMock(return_value={'uuid': 'new-rerank-uuid'})
# Model manager
app.model_mgr = Mock()
app.model_mgr.load_provider = AsyncMock()
app.model_mgr.unload_provider = AsyncMock()
return app
@pytest.fixture(scope='module')
async def quart_test_client(fake_provider_app, http_controller_cls):
"""Create Quart test client (module scope to avoid route re-registration)."""
controller = http_controller_cls(fake_provider_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestProviderEndpoints:
"""Tests for /api/v1/provider endpoints."""
@pytest.mark.asyncio
async def test_get_providers_success(self, quart_test_client):
"""GET /api/v1/provider/providers returns provider list with complete structure."""
response = await quart_test_client.get(
'/api/v1/provider/providers',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
# Verify response structure completeness
providers = data['data']['providers']
assert isinstance(providers, list)
assert len(providers) == 1
# Verify required fields in provider object
provider = providers[0]
assert 'uuid' in provider
assert 'name' in provider
assert 'requester' in provider
assert provider['uuid'] == 'test-provider-uuid'
assert provider['name'] == 'OpenAI'
@pytest.mark.asyncio
async def test_get_single_provider_success(self, quart_test_client):
"""GET /api/v1/provider/providers/{uuid} returns complete provider structure."""
response = await quart_test_client.get(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
# Verify response structure
provider = data['data']['provider']
assert 'uuid' in provider
assert 'name' in provider
assert 'requester' in provider
assert provider['uuid'] == 'test-provider-uuid'
@pytest.mark.asyncio
async def test_create_provider_success(self, quart_test_client):
"""POST /api/v1/provider/providers creates new provider with uuid returned."""
response = await quart_test_client.post(
'/api/v1/provider/providers',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Provider', 'requester': 'chatcmpl'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
# Verify uuid is present and matches expected
assert 'data' in data
assert 'uuid' in data['data']
assert data['data']['uuid'] == 'new-provider-uuid'
@pytest.mark.asyncio
async def test_update_provider_success(self, quart_test_client):
"""PUT /api/v1/provider/providers/{uuid} updates provider."""
response = await quart_test_client.put(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Provider'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_delete_provider_success(self, quart_test_client):
"""DELETE /api/v1/provider/providers/{uuid} deletes provider."""
response = await quart_test_client.delete(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_get_provider_includes_model_counts(self, quart_test_client):
"""GET provider response includes model counts."""
response = await quart_test_client.get(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
# Model counts are embedded in provider response
provider_data = data['data']['provider']
assert 'llm_count' in provider_data
assert 'embedding_count' in provider_data
assert 'rerank_count' in provider_data
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestModelEndpoints:
"""Tests for /api/v1/provider/models endpoints."""
@pytest.mark.asyncio
async def test_get_llm_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/llm returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/llm',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
@pytest.mark.asyncio
async def test_get_single_llm_model_success(self, quart_test_client):
"""GET /api/v1/provider/models/llm/{uuid} returns model."""
response = await quart_test_client.get(
'/api/v1/provider/models/llm/test-model-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_create_llm_model_success(self, quart_test_client):
"""POST /api/v1/provider/models/llm creates new model."""
response = await quart_test_client.post(
'/api/v1/provider/models/llm',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']
@pytest.mark.asyncio
async def test_delete_llm_model_success(self, quart_test_client):
"""DELETE /api/v1/provider/models/llm/{uuid} deletes model."""
response = await quart_test_client.delete(
'/api/v1/provider/models/llm/test-model-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbeddingModelEndpoints:
"""Tests for /api/v1/provider/models/embedding endpoints."""
@pytest.mark.asyncio
async def test_get_embedding_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/embedding returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/embedding',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'models' in data['data']
@pytest.mark.asyncio
async def test_create_embedding_model_success(self, quart_test_client):
"""POST /api/v1/provider/models/embedding creates new model."""
response = await quart_test_client.post(
'/api/v1/provider/models/embedding',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestRerankModelEndpoints:
"""Tests for /api/v1/provider/models/rerank endpoints."""
@pytest.mark.asyncio
async def test_get_rerank_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/rerank returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/rerank',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'models' in data['data']
@pytest.mark.asyncio
async def test_create_rerank_model_success(self, quart_test_client):
"""POST /api/v1/provider/models/rerank creates new model."""
response = await quart_test_client.post(
'/api/v1/provider/models/rerank',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']

View File

@@ -1,345 +0,0 @@
"""
API smoke integration tests.
Tests real HTTP API behavior using Quart test client.
Validates controller/service/routing wiring without real provider/platform.
Run: uv run pytest tests/integration/api/test_smoke.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
Break circular import chain for API controller using isolated_sys_modules.
Chain: http_controller → groups/plugins → core.app → pipeline entities
We need to mock core.app to prevent the circular chain when importing HTTPController.
But we must allow groups to be imported to populate preregistered_groups.
"""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
# Mock core.app with minimal Application that groups can reference
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
# Mock core.entities with proper Enum
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
# Modules to clear (force re-import after mocking)
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.system',
'langbot.pkg.api.http.controller.groups.user',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
# Import groups after mocking core.app/core.entities
import langbot.pkg.api.http.controller.group as _group_module # noqa: E402, F401
import langbot.pkg.api.http.controller.groups.system as _system_group # noqa: E402, F401
import langbot.pkg.api.http.controller.groups.user as _user_group # noqa: E402, F401
yield
# ============== FAKE APPLICATION FOR API TESTS ==============
@pytest.fixture
def fake_api_app():
"""
Create minimal FakeApp for API smoke tests with all required services.
Uses tests.factories.FakeApp as base and adds API-specific services.
"""
app = FakeApp()
# API-specific config
app.instance_config.data.update({
'api': {'port': 5300},
'plugin': {'enable_marketplace': True},
'space': {'url': 'https://space.langbot.app'},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# API-specific services
app.user_service = Mock()
app.user_service.is_initialized = AsyncMock(return_value=False)
app.user_service.authenticate = AsyncMock(return_value='fake_token')
app.user_service.create_user = AsyncMock()
app.user_service.verify_jwt_token = AsyncMock(side_effect=ValueError('Invalid token'))
app.user_service.get_user_by_email = AsyncMock(return_value=Mock())
app.user_service.generate_jwt_token = AsyncMock(return_value='fake_token')
app.apikey_service = Mock()
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
app.maintenance_service = Mock()
app.maintenance_service.get_storage_analysis = AsyncMock(return_value={})
app.plugin_connector.is_enable_plugin = False
app.plugin_connector.ping_plugin_runtime = AsyncMock()
app.task_mgr.get_tasks_dict = Mock(return_value={'tasks': []})
app.task_mgr.get_task_by_id = Mock(return_value=None)
# Required by controller groups
app.model_mgr = Mock()
app.platform_mgr = Mock()
app.pipeline_pool = Mock()
app.pipeline_mgr = Mock()
return app
# ============== QUART TEST CLIENT FIXTURE ==============
@pytest.fixture
async def quart_test_client(fake_api_app, http_controller_cls):
"""
Create Quart test client with real HTTPController and route registration.
Requires mock_circular_import_chain fixture to run first (usefixtures).
"""
controller = http_controller_cls(fake_api_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
# ============== API SMOKE TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestHealthEndpoint:
"""Tests for /healthz endpoint - simplest smoke test."""
@pytest.mark.asyncio
async def test_healthz_returns_ok(self, quart_test_client):
"""
/healthz endpoint returns {'code': 0, 'msg': 'ok'}.
This tests:
- HTTPController instantiation
- Quart app creation
- Route registration
- Basic response handling
"""
response = await quart_test_client.get('/healthz')
assert response.status_code == 200
data = await response.get_json()
assert data == {'code': 0, 'msg': 'ok'}
@pytest.mark.asyncio
async def test_healthz_no_auth_required(self, quart_test_client):
"""
/healthz doesn't require authentication.
Tests that AuthType.NONE endpoints work without headers.
"""
response = await quart_test_client.get('/healthz')
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestSystemEndpoint:
"""Tests for /api/v1/system endpoints."""
@pytest.mark.asyncio
async def test_system_info_no_auth(self, quart_test_client):
"""
/api/v1/system/info returns system information without auth.
AuthType.NONE endpoint.
"""
response = await quart_test_client.get('/api/v1/system/info')
assert response.status_code == 200
data = await response.get_json()
# Verify response structure
assert data['code'] == 0
assert data['msg'] == 'ok'
assert 'data' in data
# Verify expected fields
system_data = data['data']
assert 'version' in system_data
assert 'debug' in system_data
assert 'edition' in system_data
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestProtectedEndpoints:
"""Tests for authentication/authorization behavior."""
@pytest.mark.asyncio
async def test_protected_endpoint_rejects_no_token(self, quart_test_client):
"""
Protected endpoint (USER_TOKEN) returns 401 without auth.
Tests that AuthType.USER_TOKEN properly rejects unauthorized requests.
"""
# /api/v1/user/check-token requires USER_TOKEN
response = await quart_test_client.get('/api/v1/user/check-token')
assert response.status_code == 401
data = await response.get_json()
# Verify error response structure
assert data['code'] == -1
assert 'msg' in data
@pytest.mark.asyncio
async def test_protected_endpoint_with_invalid_token(self, quart_test_client):
"""
Protected endpoint returns 401 with invalid token.
"""
response = await quart_test_client.get(
'/api/v1/user/check-token',
headers={'Authorization': 'Bearer invalid_token'}
)
assert response.status_code == 401
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestInvalidPayload:
"""Tests for error handling with invalid payloads."""
@pytest.mark.asyncio
async def test_missing_json_body(self, quart_test_client):
"""
POST endpoint without JSON body handles gracefully.
"""
# /api/v1/user/auth expects JSON with 'user' and 'password'
response = await quart_test_client.post('/api/v1/user/auth')
# Should return error (500, 400, or 401) with stable JSON structure
assert response.status_code in (400, 500, 401)
data = await response.get_json()
# Verify error response has expected structure
assert 'code' in data
assert 'msg' in data
@pytest.mark.asyncio
async def test_invalid_json_structure(self, quart_test_client):
"""
POST with wrong JSON structure returns stable error.
"""
response = await quart_test_client.post(
'/api/v1/user/auth',
json={'wrong_field': 'value'}
)
# Should return error with stable JSON structure
assert response.status_code in (400, 500, 401)
data = await response.get_json()
assert 'code' in data
assert 'msg' in data
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestUserInitEndpoint:
"""Tests for /api/v1/user/init endpoint."""
@pytest.mark.asyncio
async def test_user_init_get_returns_not_initialized(self, quart_test_client):
"""
GET /api/v1/user/init returns initialized status.
Uses fake user_service.is_initialized() = False.
"""
response = await quart_test_client.get('/api/v1/user/init')
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert data['msg'] == 'ok'
assert data['data']['initialized'] is False
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestRealImports:
"""Tests that verify real production code is imported."""
def test_http_controller_real_import(self):
"""
Verify HTTPController is real production class, not mock.
"""
from langbot.pkg.api.http.controller.main import HTTPController
assert HTTPController.__name__ == 'HTTPController'
assert hasattr(HTTPController, 'initialize')
assert hasattr(HTTPController, 'register_routes')
def test_group_real_import(self):
"""
Verify RouterGroup and AuthType are real production classes.
"""
from langbot.pkg.api.http.controller.group import RouterGroup, AuthType, preregistered_groups
assert RouterGroup.__name__ == 'RouterGroup'
assert hasattr(AuthType, 'NONE')
assert hasattr(AuthType, 'USER_TOKEN')
assert isinstance(preregistered_groups, list)
def test_system_group_registered(self):
"""
Verify SystemRouterGroup is registered in preregistered_groups.
"""
from langbot.pkg.api.http.controller.group import preregistered_groups
# Find system group
system_group = None
for g in preregistered_groups:
if g.name == 'system':
system_group = g
break
assert system_group is not None
assert system_group.path == '/api/v1/system'
def test_user_group_registered(self):
"""
Verify UserRouterGroup is registered in preregistered_groups.
"""
from langbot.pkg.api.http.controller.group import preregistered_groups
# Find user group
user_group = None
for g in preregistered_groups:
if g.name == 'user':
user_group = g
break
assert user_group is not None
assert user_group.path == '/api/v1/user'

View File

@@ -1,5 +0,0 @@
"""
Persistence integration tests package.
Tests for database migrations and storage behavior.
"""

View File

@@ -1,251 +0,0 @@
"""
SQLite migration integration tests.
Tests real Alembic migration behavior using temporary SQLite databases.
Validates the migration workflow from .github/workflows/test-migrations.yml.
Run: uv run pytest tests/integration/persistence/test_migrations.py -q
"""
from __future__ import annotations
import pytest
from sqlalchemy.ext.asyncio import create_async_engine
from langbot.pkg.entity.persistence.base import Base
from langbot.pkg.persistence.alembic_runner import (
run_alembic_upgrade,
run_alembic_stamp,
get_alembic_current,
)
pytestmark = pytest.mark.integration
@pytest.fixture
def sqlite_db_url(tmp_path):
"""Create SQLite URL with temporary database file."""
db_file = tmp_path / "test_migrations.db"
return f"sqlite+aiosqlite:///{db_file}"
@pytest.fixture
async def sqlite_engine(sqlite_db_url):
"""Create async SQLite engine."""
engine = create_async_engine(sqlite_db_url)
yield engine
await engine.dispose()
class TestSQLiteMigrationBaseline:
"""Tests for baseline stamp workflow."""
@pytest.mark.asyncio
async def test_baseline_stamp_sets_revision(self, sqlite_engine):
"""
Stamp baseline on existing tables sets correct revision.
Workflow:
1. Create tables via Base.metadata.create_all
2. Stamp with '0001_baseline'
3. Verify current revision is '0001_baseline'
"""
# Create all tables (simulates existing DB created by ORM)
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(sqlite_engine, '0001_baseline')
# Verify revision
rev = await get_alembic_current(sqlite_engine)
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
@pytest.mark.asyncio
async def test_baseline_stamp_on_empty_db(self, sqlite_engine):
"""
Stamp on empty database (no tables) still sets revision.
This is an edge case - stamping without tables.
"""
# Don't create tables - stamp directly
await run_alembic_stamp(sqlite_engine, '0001_baseline')
rev = await get_alembic_current(sqlite_engine)
assert rev == '0001_baseline'
class TestSQLiteMigrationUpgrade:
"""Tests for upgrade to head workflow."""
@pytest.mark.asyncio
async def test_upgrade_from_baseline_to_head(self, sqlite_engine):
"""
Upgrade from baseline to head applies all migrations.
Workflow:
1. Create tables
2. Stamp baseline
3. Upgrade to head
4. Verify current revision is head
"""
# Create tables
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(sqlite_engine, '0001_baseline')
# Upgrade to head
await run_alembic_upgrade(sqlite_engine, 'head')
# Verify revision
rev = await get_alembic_current(sqlite_engine)
assert rev is not None, "Expected a revision after upgrade"
# Head should be the latest migration
assert rev.startswith('0003'), f"Expected head to be 0003_*, got {rev}"
@pytest.mark.asyncio
async def test_upgrade_idempotent(self, sqlite_engine):
"""
Running upgrade to head multiple times is idempotent.
Workflow:
1. Upgrade to head
2. Get revision
3. Upgrade to head again
4. Verify same revision
"""
# Create tables
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp and upgrade
await run_alembic_stamp(sqlite_engine, '0001_baseline')
await run_alembic_upgrade(sqlite_engine, 'head')
rev1 = await get_alembic_current(sqlite_engine)
# Upgrade again - should be idempotent
await run_alembic_upgrade(sqlite_engine, 'head')
rev2 = await get_alembic_current(sqlite_engine)
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
class TestSQLiteMigrationFreshDatabase:
"""Tests for fresh database workflow."""
@pytest.mark.asyncio
async def test_fresh_db_upgrade_from_scratch(self, tmp_path):
"""
Fresh database (no tables) can be upgraded directly to head.
Workflow:
1. Create fresh engine with new DB file
2. Create tables
3. Upgrade to head
4. Verify revision
"""
# Use different DB file for fresh test
fresh_db_file = tmp_path / "test_migrations_fresh.db"
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_engine = create_async_engine(fresh_url)
# Create tables on fresh DB
async with fresh_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Upgrade to head directly (no baseline stamp)
await run_alembic_upgrade(fresh_engine, 'head')
# Verify revision
rev = await get_alembic_current(fresh_engine)
assert rev is not None, "Expected a revision on fresh DB"
await fresh_engine.dispose()
@pytest.mark.asyncio
async def test_fresh_db_without_create_all_behavior(self, tmp_path):
"""
Fresh database without create_all - test actual behavior.
This tests what happens when migrations run on truly empty DB.
The behavior is determined by Alembic and migration scripts.
EXPECTED: Either:
1. Migration succeeds (if scripts handle empty DB)
2. Migration fails with specific error (if scripts require tables)
IMPORTANT: This test verifies the ACTUAL behavior, not accepting
any arbitrary failure with try-except pass.
"""
fresh_db_file = tmp_path / "test_empty_migrations.db"
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_engine = create_async_engine(fresh_url)
# Capture the actual behavior
actual_result = None
actual_error = None
try:
await run_alembic_upgrade(fresh_engine, 'head')
rev = await get_alembic_current(fresh_engine)
actual_result = rev
except Exception as e:
actual_error = e
await fresh_engine.dispose()
# Verify specific behavior - one of two outcomes is expected
if actual_result is not None:
# Migration succeeded - verify revision exists
assert actual_result is not None, "Revision should exist after successful migration"
else:
# Migration failed - verify the error type is known
# Alembic typically raises specific errors for missing tables
assert actual_error is not None, "Error should be captured if migration failed"
# Log the error type for documentation (don't silently pass)
error_type = type(actual_error).__name__
# Acceptable error types for empty DB scenarios
acceptable_errors = [
'OperationalError', # SQLite table not found
'ProgrammingError', # SQLAlchemy errors
'CommandError', # Alembic command errors
]
assert error_type in acceptable_errors, (
f"Unexpected error type: {error_type}. "
f"This may indicate a regression in migration behavior. "
f"Error: {actual_error}"
)
class TestSQLiteMigrationGetCurrent:
"""Tests for get_alembic_current behavior."""
@pytest.mark.asyncio
async def test_get_current_on_unstamped_db_returns_none(self, sqlite_engine):
"""
get_alembic_current returns None for unstamped database.
"""
# Create tables but don't stamp
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# No stamp - should return None
rev = await get_alembic_current(sqlite_engine)
assert rev is None, f"Expected None for unstamped DB, got {rev}"
@pytest.mark.asyncio
async def test_get_current_after_stamp_returns_revision(self, sqlite_engine):
"""
get_alembic_current returns correct revision after stamp.
"""
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await run_alembic_stamp(sqlite_engine, '0001_baseline')
rev = await get_alembic_current(sqlite_engine)
assert rev == '0001_baseline'

View File

@@ -1,217 +0,0 @@
"""
PostgreSQL migration integration tests.
Tests real Alembic migration behavior using PostgreSQL database.
Marked as slow - requires external PostgreSQL service.
Run locally (requires PostgreSQL):
TEST_POSTGRES_URL=postgresql+asyncpg://user:pass@localhost:5432/test_db \
uv run pytest tests/integration/persistence/test_migrations_postgres.py -q
CI runs automatically with PostgreSQL service container.
"""
from __future__ import annotations
import os
import pytest
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import text
from langbot.pkg.entity.persistence.base import Base
from langbot.pkg.persistence.alembic_runner import (
run_alembic_upgrade,
run_alembic_stamp,
get_alembic_current,
)
pytestmark = [pytest.mark.integration, pytest.mark.slow]
@pytest.fixture
def postgres_url():
"""Get PostgreSQL URL from environment."""
url = os.environ.get('TEST_POSTGRES_URL')
if not url:
pytest.skip("TEST_POSTGRES_URL not set")
return url
@pytest.fixture
async def postgres_engine(postgres_url):
"""Create async PostgreSQL engine."""
engine = create_async_engine(postgres_url, isolation_level="AUTOCOMMIT")
yield engine
await engine.dispose()
@pytest.fixture
async def clean_tables(postgres_engine):
"""Drop all tables before and after each test for isolation."""
# Drop all tables before test
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
yield
# Drop all tables after test
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def clean_alembic_version(postgres_engine):
"""Drop alembic_version table before and after each test."""
async with postgres_engine.begin() as conn:
# Drop alembic_version table if exists
try:
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
except Exception:
pass
yield
async with postgres_engine.begin() as conn:
try:
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
except Exception:
pass
class TestPostgreSQLMigrationBaseline:
"""Tests for baseline stamp workflow on PostgreSQL."""
@pytest.mark.asyncio
async def test_postgres_baseline_stamp_sets_revision(
self, postgres_engine, clean_tables, clean_alembic_version
):
"""
Stamp baseline on existing tables sets correct revision.
Workflow:
1. Create tables via Base.metadata.create_all
2. Stamp with '0001_baseline'
3. Verify current revision is '0001_baseline'
"""
# Create all tables (simulates existing DB created by ORM)
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(postgres_engine, '0001_baseline')
# Verify revision
rev = await get_alembic_current(postgres_engine)
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
@pytest.mark.asyncio
async def test_postgres_baseline_stamp_on_empty_db(
self, postgres_engine, clean_tables, clean_alembic_version
):
"""
Stamp on empty database (no tables) still sets revision.
This is an edge case - stamping without tables.
"""
# Don't create tables - stamp directly
await run_alembic_stamp(postgres_engine, '0001_baseline')
rev = await get_alembic_current(postgres_engine)
assert rev == '0001_baseline'
class TestPostgreSQLMigrationUpgrade:
"""Tests for upgrade to head workflow on PostgreSQL."""
@pytest.mark.asyncio
async def test_postgres_upgrade_from_baseline_to_head(
self, postgres_engine, clean_tables, clean_alembic_version
):
"""
Upgrade from baseline to head applies all migrations.
Workflow:
1. Create tables
2. Stamp baseline
3. Upgrade to head
4. Verify current revision is head
"""
# Create tables
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(postgres_engine, '0001_baseline')
# Upgrade to head
await run_alembic_upgrade(postgres_engine, 'head')
# Verify revision
rev = await get_alembic_current(postgres_engine)
assert rev is not None, "Expected a revision after upgrade"
# Head should be the latest migration (0003 for current state)
assert rev.startswith('0003'), f"Expected head to be 0003_*, got {rev}"
@pytest.mark.asyncio
async def test_postgres_upgrade_idempotent(
self, postgres_engine, clean_tables, clean_alembic_version
):
"""
Running upgrade to head multiple times is idempotent.
Workflow:
1. Upgrade to head
2. Get revision
3. Upgrade to head again
4. Verify same revision
"""
# Create tables
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp and upgrade
await run_alembic_stamp(postgres_engine, '0001_baseline')
await run_alembic_upgrade(postgres_engine, 'head')
rev1 = await get_alembic_current(postgres_engine)
# Upgrade again - should be idempotent
await run_alembic_upgrade(postgres_engine, 'head')
rev2 = await get_alembic_current(postgres_engine)
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
class TestPostgreSQLMigrationGetCurrent:
"""Tests for get_alembic_current behavior on PostgreSQL."""
@pytest.mark.asyncio
async def test_postgres_get_current_on_unstamped_db_returns_none(
self, postgres_engine, clean_tables, clean_alembic_version
):
"""
get_alembic_current returns None for unstamped database.
"""
# Create tables but don't stamp
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# No stamp - should return None
rev = await get_alembic_current(postgres_engine)
assert rev is None, f"Expected None for unstamped DB, got {rev}"
@pytest.mark.asyncio
async def test_postgres_get_current_after_stamp_returns_revision(
self, postgres_engine, clean_tables, clean_alembic_version
):
"""
get_alembic_current returns correct revision after stamp.
"""
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await run_alembic_stamp(postgres_engine, '0001_baseline')
rev = await get_alembic_current(postgres_engine)
assert rev == '0001_baseline'

View File

@@ -1,5 +0,0 @@
"""
Pipeline integration tests package.
Tests for full pipeline flow using fake provider/runner.
"""

View File

@@ -1,778 +0,0 @@
"""
Pipeline full-flow integration tests.
Tests real pipeline stages with fake runner/provider.
Validates message processing through PreProcessor, Processor, and SendResponseBackStage.
Uses RuntimePipeline directly (not PipelineManager) to avoid DB dependency.
Run: uv run pytest tests/integration/pipeline -q --tb=short
"""
from __future__ import annotations
import pytest
import asyncio
from unittest.mock import AsyncMock, Mock
import sys
from tests.factories import FakeApp, text_query, mock_platform_adapter
from tests.factories.provider import FakeProvider
from tests.factories.platform import FakePlatform
pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
Break circular import chain for pipeline modules using isolated_sys_modules.
Chain: pipeline → core.app → provider.runner → http_controller → groups/plugins
We mock minimal modules to allow importing RuntimePipeline, StageInstContainer,
and stage classes without triggering full application initialization.
After mocking, we import the stage modules so decorators register them.
"""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
# Mock core.entities with LifecycleControlScope enum
mock_core_entities = Mock()
mock_core_entities.LifecycleControlScope = MockLifecycleControlScope
# Mock core.app - Application class is referenced but not instantiated
mock_core_app = Mock()
# Mock provider.runner with preregistered_runners list
mock_runner = Mock()
mock_runner.preregistered_runners = [] # Will be populated in tests
# Mock utils.importutil - prevents auto-import of runners
mock_importutil = Mock()
mock_importutil.import_modules_in_pkg = lambda pkg: None
mock_importutil.import_modules_in_pkgs = lambda pkgs: None
# Modules to clear (force re-import after mocking)
clear = [
'langbot.pkg.pipeline.stage',
'langbot.pkg.pipeline.entities',
'langbot.pkg.pipeline.pipelinemgr',
'langbot.pkg.pipeline.preproc.preproc',
'langbot.pkg.pipeline.process.process',
'langbot.pkg.pipeline.process.handler',
'langbot.pkg.pipeline.process.handlers.chat',
'langbot.pkg.pipeline.process.handlers.command',
'langbot.pkg.pipeline.respback.respback',
'langbot.pkg.provider.runner',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.entities': mock_core_entities,
'langbot.pkg.core.app': mock_core_app,
'langbot.pkg.provider.runner': mock_runner,
'langbot.pkg.utils.importutil': mock_importutil,
'langbot.pkg.pipeline.controller': Mock(),
'langbot.pkg.pipeline.pipelinemgr': Mock(),
},
clear=clear,
):
# Import stage modules AFTER clearing so decorators register them
from importlib import import_module
# Import stage base first
import_module('langbot.pkg.pipeline.stage')
# Import entities
import_module('langbot.pkg.pipeline.entities')
# Import specific stages to register them
import_module('langbot.pkg.pipeline.preproc.preproc')
import_module('langbot.pkg.pipeline.process.process')
import_module('langbot.pkg.pipeline.respback.respback')
# Import pipelinemgr for RuntimePipeline
import_module('langbot.pkg.pipeline.pipelinemgr')
yield
# ============== FAKE RUNNER ==============
class FakeRunner:
"""Minimal fake runner class for pipeline integration tests.
Note: preregistered_runners expects a CLASS, not an instance.
The handler calls runner_cls(self.ap, query.pipeline_config) to instantiate.
"""
name = 'local-agent'
def __init__(self, app=None, config=None):
self.app = app
self.config = config or {}
self._provider = FakeProvider()
# Instance-level configuration set via class attribute
self._response_text = "fake response"
self._raise_error = None
@classmethod
def returns(cls, text: str):
"""Create a runner class configured to return specific text."""
# We create a subclass with configured response
class ConfiguredRunner(cls):
name = cls.name
_response_text = text
_raise_error = None
def __init__(self, app=None, config=None):
super().__init__(app, config)
self._response_text = text
return ConfiguredRunner
@classmethod
def raises(cls, error: Exception):
"""Create a runner class configured to raise an error."""
class ConfiguredRunner(cls):
name = cls.name
_response_text = None
_raise_error = error
def __init__(self, app=None, config=None):
super().__init__(app, config)
self._raise_error = error
return ConfiguredRunner
async def run(self, query):
"""Run the fake provider and yield messages."""
from langbot_plugin.api.entities.builtin.provider.message import Message
# Use the configured response/error
if self._raise_error:
raise self._raise_error
# Yield a simple message
yield Message(role='assistant', content=self._response_text)
# ============== PIPELINE APP FIXTURE ==============
@pytest.fixture
def pipeline_app():
"""
Create FakeApp with all dependencies required by pipeline stages.
PreProcessor needs: sess_mgr, model_mgr, tool_mgr, plugin_connector
Processor needs: instance_config, plugin_connector
SendResponseBackStage needs: logger
ChatMessageHandler needs: telemetry, survey
"""
app = FakeApp()
# Session/conversation mocks for PreProcessor
mock_session = Mock()
mock_session.launcher_type = Mock()
mock_session.launcher_type.value = 'person'
mock_session.launcher_id = 12345
mock_session.sender_id = 12345
mock_session.use_prompt_name = 'default'
mock_session.using_conversation = None
# Create a simple class to mimic Prompt behavior
class MockPrompt:
def __init__(self, name, messages):
self.name = name
self.messages = messages
def copy(self):
return MockPrompt(self.name, list(self.messages))
# Create real lists for messages
prompt_messages_list = []
messages_list = []
mock_prompt = MockPrompt('default', prompt_messages_list)
mock_conversation = Mock()
mock_conversation.prompt = mock_prompt
mock_conversation.messages = messages_list
mock_conversation.uuid = 'test-conversation-uuid'
mock_conversation.update_time = None
mock_conversation.create_time = None
app.sess_mgr.get_session = AsyncMock(return_value=mock_session)
app.sess_mgr.get_conversation = AsyncMock(return_value=mock_conversation)
# Model mock for PreProcessor
mock_model = Mock()
mock_model.model_entity = Mock()
mock_model.model_entity.uuid = 'test-model-uuid'
mock_model.model_entity.name = 'test-model'
mock_model.model_entity.abilities = ['func_call', 'vision']
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=mock_model)
# Tool manager mock
app.tool_mgr.get_all_tools = AsyncMock(return_value=[])
# Telemetry mock (required by ChatMessageHandler)
app.telemetry = Mock()
app.telemetry.start_send_task = AsyncMock()
# Survey mock
app.survey = None
return app
@pytest.fixture
def fake_platform_adapter():
"""Create a fake platform adapter for outbound capture."""
platform = FakePlatform(stream_output_supported=False)
adapter = mock_platform_adapter(platform)
return adapter, platform
@pytest.fixture
def set_fake_runner():
"""Factory fixture to set a fake runner CLASS in preregistered_runners."""
def _set_runner(runner_cls):
# preregistered_runners expects a list of runner classes
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_cls]
return _set_runner
# ============== PIPELINE CONFIGURATION ==============
def create_minimal_pipeline_config():
"""Create minimal pipeline configuration for tests."""
return {
'ai': {
'runner': {'runner': 'local-agent', 'expire-time': None},
'local-agent': {
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
'prompt': 'default',
'knowledge-bases': [],
},
},
'output': {
'force-delay': {'min': 0.0, 'max': 0.0},
'misc': {
'at-sender': False,
'quote-origin': False,
'exception-handling': 'show-hint',
'failure-hint': 'Request failed.',
},
},
'trigger': {
'misc': {'combine-quote-message': False},
},
}
# ============== HELPER TO PROCESS COROUTINE/GENERATOR ==============
async def collect_processor_results(processor, query, stage_name):
"""
Helper to handle the coroutine -> async_generator pattern.
Processor.process() returns a coroutine that yields an async_generator.
This helper handles both cases like RuntimePipeline does.
"""
result = processor.process(query, stage_name)
# Handle coroutine (await it to get async_generator)
if asyncio.iscoroutine(result):
result = await result
# Now iterate over async_generator
results = []
async for item in result:
results.append(item)
return results
# ============== TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelineStageChainReal:
"""Tests for real pipeline stage chain."""
@pytest.mark.asyncio
async def test_import_pipeline_modules(self):
"""Verify we can import real pipeline modules."""
from langbot.pkg.pipeline import stage, entities
from langbot.pkg.pipeline import pipelinemgr
assert hasattr(stage, 'PipelineStage')
assert hasattr(stage, 'preregistered_stages')
assert hasattr(entities, 'ResultType')
assert hasattr(entities, 'StageProcessResult')
assert hasattr(pipelinemgr, 'RuntimePipeline')
assert hasattr(pipelinemgr, 'StageInstContainer')
@pytest.mark.asyncio
async def test_stage_preregistration(self):
"""Verify stages are preregistered after fixture imports them."""
from langbot.pkg.pipeline import stage
# Check that our target stages are registered
assert 'PreProcessor' in stage.preregistered_stages
assert 'MessageProcessor' in stage.preregistered_stages
assert 'SendResponseBackStage' in stage.preregistered_stages
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPreProcessorStage:
"""Tests for PreProcessor stage alone."""
@pytest.mark.asyncio
async def test_preproc_continues_on_valid_query(self, pipeline_app, fake_platform_adapter):
"""PreProcessor should return CONTINUE for valid text query."""
from langbot.pkg.pipeline import entities
from langbot.pkg.pipeline.preproc import preproc
adapter, platform = fake_platform_adapter
# Create query with adapter
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
# Mock plugin_connector for PromptPreProcessing event
mock_event_ctx = Mock()
mock_event_ctx.event = Mock()
mock_event_ctx.event.default_prompt = [] # Real list
mock_event_ctx.event.prompt = [] # Real list
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create PreProcessor stage
preproc_stage = preproc.PreProcessor(pipeline_app)
result = await preproc_stage.process(query, 'PreProcessor')
assert result.result_type == entities.ResultType.CONTINUE
assert result.new_query.session is not None
assert result.new_query.user_message is not None
@pytest.mark.asyncio
async def test_preproc_sets_user_message(self, pipeline_app, fake_platform_adapter):
"""PreProcessor should set user_message from message_chain."""
from langbot.pkg.pipeline import entities
from langbot.pkg.pipeline.preproc import preproc
adapter, platform = fake_platform_adapter
query = text_query("test message content")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
# Mock plugin_connector for PromptPreProcessing event
mock_event_ctx = Mock()
mock_event_ctx.event = Mock()
mock_event_ctx.event.default_prompt = []
mock_event_ctx.event.prompt = []
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
preproc_stage = preproc.PreProcessor(pipeline_app)
result = await preproc_stage.process(query, 'PreProcessor')
assert result.result_type == entities.ResultType.CONTINUE
# Check user_message content
assert result.new_query.user_message is not None
assert result.new_query.user_message.role == 'user'
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestProcessorStage:
"""Tests for MessageProcessor stage."""
@pytest.mark.asyncio
async def test_processor_calls_chat_handler(self, pipeline_app, fake_platform_adapter, set_fake_runner):
"""Processor should route to ChatMessageHandler for non-command messages."""
adapter, platform = fake_platform_adapter
# Set fake runner that returns pong
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
set_fake_runner(fake_runner)
# Create query
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = []
# Mock plugin_connector to not prevent default
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.user_message_alter = None
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
# Collect results using helper
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) >= 1
# Check that resp_messages was populated
assert len(query.resp_messages) >= 1
@pytest.mark.asyncio
async def test_processor_prevent_default_without_reply_interrupts(self, pipeline_app, fake_platform_adapter):
"""Processor should INTERRUPT when plugin prevents default without reply."""
from langbot.pkg.pipeline import entities
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
# Mock plugin_connector to prevent default without reply
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=True)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = None
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
@pytest.mark.asyncio
async def test_processor_prevent_default_with_reply_continues(self, pipeline_app, fake_platform_adapter):
"""Processor should CONTINUE when plugin prevents default with reply."""
from langbot.pkg.pipeline import entities
from tests.factories.message import text_chain
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = []
# Create reply chain
reply_chain = text_chain("plugin response")
# Mock plugin_connector to prevent default with reply
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=True)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = reply_chain
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
assert len(query.resp_messages) == 1
assert query.resp_messages[0] == reply_chain
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestRunnerExceptionFlow:
"""Tests for runner exception handling."""
@pytest.mark.asyncio
async def test_runner_exception_yields_interrupt(self, pipeline_app, fake_platform_adapter, set_fake_runner):
"""Runner exception should yield INTERRUPT with error notices."""
from langbot.pkg.pipeline import entities
adapter, platform = fake_platform_adapter
# Set fake runner that raises exception
fake_runner = FakeRunner().raises(ValueError("API Error: rate limit exceeded"))
set_fake_runner(fake_runner)
# Create query with exception handling config
config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'show-hint'
config['output']['misc']['failure-hint'] = 'Request failed.'
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = config
# Mock plugin_connector to not prevent default
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.user_message_alter = None
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
assert results[0].user_notice == 'Request failed.'
assert results[0].error_notice is not None
@pytest.mark.asyncio
async def test_runner_exception_show_error_mode(self, pipeline_app, fake_platform_adapter, set_fake_runner):
"""show-error mode should show actual exception message."""
from langbot.pkg.pipeline import entities
adapter, platform = fake_platform_adapter
# Set fake runner that raises specific exception
fake_runner = FakeRunner().raises(RuntimeError("Custom runtime error"))
set_fake_runner(fake_runner)
# Create query with show-error mode
config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'show-error'
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = config
# Mock plugin_connector to not prevent default
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.user_message_alter = None
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
assert 'Custom runtime error' in results[0].user_notice
@pytest.mark.asyncio
async def test_runner_exception_hide_mode(self, pipeline_app, fake_platform_adapter, set_fake_runner):
"""hide mode should not show user notice."""
from langbot.pkg.pipeline import entities
adapter, platform = fake_platform_adapter
# Set fake runner that raises exception
fake_runner = FakeRunner().raises(Exception("Hidden error"))
set_fake_runner(fake_runner)
# Create query with hide mode
config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'hide'
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = config
# Mock plugin_connector to not prevent default
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.user_message_alter = None
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
assert results[0].user_notice is None
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestSendResponseBackStage:
"""Tests for SendResponseBackStage."""
@pytest.mark.asyncio
async def test_send_response_calls_adapter(self, pipeline_app, fake_platform_adapter):
"""SendResponseBackStage should call adapter.reply_message."""
from langbot.pkg.pipeline import entities
from langbot.pkg.pipeline.respback import respback
from tests.factories.message import text_chain
from langbot_plugin.api.entities.builtin.provider.message import Message
adapter, platform = fake_platform_adapter
# Create query with response message
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
# Add response message
query.resp_messages = [Message(role='assistant', content='test response')]
query.resp_message_chain = [text_chain('test response')]
# Create SendResponseBackStage
respback_stage = respback.SendResponseBackStage(pipeline_app)
result = await respback_stage.process(query, 'SendResponseBackStage')
assert result.result_type == entities.ResultType.CONTINUE
# Check that adapter was called
outbound = platform.get_outbound_messages()
assert len(outbound) == 1
assert outbound[0]['type'] == 'reply'
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestStageChainIntegration:
"""Tests for full stage chain (PreProcessor -> Processor -> SendResponseBackStage)."""
@pytest.mark.asyncio
async def test_full_chain_text_message_flow(self, pipeline_app, fake_platform_adapter, set_fake_runner):
"""
Full chain: text message -> PreProcessor -> Processor -> SendResponseBackStage.
Validates:
- PreProcessor sets up session, user_message
- Processor calls runner and populates resp_messages
- SendResponseBackStage calls adapter.reply_message
"""
from langbot.pkg.pipeline import entities
from langbot.pkg.pipeline.preproc import preproc
from langbot.pkg.pipeline.process import process
from langbot.pkg.pipeline.respback import respback
adapter, platform = fake_platform_adapter
# Set fake runner
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
set_fake_runner(fake_runner)
# Create query
config = create_minimal_pipeline_config()
query = text_query("ping")
query.adapter = adapter
query.pipeline_config = config
query.resp_messages = []
query.resp_message_chain = []
# Mock plugin_connector for PreProcessor and Processor events
mock_event_ctx_preproc = Mock()
mock_event_ctx_preproc.event = Mock()
mock_event_ctx_preproc.event.default_prompt = []
mock_event_ctx_preproc.event.prompt = []
mock_event_ctx_processor = Mock()
mock_event_ctx_processor.is_prevented_default = Mock(return_value=False)
mock_event_ctx_processor.event = Mock()
mock_event_ctx_processor.event.user_message_alter = None
pipeline_app.plugin_connector.emit_event = AsyncMock()
pipeline_app.plugin_connector.emit_event.side_effect = [
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_processor, # Processor NormalMessageReceived
]
# Create stages
preproc_stage = preproc.PreProcessor(pipeline_app)
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(config)
respback_stage = respback.SendResponseBackStage(pipeline_app)
# Run PreProcessor
result1 = await preproc_stage.process(query, 'PreProcessor')
assert result1.result_type == entities.ResultType.CONTINUE
query = result1.new_query
# Run Processor
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) >= 1
# Build resp_message_chain from resp_messages
from tests.factories.message import text_chain
for resp_msg in query.resp_messages:
if resp_msg.content:
query.resp_message_chain.append(text_chain(resp_msg.content))
# Run SendResponseBackStage
result3 = await respback_stage.process(query, 'SendResponseBackStage')
assert result3.result_type == entities.ResultType.CONTINUE
# Verify adapter was called
outbound = platform.get_outbound_messages()
assert len(outbound) >= 1
@pytest.mark.asyncio
async def test_chain_stops_on_interrupt(self, pipeline_app, fake_platform_adapter):
"""
Chain should stop when a stage returns INTERRUPT.
PreProcessor returns CONTINUE, Processor returns INTERRUPT (prevent_default).
"""
from langbot.pkg.pipeline import entities
from langbot.pkg.pipeline.preproc import preproc
from langbot.pkg.pipeline.process import process
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
# Mock plugin_connector - PreProcessor continues, Processor interrupts
mock_event_ctx_preproc = Mock()
mock_event_ctx_preproc.event = Mock()
mock_event_ctx_preproc.event.default_prompt = []
mock_event_ctx_preproc.event.prompt = []
mock_event_ctx_processor = Mock()
mock_event_ctx_processor.is_prevented_default = Mock(return_value=True)
mock_event_ctx_processor.event = Mock()
mock_event_ctx_processor.event.reply_message_chain = None
pipeline_app.plugin_connector.emit_event = AsyncMock()
pipeline_app.plugin_connector.emit_event.side_effect = [
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_processor, # Processor NormalMessageReceived
]
# Create stages
preproc_stage = preproc.PreProcessor(pipeline_app)
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
# Run PreProcessor
result1 = await preproc_stage.process(query, 'PreProcessor')
assert result1.result_type == entities.ResultType.CONTINUE
query = result1.new_query
# Run Processor - should INTERRUPT
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
# Chain stops here - no resp_messages
assert len(query.resp_messages) == 0

View File

@@ -1,6 +0,0 @@
"""
Smoke tests package.
Smoke tests verify basic functionality works without testing edge cases.
Run with: uv run pytest tests/smoke/ -q
"""

View File

@@ -1,351 +0,0 @@
"""
Minimal fake flow smoke tests for LangBot.
These tests verify basic component interactions using fake providers and platforms.
Not a full pipeline integration test - tests individual factory components.
For full pipeline tests, see tests/integration/ (planned).
"""
from __future__ import annotations
import pytest
from tests.factories import (
FakeApp,
FakeProvider,
FakePlatform,
text_query,
fake_provider_pong,
fake_model,
mock_platform_adapter,
)
class TestFakeMessageFlow:
"""Smoke tests for fake message flow through pipeline."""
@pytest.mark.asyncio
async def test_fake_app_creation(self):
"""Test FakeApp can be created with all dependencies."""
app = FakeApp()
assert app.logger is not None
assert app.sess_mgr is not None
assert app.model_mgr is not None
assert app.tool_mgr is not None
assert app.persistence_mgr is not None
assert app.query_pool is not None
assert app.instance_config is not None
# Verify default config
assert app.instance_config.data["command"]["prefix"] == ["/", "!"]
assert app.instance_config.data["command"]["enable"] is True
@pytest.mark.asyncio
async def test_fake_provider_returns_text(self):
"""Test FakeProvider returns configured response."""
provider = FakeProvider(default_response="test response")
# Create mock model with provider
model = fake_model(provider=provider)
# Create a simple query
query = text_query("hello")
# Simulate invoke
result = await provider.invoke_llm(
query=query,
model=model,
messages=[],
funcs=[],
extra_args={},
)
assert result is not None
assert result.role == "assistant"
assert result.content == "test response"
@pytest.mark.asyncio
async def test_fake_provider_pong(self):
"""Test FakeProvider returns LANGBOT_FAKE_PONG marker."""
provider = fake_provider_pong()
model = fake_model(provider=provider)
query = text_query("ping")
result = await provider.invoke_llm(
query=query,
model=model,
messages=[],
funcs=[],
extra_args={},
)
assert result.content == FakeProvider.PONG_RESPONSE
@pytest.mark.asyncio
async def test_fake_provider_streaming(self):
"""Test FakeProvider streaming response."""
provider = FakeProvider().returns_streaming(["Hello", " World"])
model = fake_model(provider=provider)
query = text_query("hello")
chunks = []
# invoke_llm_stream returns an async generator, don't await it
async for chunk in provider.invoke_llm_stream(
query=query,
model=model,
messages=[],
funcs=[],
extra_args={},
):
chunks.append(chunk)
assert len(chunks) == 2
assert chunks[0].content == "Hello"
assert chunks[1].content == " World"
assert chunks[1].is_final is True
@pytest.mark.asyncio
async def test_fake_provider_timeout(self):
"""Test FakeProvider simulates timeout error."""
provider = FakeProvider().timeout()
model = fake_model(provider=provider)
query = text_query("hello")
with pytest.raises(TimeoutError, match="Provider timeout"):
await provider.invoke_llm(
query=query,
model=model,
messages=[],
funcs=[],
extra_args={},
)
@pytest.mark.asyncio
async def test_fake_provider_rate_limit(self):
"""Test FakeProvider simulates rate limit error."""
provider = FakeProvider().rate_limit()
model = fake_model(provider=provider)
query = text_query("hello")
with pytest.raises(Exception, match="Rate limit exceeded"):
await provider.invoke_llm(
query=query,
model=model,
messages=[],
funcs=[],
extra_args={},
)
@pytest.mark.asyncio
async def test_fake_provider_captures_requests(self):
"""Test FakeProvider captures request arguments."""
provider = FakeProvider()
model = fake_model(name="gpt-4", provider=provider)
query = text_query("hello")
await provider.invoke_llm(
query=query,
model=model,
messages=[{"role": "user", "content": "hello"}],
funcs=[{"name": "test_func"}],
extra_args={"temperature": 0.7},
)
captured = provider.get_captured_requests()
assert len(captured) == 1
assert captured[0]["model"] == "gpt-4"
assert captured[0]["messages"] == [{"role": "user", "content": "hello"}]
assert captured[0]["funcs"] == [{"name": "test_func"}]
assert captured[0]["extra_args"] == {"temperature": 0.7}
@pytest.mark.asyncio
async def test_fake_platform_capture_outbound(self):
"""Test FakePlatform captures outbound messages."""
platform = FakePlatform(bot_account_id="test-bot")
query = text_query("hello")
# Simulate sending reply
from tests.factories.message import text_chain
reply_chain = text_chain("response text")
event = query.message_event
await platform.reply_message(event, reply_chain, quote_origin=False)
# Verify captured
outbound = platform.get_outbound_messages()
assert len(outbound) == 1
assert outbound[0]["type"] == "reply"
assert outbound[0]["message"] == reply_chain
@pytest.mark.asyncio
async def test_fake_platform_friend_message(self):
"""Test FakePlatform creates friend message events."""
platform = FakePlatform(bot_account_id="test-bot")
event = platform.create_friend_message(
text="hello bot",
sender_id=12345,
nickname="TestUser",
)
assert event.type == "FriendMessage"
assert event.sender.id == 12345
assert event.sender.nickname == "TestUser"
assert str(event.message_chain) == "hello bot"
@pytest.mark.asyncio
async def test_fake_platform_group_message_with_mention(self):
"""Test FakePlatform creates group message with @mention."""
platform = FakePlatform(bot_account_id="test-bot")
event = platform.create_group_message(
text="hello everyone",
sender_id=12345,
group_id=99999,
mention_bot=True,
)
assert event.type == "GroupMessage"
assert event.sender.id == 12345
assert event.group.id == 99999
# Check message chain has @mention
chain = event.message_chain
assert len(chain) >= 2 # At + Plain
@pytest.mark.asyncio
async def test_query_factories_basic(self):
"""Test basic query factory functions."""
# Text query
q1 = text_query("hello world")
assert q1.launcher_type.value == "person"
assert str(q1.message_chain) == "hello world"
# Group query
from tests.factories import group_text_query
q2 = group_text_query("hello group", group_id=88888)
assert q2.launcher_type.value == "group"
assert q2.launcher_id == 88888
# Command query
from tests.factories import command_query
q3 = command_query("help", prefix="/")
assert str(q3.message_chain) == "/help"
# Mention query
from tests.factories import mention_query
q4 = mention_query("hi", target="test-bot", group_id=77777)
assert q4.launcher_type.value == "group"
@pytest.mark.asyncio
async def test_fake_platform_send_failure(self):
"""Test FakePlatform simulates send failure."""
platform = FakePlatform().send_failure()
query = text_query("hello")
from tests.factories.message import text_chain
with pytest.raises(Exception, match="Platform send failure"):
await platform.reply_message(
query.message_event,
text_chain("response"),
)
@pytest.mark.asyncio
async def test_mock_platform_adapter(self):
"""Test mock_platform_adapter helper."""
platform = FakePlatform(bot_account_id="bot-123")
adapter = mock_platform_adapter(platform)
assert adapter.bot_account_id == "bot-123"
assert adapter._fake_platform is platform
# Test reply_message is wired
from tests.factories.message import text_chain
query = text_query("test")
await adapter.reply_message(query.message_event, text_chain("response"))
# Verify platform captured it
assert len(platform.get_outbound_messages()) == 1
class TestMessageFlowIntegration:
"""Minimal fake flow integration tests.
These tests verify component interactions but do NOT run full LangBot pipeline.
For real pipeline tests, integration tests are needed (planned).
"""
@pytest.mark.asyncio
async def test_minimal_message_flow(self):
"""Minimal fake flow test: fake query -> fake provider -> fake platform.
This test verifies:
1. Fake text query is created
2. Fake provider returns LANGBOT_FAKE_PONG
3. Fake platform captures outbound response
4. No unexpected exception
Note: This does NOT run actual LangBot pipeline stages.
"""
# Setup
platform = FakePlatform(bot_account_id="test-bot")
provider = fake_provider_pong()
model = fake_model(provider=provider)
# Create inbound message
query = text_query("ping")
# Simulate provider processing
response = await provider.invoke_llm(
query=query,
model=model,
messages=[{"role": "user", "content": "ping"}],
funcs=[],
extra_args={},
)
# Verify provider returned pong
assert response.content == FakeProvider.PONG_RESPONSE
# Simulate platform sending response
from tests.factories.message import text_chain
reply_chain = text_chain(response.content)
await platform.reply_message(query.message_event, reply_chain)
# Verify platform captured outbound
outbound = platform.get_outbound_messages()
assert len(outbound) == 1
assert outbound[0]["type"] == "reply"
assert str(outbound[0]["message"]) == FakeProvider.PONG_RESPONSE
@pytest.mark.asyncio
async def test_streaming_message_flow(self):
"""Smoke test: streaming message flow."""
platform = FakePlatform().supports_streaming()
provider = FakeProvider().returns_streaming(["Hello", " there"])
model = fake_model(provider=provider)
query = text_query("hi")
chunks = []
async for chunk in provider.invoke_llm_stream(
query=query,
model=model,
messages=[],
funcs=[],
extra_args={},
):
chunks.append(chunk)
# Verify streaming worked
assert len(chunks) == 2
full_content = "".join(c.content for c in chunks)
assert full_content == "Hello there"
# Verify platform supports streaming
assert await platform.is_stream_output_supported() is True

View File

@@ -1,66 +0,0 @@
"""
PoC test for CWE-94: Authenticated RCE via exec() on user-supplied Python code.
The /api/v1/system/debug/exec endpoint passes raw HTTP body to exec(),
allowing arbitrary code execution when debug_mode is True.
This test verifies that:
1. The exec() endpoint is removed from the codebase entirely.
2. No route matches /api/v1/system/debug/exec.
"""
import ast
import pathlib
# Resolve project root (one level up from tests/)
_PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
VULN_FILE = (
_PROJECT_ROOT
/ "src"
/ "langbot"
/ "pkg"
/ "api"
/ "http"
/ "controller"
/ "groups"
/ "system.py"
)
def test_no_exec_call_in_system_controller():
"""Verify there is no exec() call in system.py that takes user input."""
with open(VULN_FILE, "r") as f:
source = f.read()
tree = ast.parse(source)
exec_calls = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
func = node.func
# Match bare exec() call
if isinstance(func, ast.Name) and func.id == "exec":
exec_calls.append(node.lineno)
assert len(exec_calls) == 0, (
f"Found exec() call(s) at line(s) {exec_calls} in system.py. "
"User-supplied code must never be passed to exec()."
)
def test_no_debug_exec_route():
"""Verify the /debug/exec route is not registered."""
with open(VULN_FILE, "r") as f:
source = f.read()
assert "debug/exec" not in source, (
"The /debug/exec route still exists in system.py. "
"This endpoint allows arbitrary code execution and must be removed."
)
if __name__ == "__main__":
test_no_exec_call_in_system_controller()
test_no_debug_exec_route()
print("All tests passed!")

View File

@@ -1,179 +0,0 @@
# 单元测试覆盖率排除说明
## 排除范围
以下外部适配器模块不纳入测试覆盖目标,因为它们需要实际外部环境才能测试:
### 1. 消息平台适配器 (`platform/sources/`)
- **路径**: `src/langbot/pkg/platform/sources/`
- **模块**: aiocqhttp, dingtalk, discord, feishu, gestep, kook, lark, slack, telegram, wecom, wechatpv, wechatmp, qqbot
- **排除原因**: 需要真实消息平台账号和 webhook 连接,无法纯单元测试
- **测试方式**: 需要 mock 平台 API 或集成测试环境
- **状态**: 后续可补充 mock 测试
### 2. LLM Requester (`provider/modelmgr/requesters/`)
- **路径**: `src/langbot/pkg/provider/modelmgr/requesters/`
- **模块**: deepseek, openai, anthropic, gemini, moonshot, ollama, zhipuai 等 20+ 个 requester
- **排除原因**: 需要真实 LLM API 密钥和网络请求,涉及付费 API 调用
- **测试方式**: 需要 mock HTTP 响应或使用 fake LLM server
- **状态**: 后续可补充 mock HTTP 测试
### 3. Agent Runner (`provider/runners/`)
- **路径**: `src/langbot/pkg/provider/runners/`
- **模块**: cozeapi, difysvapi, n8nsvapi, langflowapi, dashscopeapi, localagent, tboxapi
- **排除原因**: 需要真实 Agent 平台Coze、Dify、n8n 等)的 API 连接
- **测试方式**: 需要 mock Agent 平台响应
- **状态**: 后续可补充 mock 测试
### 4. 向量数据库 (`vector/vdbs/`)
- **路径**: `src/langbot/pkg/vector/vdbs/`
- **模块**: chroma, milvus, pgvector, qdrant, seekdb
- **排除原因**: 需要真实向量数据库实例运行
- **测试方式**: 需要 Docker 启动测试数据库或 mock
- **状态**: 后续可补充 mock 测试
---
## 覆盖率计算(排除外部适配器)
### 统计方法
```bash
# 排除外部适配器后计算覆盖率
pytest tests/unit_tests/ --cov=langbot.pkg \
--cov-fail-under=0 \
-o "cov_exclude_patterns=platform/sources/*,provider/modelmgr/requesters/*,provider/runners/*,vector/vdbs/*"
```
### 当前覆盖率(排除后)
| 模块 | 覆盖率 | 状态 |
|------|--------|------|
| `command` | **99%** | ✅ 完成 |
| `entity` | **99%** | ✅ 完成 |
| `vector` | **76%** | ✅ 完成 |
| `survey` | **84%** | ✅ 完成 |
| `pipeline` | **72%** | ✅ 核心流程 |
| `rag` | **66%** | ✅ 完成 |
| `telemetry` | **87%** | ✅ 完成 |
| `storage` | **80%** | ✅ 完成 |
| `provider` | **83%** | ✅ 完成 |
| `discover` | **61%** | ✅ 完成 |
| `config` | **70%** | ✅ 完成 |
| `utils` | **48%** | 🔄 部分完成 |
| `api` | **34%** | 🔄 需补充 controller |
| `platform` | **35%** | 🔄 需补充 adapter base |
| `plugin` | **27%** | 🔄 需补充 handler |
| `core` | **28%** | 🔄 需补充 app 启动 |
| `persistence` | **24%** | 🔄 需补充 mgr |
---
## 后续计划
### 可补充的 Mock 测试(优先级排序)
1. **`provider/modelmgr/requesters/`** (优先级:中)
- 使用 `httpx` mock 测试 API 响应解析
- 测试重试逻辑、错误处理
2. **`provider/runners/`** (优先级:中)
- Mock Agent 平台响应
- 测试 session 管理、错误处理
3. **`platform/sources/`** (优先级:低)
- Mock 平台 webhook 事件
- 测试消息解析、事件处理
4. **`vector/vdbs/`** (优先级:低)
- Mock 向量数据库操作
- 测试 CRUD、查询逻辑
---
## 测试文件结构
```
tests/unit_tests/
├── api/
│ └── service/
│ ├── test_knowledge_service.py # 22 tests ✅
│ └── ...
├── core/
│ ├── test_taskmgr.py # 21 tests ✅
│ ├── test_load_config.py # 21 tests ✅ (含env override)
│ └── ...
├── plugin/
│ ├── test_connector_static.py # 8 tests ✅
│ ├── test_connector_pure.py # 7 tests ✅
│ ├── test_connector_methods.py # 24 tests ✅
│ ├── test_extract_deps.py # 7 tests ✅
│ ├── test_handler_actions.py # 15 tests ✅ (新增)
│ └── ...
├── provider/
│ ├── test_session_manager.py # 11 tests ✅ (新增)
│ ├── test_tool_manager.py # 14 tests ✅ (新增)
│ └── ...
├── rag/
│ ├── test_i18n_conversion.py # 8 tests ✅
│ ├── test_kbmgr.py # 39 tests ✅
│ ├── test_file_storage.py # 21 tests ✅ (新增)
│ └── ...
├── storage/
│ ├── test_s3storage.py # 16 tests ✅ (新增)
│ ├── test_localstorage_path_traversal.py # 11 tests ✅
│ └── ...
├── survey/
│ └── test_survey_manager.py # 22 tests ✅
├── telemetry/
│ └── test_telemetry.py # 25 tests ✅ (重写)
├── vector/
│ ├── test_filter_utils.py # 21 tests ✅
│ ├── test_vdb_filter_conversion.py # 30 tests ✅ (新增)
│ └── ...
├── utils/
│ ├── test_platform.py # 7 tests ✅
│ ├── test_funcschema.py # 9 tests ✅
│ └── ...
├── pipeline/
│ ├── test_ratelimit.py # 12 tests ✅ (新增真实算法)
│ ├── test_msgtrun.py # 9 tests ✅ (强化断言)
│ └── ...
└── persistence/
├── test_serialize_model.py # 6 tests ✅
├── test_database_decorator.py # 7 tests ✅
└── ...
```
---
## 总结
- **总测试数**: 1193 passed
- **总体覆盖率**: 30%
- **核心模块覆盖率**: **51.2%** (6549/12825 语句) - 排除外部适配器
- **外部适配器覆盖率**: 5.6% (535/9483 语句) - 不纳入目标
### 核心模块覆盖率详情
| 模块 | 覆盖率 | 语句数 | 说明 |
|------|--------|--------|------|
| `command` | **99%** | 93 | ✅ 完成 |
| `entity` | **99%** | 335 | ✅ 完成 |
| `vector` | **76%** | 139 | ✅ 完成 (新增filter转换测试) |
| `survey` | **84%** | 95 | ✅ 完成 |
| `pipeline` | **72%** | 1761 | ✅ 核心流程 (新增算法测试) |
| `rag` | **66%** | 347 | ✅ 完成 (新增ZIP处理测试) |
| `telemetry` | **87%** | 70 | ✅ 完成 (重写假测试) |
| `storage` | **80%** | 170 | ✅ 完成 (新增S3测试) |
| `provider` | **83%** | 854 | ✅ 完成 (新增Session/Tool测试) |
| `discover` | **61%** | 188 | ✅ 完成 |
| `config` | **70%** | 198 | ✅ 完成 |
| `utils` | **48%** | 478 | 🔄 部分完成 |
| `api` | **34%** | 4061 | 🔄 需补充 controller |
| `platform` | **35%** | 433 | 🔄 需补充 adapter base |
| `plugin` | **27%** | 815 | 🔄 需补充 handler (新增action测试) |
| `core` | **28%** | 1289 | 🔄 需补充 app 启动 |
| `persistence` | **24%** | 1099 | 🔄 需补充 mgr |
外部适配器测试需要 mock 环境或集成测试,不属于纯单元测试范畴。

View File

@@ -1 +0,0 @@
"""Unit tests for LangBot API HTTP service layer."""

View File

@@ -1,62 +0,0 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
from sqlalchemy.sql.dml import Update
from langbot.pkg.api.http.service.bot import BotService
class _FakeResult:
def __init__(self, value):
self.value = value
def first(self):
return self.value
class _PersistenceManager:
def __init__(self):
self.update_values = None
async def execute_async(self, statement):
if isinstance(statement, Update):
self.update_values = {
key: value for key, value in statement.compile().params.items() if not key.startswith('uuid_')
}
return None
return _FakeResult(SimpleNamespace(name='Updated Pipeline'))
async def test_update_bot_copies_input_before_filtering_and_setting_pipeline_name():
persistence_mgr = _PersistenceManager()
runtime_bot = SimpleNamespace(enable=False)
platform_mgr = SimpleNamespace(
remove_bot=AsyncMock(),
load_bot=AsyncMock(return_value=runtime_bot),
)
ap = SimpleNamespace(
persistence_mgr=persistence_mgr,
platform_mgr=platform_mgr,
sess_mgr=SimpleNamespace(session_list=[]),
)
service = BotService(ap)
service.get_bot = AsyncMock(return_value={'uuid': 'bot-1'})
payload = {
'uuid': 'caller-owned-uuid',
'name': 'Test Bot',
'use_pipeline_uuid': 'pipeline-1',
}
await service.update_bot('bot-1', payload)
assert payload == {
'uuid': 'caller-owned-uuid',
'name': 'Test Bot',
'use_pipeline_uuid': 'pipeline-1',
}
assert persistence_mgr.update_values == {
'name': 'Test Bot',
'use_pipeline_uuid': 'pipeline-1',
'use_pipeline_name': 'Updated Pipeline',
}

View File

@@ -1,16 +0,0 @@
"""Unit tests for API HTTP service layer.
Tests real service business logic with mocked dependencies:
- persistence_mgr (database operations)
- model_mgr (runtime model management)
- platform_mgr (platform management)
- plugin_connector (plugin runtime)
- adjacent services (cross-service calls)
Does NOT:
- Start real Quart server
- Access real database
- Call real provider/platform/network
Uses tests.factories.FakeApp as base mock application.
"""

View File

@@ -1,429 +0,0 @@
"""
Unit tests for ApiKeyService.
Tests API key CRUD operations with mocked persistence layer.
Source: src/langbot/pkg/api/http/service/apikey.py
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock, patch
from types import SimpleNamespace
from langbot.pkg.api.http.service.apikey import ApiKeyService
from langbot.pkg.entity.persistence.apikey import ApiKey
pytestmark = pytest.mark.asyncio
class TestApiKeyServiceGetApiKeys:
"""Tests for get_api_keys method."""
async def test_get_api_keys_empty_list(self):
"""Returns empty list when no API keys exist."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = Mock()
mock_result.all = Mock(return_value=[])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'id': entity.id,
'name': entity.name,
'key': entity.key,
'description': entity.description,
}
if entity
else {}
)
service = ApiKeyService(ap)
# Execute
result = await service.get_api_keys()
# Verify
assert result == []
ap.persistence_mgr.execute_async.assert_called_once()
async def test_get_api_keys_returns_serialized_list(self):
"""Returns serialized list of API keys."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Create mock API key entities
key1 = Mock(spec=ApiKey)
key1.id = 1
key1.name = 'Test Key 1'
key1.key = 'lbk_test_key_1'
key1.description = 'First test key'
key2 = Mock(spec=ApiKey)
key2.id = 2
key2.name = 'Test Key 2'
key2.key = 'lbk_test_key_2'
key2.description = 'Second test key'
mock_result = Mock()
mock_result.all = Mock(return_value=[key1, key2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'id': entity.id,
'name': entity.name,
'key': entity.key,
'description': entity.description,
}
)
service = ApiKeyService(ap)
# Execute
result = await service.get_api_keys()
# Verify
assert len(result) == 2
assert result[0]['name'] == 'Test Key 1'
assert result[1]['name'] == 'Test Key 2'
class TestApiKeyServiceCreateApiKey:
"""Tests for create_api_key method."""
async def test_create_api_key_generates_key_with_prefix(self):
"""Creates API key with 'lbk_' prefix."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
created_key = Mock(spec=ApiKey)
created_key.id = 1
created_key.name = 'New Key'
created_key.key = 'lbk_fixed-token'
created_key.description = 'Test description'
select_result = Mock()
select_result.first = Mock(return_value=created_key)
insert_params = []
async def mock_execute(query):
params = query.compile().params
if {'name', 'key', 'description'}.issubset(params):
insert_params.append(params)
return Mock()
return select_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'id': 1,
'name': entity.name,
'key': entity.key,
'description': entity.description,
}
)
service = ApiKeyService(ap)
with patch('langbot.pkg.api.http.service.apikey.secrets.token_urlsafe', return_value='fixed-token'):
result = await service.create_api_key('New Key', 'Test description')
assert insert_params == [
{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}
]
assert result['key'].startswith('lbk_')
assert result['key'] == 'lbk_fixed-token'
assert result['name'] == 'New Key'
assert result['description'] == 'Test description'
async def test_create_api_key_without_description(self):
"""Creates API key with empty description when not provided."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
created_key = Mock(spec=ApiKey)
created_key.id = 1
created_key.name = 'No Desc Key'
created_key.key = 'lbk_no_desc_key'
created_key.description = ''
select_result = Mock()
select_result.first = Mock(return_value=created_key)
insert_result = Mock()
async def mock_execute(query):
if hasattr(query, 'values'):
return insert_result
return select_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'id': 1,
'name': 'No Desc Key',
'key': 'lbk_no_desc_key',
'description': '',
}
)
service = ApiKeyService(ap)
# Execute
result = await service.create_api_key('No Desc Key')
# Verify
assert result['description'] == ''
class TestApiKeyServiceGetApiKey:
"""Tests for get_api_key method."""
async def test_get_api_key_by_id_found(self):
"""Returns API key when found by ID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
key = Mock(spec=ApiKey)
key.id = 1
key.name = 'Found Key'
key.key = 'lbk_found_key'
key.description = 'Found'
mock_result = Mock()
mock_result.first = Mock(return_value=key)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'id': 1,
'name': 'Found Key',
'key': 'lbk_found_key',
'description': 'Found',
}
)
service = ApiKeyService(ap)
# Execute
result = await service.get_api_key(1)
# Verify
assert result is not None
assert result['id'] == 1
assert result['name'] == 'Found Key'
async def test_get_api_key_by_id_not_found(self):
"""Returns None when API key not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = Mock()
mock_result.first = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ApiKeyService(ap)
# Execute
result = await service.get_api_key(999)
# Verify
assert result is None
async def test_get_api_key_by_id_zero(self):
"""Handles ID=0 (edge case) correctly."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = Mock()
mock_result.first = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ApiKeyService(ap)
# Execute
result = await service.get_api_key(0)
# Verify - should return None (no key with ID 0)
assert result is None
class TestApiKeyServiceVerifyApiKey:
"""Tests for verify_api_key method."""
async def test_verify_api_key_valid(self):
"""Returns True for valid API key."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
key = Mock(spec=ApiKey)
mock_result = Mock()
mock_result.first = Mock(return_value=key)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ApiKeyService(ap)
# Execute
result = await service.verify_api_key('lbk_valid_key')
# Verify
assert result is True
async def test_verify_api_key_invalid(self):
"""Returns False for invalid API key."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = Mock()
mock_result.first = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ApiKeyService(ap)
# Execute
result = await service.verify_api_key('lbk_invalid_key')
# Verify
assert result is False
async def test_verify_api_key_empty_string(self):
"""Returns False for empty key string."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = Mock()
mock_result.first = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ApiKeyService(ap)
# Execute
result = await service.verify_api_key('')
# Verify
assert result is False
async def test_verify_api_key_unknown_key(self):
"""Returns False when the key is not present in persistence."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = Mock()
mock_result.first = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ApiKeyService(ap)
# Execute
result = await service.verify_api_key('unknown_key')
# Verify
assert result is False
class TestApiKeyServiceDeleteApiKey:
"""Tests for delete_api_key method."""
async def test_delete_api_key_by_id(self):
"""Deletes API key by ID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = ApiKeyService(ap)
# Execute
await service.delete_api_key(1)
# Verify - execute_async was called (delete operation)
ap.persistence_mgr.execute_async.assert_called_once()
async def test_delete_api_key_nonexistent_id(self):
"""Delete operation completes even for nonexistent ID (no error raised)."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = ApiKeyService(ap)
# Execute - should not raise error
await service.delete_api_key(999)
# Verify - execute_async was called regardless
ap.persistence_mgr.execute_async.assert_called_once()
class TestApiKeyServiceUpdateApiKey:
"""Tests for update_api_key method."""
async def test_update_api_key_name_only(self):
"""Updates only the name field."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = ApiKeyService(ap)
# Execute
await service.update_api_key(1, name='Updated Name')
# Verify - execute_async was called with update
ap.persistence_mgr.execute_async.assert_called_once()
async def test_update_api_key_description_only(self):
"""Updates only the description field."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = ApiKeyService(ap)
# Execute
await service.update_api_key(1, description='Updated description')
# Verify
ap.persistence_mgr.execute_async.assert_called_once()
async def test_update_api_key_both_fields(self):
"""Updates both name and description."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = ApiKeyService(ap)
# Execute
await service.update_api_key(1, name='New Name', description='New description')
# Verify
ap.persistence_mgr.execute_async.assert_called_once()
async def test_update_api_key_no_fields(self):
"""Does nothing when no fields provided."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = ApiKeyService(ap)
# Execute
await service.update_api_key(1)
# Verify - no execute call since no update_data
ap.persistence_mgr.execute_async.assert_not_called()

View File

@@ -1,662 +0,0 @@
"""
Unit tests for BotService.
Tests bot CRUD operations with mocked persistence and runtime managers.
Source: src/langbot/pkg/api/http/service/bot.py
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock, patch
from types import SimpleNamespace
import uuid
from langbot.pkg.api.http.service.bot import BotService
from langbot.pkg.entity.persistence.bot import Bot
pytestmark = pytest.mark.asyncio
def _create_mock_bot(
bot_uuid: str = None,
name: str = 'Test Bot',
description: str = 'Test Description',
adapter: str = 'telegram',
adapter_config: dict = None,
enable: bool = True,
use_pipeline_uuid: str = None,
use_pipeline_name: str = None,
) -> Mock:
"""Helper to create mock Bot entity."""
bot = Mock(spec=Bot)
bot.uuid = bot_uuid or str(uuid.uuid4())
bot.name = name
bot.description = description
bot.adapter = adapter
bot.adapter_config = adapter_config or {'token': 'test_token'}
bot.enable = enable
bot.use_pipeline_uuid = use_pipeline_uuid
bot.use_pipeline_name = use_pipeline_name
bot.pipeline_routing_rules = []
return bot
def _create_mock_result(items: list = None, first_item=None):
"""Create mock result object for persistence queries."""
result = Mock()
result.all = Mock(return_value=items or [])
result.first = Mock(return_value=first_item)
return result
class TestBotServiceGetBots:
"""Tests for get_bots method."""
async def test_get_bots_empty_list(self):
"""Returns empty list when no bots exist."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity, masked_columns=None: {
'uuid': entity.uuid,
'name': entity.name,
'adapter': entity.adapter,
}
)
service = BotService(ap)
# Execute
result = await service.get_bots()
# Verify
assert result == []
async def test_get_bots_returns_list_with_secrets(self):
"""Returns bot list including adapter_config by default."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
bot1 = _create_mock_bot(bot_uuid='uuid-1', name='Bot 1')
bot2 = _create_mock_bot(bot_uuid='uuid-2', name='Bot 2')
mock_result = _create_mock_result([bot1, bot2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity, masked_columns=None: {
'uuid': entity.uuid,
'name': entity.name,
'adapter': entity.adapter,
'adapter_config': entity.adapter_config if 'adapter_config' not in (masked_columns or []) else None,
}
)
service = BotService(ap)
# Execute
result = await service.get_bots(include_secret=True)
# Verify
assert len(result) == 2
assert result[0]['name'] == 'Bot 1'
assert result[0]['adapter_config'] is not None
async def test_get_bots_masks_secrets(self):
"""Returns bot list without adapter_config when include_secret=False."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
bot1 = _create_mock_bot(bot_uuid='uuid-1', name='Bot 1')
mock_result = _create_mock_result([bot1])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity, masked_columns=None: {
'uuid': entity.uuid,
'name': entity.name,
'adapter': entity.adapter,
'adapter_config': entity.adapter_config if 'adapter_config' not in (masked_columns or []) else None,
}
)
service = BotService(ap)
# Execute
result = await service.get_bots(include_secret=False)
# Verify - adapter_config should be masked
assert result[0]['adapter_config'] is None
class TestBotServiceGetBot:
"""Tests for get_bot method."""
async def test_get_bot_by_uuid_found(self):
"""Returns bot when found by UUID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
bot = _create_mock_bot(bot_uuid='test-uuid', name='Found Bot')
mock_result = _create_mock_result(first_item=bot)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'test-uuid',
'name': 'Found Bot',
'adapter': 'telegram',
}
)
service = BotService(ap)
# Execute
result = await service.get_bot('test-uuid')
# Verify
assert result is not None
assert result['uuid'] == 'test-uuid'
assert result['name'] == 'Found Bot'
async def test_get_bot_by_uuid_not_found(self):
"""Returns None when bot not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result(first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = BotService(ap)
# Execute
result = await service.get_bot('nonexistent-uuid')
# Verify
assert result is None
class TestBotServiceGetRuntimeBotInfo:
"""Tests for get_runtime_bot_info method."""
async def test_get_runtime_bot_info_bot_not_found_raises(self):
"""Raises Exception when bot not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result(first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = BotService(ap)
# Mock get_bot to return None
service.get_bot = AsyncMock(return_value=None)
# Execute & Verify
with pytest.raises(Exception, match='Bot not found'):
await service.get_runtime_bot_info('nonexistent-uuid')
async def test_get_runtime_bot_info_returns_webhook_for_wecom(self):
"""Returns webhook URL for wecom adapter."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'api': {
'webhook_prefix': 'http://127.0.0.1:5300',
'extra_webhook_prefix': 'http://extra.example.com',
}
}
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
bot_data = {
'uuid': 'wecom-uuid',
'name': 'WeCom Bot',
'adapter': 'wecom',
'adapter_config': {'token': 'test'},
}
service = BotService(ap)
service.get_bot = AsyncMock(return_value=bot_data)
# Execute
result = await service.get_runtime_bot_info('wecom-uuid')
# Verify
assert result['adapter_runtime_values']['webhook_url'] == '/bots/wecom-uuid'
assert result['adapter_runtime_values']['webhook_full_url'] == 'http://127.0.0.1:5300/bots/wecom-uuid'
async def test_get_runtime_bot_info_no_webhook_for_telegram(self):
"""Returns no webhook URL for non-webhook adapters like telegram."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'api': {}}
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
bot_data = {
'uuid': 'telegram-uuid',
'name': 'Telegram Bot',
'adapter': 'telegram',
'adapter_config': {'token': 'test'},
}
service = BotService(ap)
service.get_bot = AsyncMock(return_value=bot_data)
# Execute
result = await service.get_runtime_bot_info('telegram-uuid')
# Verify - no webhook for telegram
assert result['adapter_runtime_values']['webhook_url'] is None
assert result['adapter_runtime_values']['webhook_full_url'] is None
async def test_get_runtime_bot_info_with_runtime_bot(self):
"""Returns bot_account_id when runtime bot exists."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'api': {}}
ap.platform_mgr = SimpleNamespace()
# Mock runtime bot with adapter
runtime_bot = SimpleNamespace()
runtime_bot.adapter = SimpleNamespace()
runtime_bot.adapter.bot_account_id = 'runtime-account-123'
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
bot_data = {
'uuid': 'runtime-uuid',
'name': 'Runtime Bot',
'adapter': 'telegram',
'adapter_config': {},
}
service = BotService(ap)
service.get_bot = AsyncMock(return_value=bot_data)
# Execute
result = await service.get_runtime_bot_info('runtime-uuid')
# Verify
assert result['adapter_runtime_values']['bot_account_id'] == 'runtime-account-123'
class TestBotServiceCreateBot:
"""Tests for create_bot method."""
async def test_create_bot_max_limit_reached_raises(self):
"""Raises ValueError when max_bots limit reached."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_bots': 2
}
}
}
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.load_bot = AsyncMock()
# Mock get_bots to return 2 bots already
bot1 = _create_mock_bot(bot_uuid='uuid-1')
bot2 = _create_mock_bot(bot_uuid='uuid-2')
mock_result = _create_mock_result([bot1, bot2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'uuid-1', 'name': 'Bot 1'}
)
service = BotService(ap)
# Execute & Verify
with pytest.raises(ValueError, match='Maximum number of bots'):
await service.create_bot({'name': 'New Bot'})
async def test_create_bot_no_limit(self):
"""Creates bot without limit check when max_bots=-1."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_bots': -1 # No limit
}
}
}
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.load_bot = AsyncMock()
# Mock pipeline query
pipeline_result = Mock()
pipeline_result.first = Mock(return_value=None)
# Mock bot query after insert
bot_result = Mock()
bot_result.first = Mock(return_value=_create_mock_bot())
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count <= 2:
return pipeline_result # First call: check pipeline
elif call_count == 3:
return Mock() # Insert
return bot_result # Get bot
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'new-uuid', 'name': 'New Bot'}
)
service = BotService(ap)
# Execute
bot_uuid = await service.create_bot({'name': 'New Bot', 'adapter': 'telegram', 'adapter_config': {}})
# Verify
assert bot_uuid is not None
assert len(bot_uuid) == 36 # UUID format
async def test_create_bot_sets_default_pipeline(self):
"""Sets default pipeline when one exists."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_bots': -1}}}
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.load_bot = AsyncMock()
# Mock default pipeline
mock_pipeline = SimpleNamespace()
mock_pipeline.uuid = 'default-pipeline-uuid'
mock_pipeline.name = 'Default Pipeline'
pipeline_result = Mock()
pipeline_result.first = Mock(return_value=mock_pipeline)
# Mock bot after insert
bot_result = Mock()
bot_result.first = Mock(return_value=_create_mock_bot())
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return pipeline_result # Check default pipeline
elif call_count == 2:
return Mock() # Insert
return bot_result # Get bot
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'new-uuid',
'name': 'New Bot',
'use_pipeline_uuid': 'default-pipeline-uuid',
'use_pipeline_name': 'Default Pipeline',
}
)
service = BotService(ap)
# Execute
bot_data = {'name': 'New Bot', 'adapter': 'telegram', 'adapter_config': {}}
bot_uuid = await service.create_bot(bot_data)
# Verify - pipeline uuid and name were set
assert 'use_pipeline_uuid' in bot_data
assert 'use_pipeline_name' in bot_data
assert bot_uuid is not None # Verify UUID was returned
class TestBotServiceUpdateBot:
"""Tests for update_bot method."""
async def test_update_bot_removes_uuid_from_data(self):
"""Does not persist caller-provided uuid in update payload."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.remove_bot = AsyncMock()
# Mock pipeline query - not updating pipeline
ap.persistence_mgr.execute_async = AsyncMock()
ap.sess_mgr = SimpleNamespace()
ap.sess_mgr.session_list = []
service = BotService(ap)
service.get_bot = AsyncMock(return_value={'uuid': 'test-uuid', 'name': 'Updated'})
# Create mock runtime bot
runtime_bot = SimpleNamespace()
runtime_bot.enable = False
ap.platform_mgr.load_bot = AsyncMock(return_value=runtime_bot)
# Execute
update_data = {'uuid': 'should-be-removed', 'name': 'Updated Name'}
await service.update_bot('test-uuid', update_data)
update_params = ap.persistence_mgr.execute_async.await_args_list[0].args[0].compile().params
assert update_params['name'] == 'Updated Name'
assert 'should-be-removed' not in update_params.values()
async def test_update_bot_pipeline_not_found_raises(self):
"""Raises Exception when updating with nonexistent pipeline UUID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Mock pipeline query returns None
pipeline_result = Mock()
pipeline_result.first = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=pipeline_result)
service = BotService(ap)
# Execute & Verify
with pytest.raises(Exception, match='Pipeline not found'):
await service.update_bot('test-uuid', {'use_pipeline_uuid': 'nonexistent-pipeline'})
async def test_update_bot_sets_pipeline_name(self):
"""Sets use_pipeline_name when updating use_pipeline_uuid."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.remove_bot = AsyncMock()
# Mock pipeline query
mock_pipeline = SimpleNamespace()
mock_pipeline.name = 'Updated Pipeline'
pipeline_result = Mock()
pipeline_result.first = Mock(return_value=mock_pipeline)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return pipeline_result
return Mock()
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.sess_mgr = SimpleNamespace()
ap.sess_mgr.session_list = []
service = BotService(ap)
service.get_bot = AsyncMock(return_value={'uuid': 'test-uuid'})
runtime_bot = SimpleNamespace()
runtime_bot.enable = False
ap.platform_mgr.load_bot = AsyncMock(return_value=runtime_bot)
# Execute
await service.update_bot('test-uuid', {'use_pipeline_uuid': 'pipeline-uuid'})
update_params = ap.persistence_mgr.execute_async.await_args_list[1].args[0].compile().params
assert update_params['use_pipeline_uuid'] == 'pipeline-uuid'
assert update_params['use_pipeline_name'] == 'Updated Pipeline'
class TestBotServiceDeleteBot:
"""Tests for delete_bot method."""
async def test_delete_bot_calls_remove_and_delete(self):
"""Calls both platform_mgr.remove_bot and persistence delete."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.remove_bot = AsyncMock()
service = BotService(ap)
# Execute
await service.delete_bot('test-uuid')
# Verify
ap.platform_mgr.remove_bot.assert_called_once_with('test-uuid')
ap.persistence_mgr.execute_async.assert_called_once()
async def test_delete_bot_nonexistent_uuid(self):
"""Delete operation completes even for nonexistent UUID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.remove_bot = AsyncMock()
service = BotService(ap)
# Execute - should not raise
await service.delete_bot('nonexistent-uuid')
# Verify - both called regardless
ap.platform_mgr.remove_bot.assert_called_once()
class TestBotServiceListEventLogs:
"""Tests for list_event_logs method."""
async def test_list_event_logs_bot_not_found_raises(self):
"""Raises Exception when runtime bot not found."""
# Setup
ap = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
service = BotService(ap)
# Execute & Verify
with pytest.raises(Exception, match='Bot not found'):
await service.list_event_logs('nonexistent-uuid', 0, 10)
async def test_list_event_logs_returns_logs(self):
"""Returns logs from runtime bot logger."""
# Setup
ap = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
# Mock runtime bot with logger
runtime_bot = SimpleNamespace()
runtime_bot.logger = SimpleNamespace()
runtime_bot.logger.get_logs = AsyncMock(return_value=(
[SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))],
5
))
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
service = BotService(ap)
# Execute
logs, total = await service.list_event_logs('bot-uuid', 0, 10)
# Verify
assert len(logs) == 1
assert logs[0] == {'msg': 'log1'}
assert total == 5
class TestBotServiceSendMessage:
"""Tests for send_message method."""
async def test_send_message_bot_not_found_raises(self):
"""Raises Exception when bot not found."""
# Setup
ap = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
service = BotService(ap)
# Execute & Verify
with pytest.raises(Exception, match='Bot not found'):
await service.send_message('nonexistent-uuid', 'group', '123', {'test': 'data'})
async def test_send_message_invalid_message_chain_raises(self):
"""Raises Exception when message_chain_data is invalid."""
# Setup
ap = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
runtime_bot = SimpleNamespace()
runtime_bot.adapter = SimpleNamespace()
runtime_bot.adapter.send_message = AsyncMock()
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
service = BotService(ap)
# Execute & Verify - invalid format should raise
with pytest.raises(Exception, match='Invalid message_chain format'):
await service.send_message('bot-uuid', 'group', '123', {'invalid': 'format'})
async def test_send_message_valid_call(self):
"""Sends message through adapter when all valid."""
# Setup
ap = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
runtime_bot = SimpleNamespace()
runtime_bot.adapter = SimpleNamespace()
runtime_bot.adapter.send_message = AsyncMock()
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
service = BotService(ap)
# Execute with valid message chain format
message_chain_data = {
'messages': [
{'type': 'text', 'data': {'text': 'Hello'}}
]
}
# Patch the import location - the module imports inside the function
with patch('langbot_plugin.api.entities.builtin.platform.message.MessageChain') as MockMessageChain:
mock_chain = Mock()
MockMessageChain.model_validate = Mock(return_value=mock_chain)
await service.send_message('bot-uuid', 'group', '123', message_chain_data)
# Verify adapter.send_message was called
runtime_bot.adapter.send_message.assert_called_once_with('group', '123', mock_chain)

View File

@@ -1,397 +0,0 @@
"""Unit tests for API knowledge service.
Tests cover:
- Knowledge base CRUD operations
- Capability checking
- Knowledge engine discovery
- File operations
"""
from __future__ import annotations
import pytest
from unittest.mock import Mock, AsyncMock
from importlib import import_module
def get_knowledge_service_module():
"""Lazy import to avoid circular import issues."""
return import_module('langbot.pkg.api.http.service.knowledge')
def create_mock_app():
"""Create mock Application for testing."""
mock_app = Mock()
mock_app.logger = Mock()
mock_app.rag_mgr = AsyncMock()
mock_app.persistence_mgr = AsyncMock()
mock_app.persistence_mgr.execute_async = AsyncMock()
mock_app.persistence_mgr.serialize_model = Mock(return_value={})
mock_app.plugin_connector = AsyncMock()
mock_app.plugin_connector.is_enable_plugin = True
return mock_app
class TestKnowledgeServiceInit:
"""Tests for KnowledgeService initialization."""
def test_init_stores_app_reference(self):
"""Test that __init__ stores Application reference."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
service = knowledge_module.KnowledgeService(mock_app)
assert service.ap is mock_app
class TestGetKnowledgeBases:
"""Tests for get_knowledge_bases method."""
@pytest.mark.asyncio
async def test_returns_all_kb_details(self):
"""Test that it returns all knowledge base details."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(
return_value=[{'uuid': 'kb1', 'name': 'KB1'}]
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_bases()
assert len(result) == 1
assert result[0]['uuid'] == 'kb1'
@pytest.mark.asyncio
async def test_returns_empty_list_when_no_kbs(self):
"""Test that it returns empty list when no knowledge bases."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(return_value=[])
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_bases()
assert result == []
class TestGetKnowledgeBase:
"""Tests for get_knowledge_base method."""
@pytest.mark.asyncio
async def test_returns_kb_details_by_uuid(self):
"""Test that it returns specific KB details."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'uuid': 'kb1', 'name': 'KB1'}
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_base('kb1')
assert result['uuid'] == 'kb1'
@pytest.mark.asyncio
async def test_returns_none_when_not_found(self):
"""Test that it returns None when KB not found."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value=None)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_base('nonexistent')
assert result is None
class TestCreateKnowledgeBase:
"""Tests for create_knowledge_base method."""
@pytest.mark.asyncio
async def test_creates_kb_with_required_fields(self):
"""Test creating KB with required plugin ID."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_kb = Mock()
mock_kb.uuid = 'new_kb_uuid'
mock_app.rag_mgr.create_knowledge_base = AsyncMock(return_value=mock_kb)
service = knowledge_module.KnowledgeService(mock_app)
kb_data = {
'name': 'Test KB',
'knowledge_engine_plugin_id': 'author/engine',
'description': 'Test description',
}
result = await service.create_knowledge_base(kb_data)
assert result == 'new_kb_uuid'
mock_app.rag_mgr.create_knowledge_base.assert_called_once()
@pytest.mark.asyncio
async def test_raises_when_missing_plugin_id(self):
"""Test that ValueError is raised when plugin ID missing."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
service = knowledge_module.KnowledgeService(mock_app)
with pytest.raises(ValueError) as exc_info:
await service.create_knowledge_base({'name': 'Test'})
assert 'knowledge_engine_plugin_id is required' in str(exc_info.value)
@pytest.mark.asyncio
async def test_creates_with_default_name(self):
"""Test that KB is created with default name if not provided."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_kb = Mock()
mock_kb.uuid = 'new_kb_uuid'
mock_app.rag_mgr.create_knowledge_base = AsyncMock(return_value=mock_kb)
service = knowledge_module.KnowledgeService(mock_app)
await service.create_knowledge_base({
'knowledge_engine_plugin_id': 'author/engine'
})
# Check that default name 'Untitled' was used
call_args = mock_app.rag_mgr.create_knowledge_base.call_args
assert call_args.kwargs['name'] == 'Untitled'
class TestUpdateKnowledgeBase:
"""Tests for update_knowledge_base method."""
@pytest.mark.asyncio
async def test_updates_mutable_fields_only(self):
"""Test that only mutable fields are updated."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'uuid': 'kb1', 'name': 'Updated'}
)
mock_app.rag_mgr.remove_knowledge_base_from_runtime = AsyncMock()
mock_app.rag_mgr.load_knowledge_base = AsyncMock()
service = knowledge_module.KnowledgeService(mock_app)
# Pass both mutable and immutable fields
await service.update_knowledge_base('kb1', {
'name': 'New Name',
'description': 'New desc',
'uuid': 'should_be_filtered', # immutable
})
# Check that only mutable fields were passed to update
call_args = mock_app.persistence_mgr.execute_async.call_args
assert call_args is not None
@pytest.mark.asyncio
async def test_returns_early_when_no_mutable_fields(self):
"""Test that update returns early when no mutable fields provided."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
service = knowledge_module.KnowledgeService(mock_app)
# Pass only immutable fields
await service.update_knowledge_base('kb1', {'uuid': 'should_be_filtered'})
# No DB update should be called
mock_app.persistence_mgr.execute_async.assert_not_called()
class TestCheckDocCapability:
"""Tests for _check_doc_capability method."""
@pytest.mark.asyncio
async def test_passes_when_capability_supported(self):
"""Test that check passes when doc_ingestion capability exists."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'knowledge_engine': {'capabilities': ['doc_ingestion']}}
)
service = knowledge_module.KnowledgeService(mock_app)
await service._check_doc_capability('kb1', 'document upload')
# No exception raised means success
@pytest.mark.asyncio
async def test_raises_when_kb_not_found(self):
"""Test that Exception is raised when KB not found."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value=None)
service = knowledge_module.KnowledgeService(mock_app)
with pytest.raises(Exception) as exc_info:
await service._check_doc_capability('nonexistent', 'test operation')
assert 'Knowledge base not found' in str(exc_info.value)
@pytest.mark.asyncio
async def test_raises_when_capability_not_supported(self):
"""Test that Exception is raised when doc_ingestion not in capabilities."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'knowledge_engine': {'capabilities': ['other_capability']}}
)
service = knowledge_module.KnowledgeService(mock_app)
with pytest.raises(Exception) as exc_info:
await service._check_doc_capability('kb1', 'document upload')
assert 'does not support document upload' in str(exc_info.value)
class TestListKnowledgeEngines:
"""Tests for list_knowledge_engines method."""
@pytest.mark.asyncio
async def test_returns_engines_from_plugin_connector(self):
"""Test that it returns knowledge engines from plugin connector."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
return_value=[{'id': 'engine1', 'name': 'Engine 1'}]
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_knowledge_engines()
assert len(result) == 1
assert result[0]['id'] == 'engine1'
@pytest.mark.asyncio
async def test_returns_empty_when_plugin_disabled(self):
"""Test that it returns empty list when plugin disabled."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.is_enable_plugin = False
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_knowledge_engines()
assert result == []
@pytest.mark.asyncio
async def test_returns_empty_on_exception(self):
"""Test that it returns empty list and logs warning on exception."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
side_effect=Exception('Connection error')
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_knowledge_engines()
assert result == []
mock_app.logger.warning.assert_called_once()
class TestListParsers:
"""Tests for list_parsers method."""
@pytest.mark.asyncio
async def test_returns_all_parsers(self):
"""Test that it returns all parsers when no MIME type filter."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_parsers = AsyncMock(
return_value=[
{'id': 'parser1', 'supported_mime_types': ['text/plain']},
{'id': 'parser2', 'supported_mime_types': ['application/pdf']},
]
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_parsers()
assert len(result) == 2
@pytest.mark.asyncio
async def test_filters_by_mime_type(self):
"""Test that it filters parsers by MIME type."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_parsers = AsyncMock(
return_value=[
{'id': 'parser1', 'supported_mime_types': ['text/plain']},
{'id': 'parser2', 'supported_mime_types': ['application/pdf']},
]
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_parsers(mime_type='application/pdf')
assert len(result) == 1
assert result[0]['id'] == 'parser2'
@pytest.mark.asyncio
async def test_returns_empty_when_plugin_disabled(self):
"""Test that it returns empty list when plugin disabled."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.is_enable_plugin = False
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_parsers()
assert result == []
class TestGetEngineSchemas:
"""Tests for get_engine_creation_schema and get_engine_retrieval_schema."""
@pytest.mark.asyncio
async def test_returns_creation_schema(self):
"""Test that it returns creation schema for engine."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
return_value={'properties': {'name': {'type': 'string'}}}
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_engine_creation_schema('author/engine')
assert 'properties' in result
@pytest.mark.asyncio
async def test_returns_retrieval_schema(self):
"""Test that it returns retrieval schema for engine."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.get_rag_retrieval_schema = AsyncMock(
return_value={'properties': {'top_k': {'type': 'integer'}}}
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_engine_retrieval_schema('author/engine')
assert 'properties' in result
@pytest.mark.asyncio
async def test_returns_empty_dict_on_exception(self):
"""Test that it returns empty dict and logs warning on exception."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
side_effect=Exception('Plugin error')
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_engine_creation_schema('author/engine')
assert result == {}
mock_app.logger.warning.assert_called_once()

View File

@@ -1,824 +0,0 @@
"""
Unit tests for MaintenanceService.
Tests storage maintenance and diagnostics including:
- Cleanup expired files
- Storage analysis
- File counting and sizing
- Monitoring counts
- Binary storage stats
Source: src/langbot/pkg/api/http/service/maintenance.py
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock, patch, MagicMock
from types import SimpleNamespace
import datetime
from pathlib import Path
from langbot.pkg.api.http.service.maintenance import MaintenanceService
pytestmark = pytest.mark.asyncio
def _create_mock_result(scalar_value=None):
"""Create mock result object for persistence queries."""
result = Mock()
result.scalar = Mock(return_value=scalar_value)
return result
class TestMaintenanceServiceCleanupExpiredFiles:
"""Tests for cleanup_expired_files method."""
async def test_cleanup_expired_files_default_retention(self):
"""Uses default retention days when config not set."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.storage_mgr = SimpleNamespace()
# Create a proper mock object with __class__.__name__
storage_provider = MagicMock()
storage_provider.__class__.__name__ = 'LocalStorageProvider'
ap.storage_mgr.storage_provider = storage_provider
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Mock the internal cleanup methods - one is async, one is not
service._cleanup_expired_uploaded_files = AsyncMock(return_value=0)
service._cleanup_expired_log_files = Mock(return_value=0) # NOT async!
# Execute
result = await service.cleanup_expired_files()
# Verify - returns counts
assert 'uploaded_files' in result
assert 'log_files' in result
assert result['uploaded_files'] == 0
assert result['log_files'] == 0
async def test_cleanup_expired_files_custom_retention(self):
"""Uses custom retention days from config."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'storage': {
'cleanup': {
'uploaded_file_retention_days': 14,
'log_retention_days': 7,
}
}
}
ap.storage_mgr = SimpleNamespace()
storage_provider = MagicMock()
storage_provider.__class__.__name__ = 'LocalStorageProvider'
ap.storage_mgr.storage_provider = storage_provider
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Mock the internal cleanup methods
service._cleanup_expired_uploaded_files = AsyncMock(return_value=2)
service._cleanup_expired_log_files = Mock(return_value=3) # NOT async
# Execute
result = await service.cleanup_expired_files()
# Verify
assert result['uploaded_files'] == 2
assert result['log_files'] == 3
async def test_cleanup_expired_files_s3_provider(self):
"""Handles S3StorageProvider correctly."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.storage_mgr = SimpleNamespace()
# Mock S3 provider
s3_provider = MagicMock()
s3_provider.__class__.__name__ = 'S3StorageProvider'
s3_provider.delete = AsyncMock()
ap.storage_mgr.storage_provider = s3_provider
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Mock the internal cleanup methods
service._cleanup_expired_uploaded_files = AsyncMock(return_value=1)
service._cleanup_expired_log_files = Mock(return_value=0) # NOT async
# Execute
result = await service.cleanup_expired_files()
# Verify
assert result['uploaded_files'] == 1
assert result['log_files'] == 0
async def test_cleanup_expired_files_invalid_retention(self):
"""Uses default for invalid retention config."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'storage': {
'cleanup': {
'uploaded_file_retention_days': 'invalid', # Invalid
'log_retention_days': 0, # Invalid (less than 1)
}
}
}
ap.storage_mgr = SimpleNamespace()
storage_provider = MagicMock()
storage_provider.__class__.__name__ = 'LocalStorageProvider'
ap.storage_mgr.storage_provider = storage_provider
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Mock the internal cleanup methods
service._cleanup_expired_uploaded_files = AsyncMock(return_value=0)
service._cleanup_expired_log_files = Mock(return_value=0) # NOT async
# Execute
result = await service.cleanup_expired_files()
# Verify - warning logged, defaults used
assert ap.logger.warning.called
assert 'uploaded_files' in result
class TestMaintenanceServiceGetStorageAnalysis:
"""Tests for get_storage_analysis method."""
async def test_get_storage_analysis_basic(self):
"""Returns basic storage analysis."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}
}
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
ap.task_mgr = SimpleNamespace()
ap.task_mgr.get_stats = Mock(return_value={'running': 0})
# Mock monitoring counts
count_result = _create_mock_result(scalar_value=10)
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
service = MaintenanceService(ap)
# Mock file operations
service._path_size = Mock(return_value=1000)
service._file_count = Mock(return_value=5)
service._monitoring_counts = AsyncMock(return_value={'messages': 10, 'errors': 0})
service._binary_storage_stats = AsyncMock(return_value={'count': 5, 'size_bytes': 500})
service._expired_uploaded_candidates = AsyncMock(return_value=[])
service._expired_log_candidates = Mock(return_value=[])
# Execute
result = await service.get_storage_analysis()
# Verify
assert 'generated_at' in result
assert 'cleanup_policy' in result
assert 'sections' in result
assert 'database' in result
assert 'cleanup_candidates' in result
async def test_get_storage_analysis_sections(self):
"""Returns all storage sections."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'database': {'use': 'postgresql'}}
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
ap.task_mgr = None
count_result = _create_mock_result(scalar_value=0)
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
service = MaintenanceService(ap)
service._path_size = Mock(return_value=0)
service._file_count = Mock(return_value=0)
service._monitoring_counts = AsyncMock(return_value={})
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
service._expired_uploaded_candidates = AsyncMock(return_value=[])
service._expired_log_candidates = Mock(return_value=[])
# Execute
result = await service.get_storage_analysis()
# Verify - all sections present
sections = {s['key'] for s in result['sections']}
assert 'database' in sections
assert 'logs' in sections
assert 'storage' in sections
assert 'vector_store' in sections
assert 'plugins' in sections
assert 'mcp' in sections
assert 'temp' in sections
async def test_get_storage_analysis_postgresql(self):
"""Handles PostgreSQL database type."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'database': {'use': 'postgresql'}}
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
ap.task_mgr = None
count_result = _create_mock_result(scalar_value=0)
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
service = MaintenanceService(ap)
service._path_size = Mock(return_value=0)
service._file_count = Mock(return_value=0)
service._monitoring_counts = AsyncMock(return_value={})
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': None})
service._expired_uploaded_candidates = AsyncMock(return_value=[])
service._expired_log_candidates = Mock(return_value=[])
# Execute
result = await service.get_storage_analysis()
# Verify
assert result['database']['type'] == 'postgresql'
async def test_get_storage_analysis_with_cleanup_candidates(self):
"""Returns cleanup candidates in analysis."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
ap.task_mgr = None
count_result = _create_mock_result(scalar_value=0)
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
service = MaintenanceService(ap)
service._path_size = Mock(return_value=0)
service._file_count = Mock(return_value=0)
service._monitoring_counts = AsyncMock(return_value={})
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
service._expired_uploaded_candidates = AsyncMock(return_value=[
{'key': 'old_file', 'size_bytes': 100}
])
service._expired_log_candidates = Mock(return_value=[
{'name': 'old_log', 'size_bytes': 50}
])
# Execute
result = await service.get_storage_analysis()
# Verify
assert len(result['cleanup_candidates']['uploaded_files']) == 1
assert len(result['cleanup_candidates']['log_files']) == 1
class TestMaintenanceServiceMonitoringCounts:
"""Tests for _monitoring_counts method."""
async def test_monitoring_counts_returns_counts(self):
"""Returns counts for all monitoring tables."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
count_result = _create_mock_result(scalar_value=42)
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
service = MaintenanceService(ap)
# Execute
result = await service._monitoring_counts()
# Verify - all table keys present
assert 'messages' in result
assert 'llm_calls' in result
assert 'embedding_calls' in result
assert 'errors' in result
assert 'sessions' in result
assert 'feedback' in result
async def test_monitoring_counts_zero_results(self):
"""Returns zero counts when tables empty."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
count_result = _create_mock_result(scalar_value=0)
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
service = MaintenanceService(ap)
# Execute
result = await service._monitoring_counts()
# Verify - all zero
assert all(v == 0 for v in result.values())
class TestMaintenanceServiceBinaryStorageStats:
"""Tests for _binary_storage_stats method."""
async def test_binary_storage_stats_returns_stats(self):
"""Returns count and size for binary storage."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
# Mock count result
count_result = _create_mock_result(scalar_value=10)
# Mock size result
size_result = _create_mock_result(scalar_value=5000)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return count_result
return size_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
service = MaintenanceService(ap)
# Execute
result = await service._binary_storage_stats()
# Verify
assert result['count'] == 10
assert result['size_bytes'] == 5000
async def test_binary_storage_stats_size_error(self):
"""Handles error when calculating size."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
count_result = _create_mock_result(scalar_value=5)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return count_result
raise Exception('Size calculation error')
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
service = MaintenanceService(ap)
# Execute
result = await service._binary_storage_stats()
# Verify - warning logged, size_bytes None or 0
assert ap.logger.warning.called
assert result['count'] == 5
class TestMaintenanceServicePathSize:
"""Tests for _path_size method."""
def test_path_size_nonexistent_path(self):
"""Returns 0 for nonexistent path."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
# Execute
result = service._path_size(Path('/nonexistent/path'))
# Verify
assert result == 0
def test_path_size_single_file(self):
"""Returns size for single file."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
# Mock file
mock_stat = Mock()
mock_stat.st_size = 100
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'is_file', return_value=True):
with patch.object(Path, 'stat', return_value=mock_stat):
result = service._path_size(Path('test.txt'))
# Verify
assert result == 100
def test_path_size_directory(self):
"""Returns total size for directory."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
# Mock os.walk
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'is_file', return_value=False):
with patch('os.walk') as mock_walk:
mock_walk.return_value = [
('/test_dir', [], ['file1.txt', 'file2.txt']),
]
# Mock file stat
mock_stat = Mock()
mock_stat.st_size = 50
with patch.object(Path, 'stat', return_value=mock_stat):
result = service._path_size(Path('/test_dir'))
# Verify - 2 files * 50 bytes
assert result == 100
class TestMaintenanceServiceFileCount:
"""Tests for _file_count method."""
def test_file_count_nonexistent_path(self):
"""Returns 0 for nonexistent path."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
# Execute
result = service._file_count(Path('/nonexistent/path'))
# Verify
assert result == 0
def test_file_count_single_file(self):
"""Returns 1 for single file."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'is_file', return_value=True):
result = service._file_count(Path('test.txt'))
# Verify
assert result == 1
def test_file_count_directory(self):
"""Returns file count for directory."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'is_file', return_value=False):
with patch('os.walk') as mock_walk:
mock_walk.return_value = [
('/test_dir', [], ['file1.txt', 'file2.txt', 'file3.txt']),
]
result = service._file_count(Path('/test_dir'))
# Verify
assert result == 3
class TestMaintenanceServicePositiveInt:
"""Tests for _positive_int helper method."""
def test_positive_int_valid_value(self):
"""Returns valid positive integer."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Execute
result = service._positive_int(7, 5, 'test_param')
# Verify
assert result == 7
assert not ap.logger.warning.called
def test_positive_int_invalid_string(self):
"""Returns default for invalid string."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Execute
result = service._positive_int('invalid', 5, 'test_param')
# Verify
assert result == 5
assert ap.logger.warning.called
def test_positive_int_invalid_none(self):
"""Returns default for None."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Execute
result = service._positive_int(None, 5, 'test_param')
# Verify
assert result == 5
assert ap.logger.warning.called
def test_positive_int_negative_value(self):
"""Returns default for negative value."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Execute
result = service._positive_int(-1, 5, 'test_param')
# Verify
assert result == 5
assert ap.logger.warning.called
def test_positive_int_zero_value(self):
"""Returns default for zero value."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Execute
result = service._positive_int(0, 5, 'test_param')
# Verify
assert result == 5
assert ap.logger.warning.called
class TestMaintenanceServiceIsUploadedFileKey:
"""Tests for _is_uploaded_file_key helper method."""
def test_is_uploaded_file_key_valid(self):
"""Returns True for valid upload file key."""
# Setup
ap = SimpleNamespace()
service = MaintenanceService(ap)
# Execute - simple filename without path
result = service._is_uploaded_file_key('uploaded_file.txt')
# Verify
assert result is True
def test_is_uploaded_file_key_with_path(self):
"""Returns False for key with path separator."""
# Setup
ap = SimpleNamespace()
service = MaintenanceService(ap)
# Execute - key with path
result = service._is_uploaded_file_key('path/to/file.txt')
# Verify
assert result is False
def test_is_uploaded_file_key_plugin_config(self):
"""Returns False for plugin config prefix."""
# Setup
ap = SimpleNamespace()
service = MaintenanceService(ap)
# Execute - plugin config file
result = service._is_uploaded_file_key('plugin_config_some_plugin.json')
# Verify
assert result is False
class TestMaintenanceServiceExpiredLogCandidates:
"""Tests for _expired_log_candidates method."""
def test_expired_log_candidates_nonexistent_dir(self):
"""Returns empty list when logs dir not exists."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
with patch.object(Path, 'exists', return_value=False):
result = service._expired_log_candidates(3)
# Verify
assert result == []
def test_expired_log_candidates_matches_pattern(self):
"""Matches log file pattern correctly."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
# Mock directory with log files
old_date = datetime.date.today() - datetime.timedelta(days=10)
old_log_name = f'langbot-{old_date.isoformat()}.log'
recent_log_name = f'langbot-{datetime.date.today().isoformat()}.log'
mock_entry_old = Mock(spec=Path)
mock_entry_old.is_file = Mock(return_value=True)
mock_entry_old.name = old_log_name
mock_stat = Mock()
mock_stat.st_size = 1000
mock_entry_old.stat = Mock(return_value=mock_stat)
mock_entry_recent = Mock(spec=Path)
mock_entry_recent.is_file = Mock(return_value=True)
mock_entry_recent.name = recent_log_name
mock_stat2 = Mock()
mock_stat2.st_size = 500
mock_entry_recent.stat = Mock(return_value=mock_stat2)
# Non-log file
mock_entry_other = Mock(spec=Path)
mock_entry_other.is_file = Mock(return_value=True)
mock_entry_other.name = 'other_file.txt'
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'iterdir') as mock_iterdir:
mock_iterdir.return_value = [mock_entry_old, mock_entry_recent, mock_entry_other]
result = service._expired_log_candidates(3)
# Verify - only old log included
assert len(result) == 1
assert result[0]['name'] == old_log_name
def test_expired_log_candidates_includes_path(self):
"""Includes path when include_paths=True."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
old_date = datetime.date.today() - datetime.timedelta(days=10)
old_log_name = f'langbot-{old_date.isoformat()}.log'
mock_entry = Mock(spec=Path)
mock_entry.is_file = Mock(return_value=True)
mock_entry.name = old_log_name
mock_entry.__str__ = Mock(return_value='/data/logs/' + old_log_name)
mock_stat = Mock()
mock_stat.st_size = 1000
mock_entry.stat = Mock(return_value=mock_stat)
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'iterdir') as mock_iterdir:
mock_iterdir.return_value = [mock_entry]
result = service._expired_log_candidates(3, include_paths=True)
# Verify - path included
assert 'path' in result[0]
class TestMaintenanceServiceExpiredLocalUploadCandidates:
"""Tests for _expired_local_upload_candidates method."""
def test_expired_local_upload_candidates_nonexistent_dir(self):
"""Returns empty list when storage dir not exists."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
with patch.object(Path, 'exists', return_value=False):
result = service._expired_local_upload_candidates(7)
# Verify
assert result == []
def test_expired_local_upload_candidates_filters_uploaded(self):
"""Only returns uploaded files matching pattern."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
# Mock _is_uploaded_file_key
service._is_uploaded_file_key = Mock(side_effect=lambda key: 'plugin_config_' not in key and '/' not in key)
# Create mock files - one valid, one plugin config
mock_entry_valid = Mock(spec=Path)
mock_entry_valid.is_file = Mock(return_value=True)
mock_entry_valid.name = 'valid_upload.txt'
mock_stat = Mock()
mock_stat.st_size = 100
mock_stat.st_mtime = 0 # Very old
mock_entry_valid.stat = Mock(return_value=mock_stat)
mock_entry_plugin = Mock(spec=Path)
mock_entry_plugin.is_file = Mock(return_value=True)
mock_entry_plugin.name = 'plugin_config_test.json'
mock_stat2 = Mock()
mock_stat2.st_size = 200
mock_stat2.st_mtime = 0
mock_entry_plugin.stat = Mock(return_value=mock_stat2)
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'iterdir') as mock_iterdir:
mock_iterdir.return_value = [mock_entry_valid, mock_entry_plugin]
result = service._expired_local_upload_candidates(7)
# Verify - only valid upload included
assert len(result) == 1
assert result[0]['key'] == 'valid_upload.txt'
def test_expired_local_upload_candidates_includes_path(self):
"""Includes path when include_paths=True."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
service._is_uploaded_file_key = Mock(return_value=True)
mock_entry = Mock(spec=Path)
mock_entry.is_file = Mock(return_value=True)
mock_entry.name = 'old_file.txt'
mock_entry.__str__ = Mock(return_value='/data/storage/old_file.txt')
mock_stat = Mock()
mock_stat.st_size = 100
mock_stat.st_mtime = 0
mock_entry.stat = Mock(return_value=mock_stat)
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'iterdir') as mock_iterdir:
mock_iterdir.return_value = [mock_entry]
result = service._expired_local_upload_candidates(7, include_paths=True)
# Verify - path included
assert 'path' in result[0]

Some files were not shown because too many files have changed in this diff Show More