mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 12:05:54 +00:00
Compare commits
251 Commits
v4.9.0
...
feat/workf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
192b69b0fb | ||
|
|
543fbd8ca0 | ||
|
|
804448b6cd | ||
|
|
5d4e40459f | ||
|
|
b5c43cc113 | ||
|
|
8ebfcd963a | ||
|
|
127198675e | ||
|
|
44fb188994 | ||
|
|
265385a563 | ||
|
|
253cc6cbea | ||
|
|
f99d3022e8 | ||
|
|
6823069103 | ||
|
|
699545a196 | ||
|
|
5c5614667a | ||
|
|
f0061817ea | ||
|
|
688202e7d1 | ||
|
|
d46b762d03 | ||
|
|
0963fd5443 | ||
|
|
6471770737 | ||
|
|
314b7d15bb | ||
|
|
c758908745 | ||
|
|
313d553271 | ||
|
|
bb7db53447 | ||
|
|
767137aaa0 | ||
|
|
acb2ce6a40 | ||
|
|
67784708d6 | ||
|
|
1bd9c334aa | ||
|
|
17bbc8bf10 | ||
|
|
4a4c0921a4 | ||
|
|
e425cf079a | ||
|
|
245e798b79 | ||
|
|
27fdccce16 | ||
|
|
484643c0ee | ||
|
|
ec61459619 | ||
|
|
66ef744447 | ||
|
|
10d3a9cc92 | ||
|
|
885320e9ae | ||
|
|
ed02ac4710 | ||
|
|
e4841edbaf | ||
|
|
ef7a06b0db | ||
|
|
6fe20c1812 | ||
|
|
9e8c8f79df | ||
|
|
01d06898fb | ||
|
|
0a669c7016 | ||
|
|
27c0d344bf | ||
|
|
c088dc114f | ||
|
|
37c74b0622 | ||
|
|
b251fc4b89 | ||
|
|
075c85e2bc | ||
|
|
62b63ca2ca | ||
|
|
3680a80248 | ||
|
|
6713b57d01 | ||
|
|
ea13ef87f2 | ||
|
|
59bd581e88 | ||
|
|
cba83a62e8 | ||
|
|
f412127fb0 | ||
|
|
5273bbb23f | ||
|
|
0ceab3f6a5 | ||
|
|
aedc097188 | ||
|
|
18b27dd9ef | ||
|
|
3f50a56623 | ||
|
|
c9f7911efe | ||
|
|
75fdfe6806 | ||
|
|
eb9f38b102 | ||
|
|
fc40d3c949 | ||
|
|
d176a448e0 | ||
|
|
ada4c30f85 | ||
|
|
32c9eaff45 | ||
|
|
9706ee2d53 | ||
|
|
e7c9bc69d3 | ||
|
|
1fcdbd472f | ||
|
|
547006cb4a | ||
|
|
92bf9a7ea5 | ||
|
|
832efb4069 | ||
|
|
8f1847d480 | ||
|
|
fe619e415f | ||
|
|
0154ea6cd3 | ||
|
|
8db55267d8 | ||
|
|
b9662250a6 | ||
|
|
d9378c3a88 | ||
|
|
86a4d1bf0b | ||
|
|
ce6e79db8e | ||
|
|
d53e2cb9a0 | ||
|
|
c1168745b7 | ||
|
|
69b87a0d8a | ||
|
|
6637b153f1 | ||
|
|
e768fc6116 | ||
|
|
2442d3bf52 | ||
|
|
42d78817f4 | ||
|
|
4b9f25a05d | ||
|
|
d1f0e07cc0 | ||
|
|
78e55509ae | ||
|
|
2c28635a39 | ||
|
|
5f3cecfbe2 | ||
|
|
12df9d6ee9 | ||
|
|
195f6efeff | ||
|
|
564d829e25 | ||
|
|
58c1916712 | ||
|
|
a8fba46040 | ||
|
|
3115d6f6dd | ||
|
|
323481d69b | ||
|
|
5a5c4295b1 | ||
|
|
88111d87ac | ||
|
|
4e5a6ee79a | ||
|
|
05c684d757 | ||
|
|
2838020580 | ||
|
|
9b34ae2db4 | ||
|
|
f8010a20eb | ||
|
|
917edb3413 | ||
|
|
10425ede34 | ||
|
|
e4b40a8fa0 | ||
|
|
0b8ab4b54b | ||
|
|
49239e0e08 | ||
|
|
aec2a30445 | ||
|
|
c8915ca964 | ||
|
|
a715eddd06 | ||
|
|
2f9c235b41 | ||
|
|
cc4d8838eb | ||
|
|
fa0a77f09f | ||
|
|
fd6a7b73d4 | ||
|
|
bf0848d60b | ||
|
|
e06fac2bb7 | ||
|
|
bec61427a0 | ||
|
|
5fae7b2eb0 | ||
|
|
2eebdfe16a | ||
|
|
9cd3544d59 | ||
|
|
de4d14fee3 | ||
|
|
f29c568381 | ||
|
|
af3f557055 | ||
|
|
b894842736 | ||
|
|
e190029e1f | ||
|
|
e4940a8050 | ||
|
|
617c95ebc4 | ||
|
|
1cdd428bcc | ||
|
|
71ac719aee | ||
|
|
4621e6cc9f | ||
|
|
66087f83e1 | ||
|
|
25f9330491 | ||
|
|
14b1e0d33b | ||
|
|
83ccb33fd3 | ||
|
|
05bcf543ba | ||
|
|
7cd063bb5d | ||
|
|
8f1317b39e | ||
|
|
77a0de5ef0 | ||
|
|
875227a2fe | ||
|
|
2317392ee5 | ||
|
|
c7efa4dd7f | ||
|
|
e701daa8e0 | ||
|
|
1ae99199b2 | ||
|
|
7c067a1cb3 | ||
|
|
478bc62576 | ||
|
|
a740eb8ee9 | ||
|
|
f8aedd02b3 | ||
|
|
ea638cab80 | ||
|
|
7129dd536e | ||
|
|
1b1cc7769b | ||
|
|
44b8354dfd | ||
|
|
55ec9d11ae | ||
|
|
5b3d3801b5 | ||
|
|
9f1ea75d09 | ||
|
|
6e37aae636 | ||
|
|
921d12f596 | ||
|
|
6bf6deaefd | ||
|
|
1201949f2c | ||
|
|
1c419e3591 | ||
|
|
b0a9be77b0 | ||
|
|
e02ade5a30 | ||
|
|
1a51ba8e7e | ||
|
|
e7b22d6ebf | ||
|
|
dddfa8ac79 | ||
|
|
99e2976826 | ||
|
|
71e44f0e54 | ||
|
|
4c904c2375 | ||
|
|
498d030da9 | ||
|
|
c111bf1714 | ||
|
|
6570f276d2 | ||
|
|
42e1e038bd | ||
|
|
d0e54a45c7 | ||
|
|
23fa47b07e | ||
|
|
4902c1d3b2 | ||
|
|
a6f96e5209 | ||
|
|
37c41bcfe4 | ||
|
|
9e223949a7 | ||
|
|
267bd72c63 | ||
|
|
af0d00e5e9 | ||
|
|
244e16c491 | ||
|
|
cad259fe39 | ||
|
|
bc3199bf29 | ||
|
|
127dc455c3 | ||
|
|
e8dc6fde53 | ||
|
|
4a97895dea | ||
|
|
3c0495fc51 | ||
|
|
dfd25deb68 | ||
|
|
f4db53b759 | ||
|
|
9f90341dcb | ||
|
|
67b726afb2 | ||
|
|
01852b81d4 | ||
|
|
4d6f109788 | ||
|
|
e1e5e7aedf | ||
|
|
cd53abc440 | ||
|
|
16a15a122a | ||
|
|
6fa653f232 | ||
|
|
c13971d7d6 | ||
|
|
9c659ce8fa | ||
|
|
c9fc64360f | ||
|
|
88a04fdbe8 | ||
|
|
bbe019f0c6 | ||
|
|
865f6ee81b | ||
|
|
bd5ec59b7c | ||
|
|
9c0cc1003d | ||
|
|
ea07d8ad00 | ||
|
|
3ac3fad4bc | ||
|
|
254a13bba3 | ||
|
|
4355f0fa78 | ||
|
|
031737f05d | ||
|
|
9e366fc536 | ||
|
|
8bd6442965 | ||
|
|
1a1eadb282 | ||
|
|
eed72b1c12 | ||
|
|
351350ea03 | ||
|
|
bc3d6ba92f | ||
|
|
345e4baf2a | ||
|
|
6c64dc057f | ||
|
|
eec0a9c9d9 | ||
|
|
6896a55485 | ||
|
|
4b0fad233e | ||
|
|
52eb991a70 | ||
|
|
10c716be0c | ||
|
|
6e77351eda | ||
|
|
20f5ebd9b8 | ||
|
|
d2c75329cf | ||
|
|
7e2fe082f0 | ||
|
|
d451b059fd | ||
|
|
93c52fcd4c | ||
|
|
f1608682e6 | ||
|
|
077e631c13 | ||
|
|
d7df1f05d1 | ||
|
|
8b8cfb76de | ||
|
|
79311ccde3 | ||
|
|
def798bf1f | ||
|
|
5290834b8b | ||
|
|
89064a9d5b | ||
|
|
8c2aef3734 | ||
|
|
3fb9e542b6 | ||
|
|
01844d8687 | ||
|
|
2655425fbe | ||
|
|
bd15b630b0 | ||
|
|
fe5ce68436 | ||
|
|
0541b05966 | ||
|
|
13cb0aa9be | ||
|
|
a048369b38 |
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -1,5 +1,5 @@
|
||||
name: 漏洞反馈
|
||||
description: 【供中文用户】报错或漏洞请使用这个模板创建,不使用此模板创建的异常、漏洞相关issue将被直接关闭。由于自己操作不当/不甚了解所用技术栈引起的网络连接问题恕无法解决,请勿提 issue。容器间网络连接问题,参考文档 https://docs.langbot.app/zh/workshop/network-details.html
|
||||
description: 【供中文用户】报错或漏洞请使用这个模板创建,不使用此模板创建的异常、漏洞相关issue将被直接关闭。由于自己操作不当/不甚了解所用技术栈引起的网络连接问题恕无法解决,请勿提 issue。容器间网络连接问题,参考文档 https://link.langbot.app/zh/docs/network
|
||||
title: "[Bug]: "
|
||||
labels: ["bug?"]
|
||||
body:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/bug-report_en.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report_en.yml
vendored
@@ -1,5 +1,5 @@
|
||||
name: Bug report
|
||||
description: Report bugs or vulnerabilities using this template. For container network connection issues, refer to the documentation https://docs.langbot.app/en/workshop/network-details.html
|
||||
description: Report bugs or vulnerabilities using this template. For container network connection issues, refer to the documentation https://link.langbot.app/en/docs/network
|
||||
title: "[Bug]: "
|
||||
labels: ["bug?"]
|
||||
body:
|
||||
|
||||
@@ -43,10 +43,10 @@ jobs:
|
||||
run: |
|
||||
cd /tmp/langbot_build_web/web
|
||||
npm install
|
||||
npm run build
|
||||
npx vite build
|
||||
- name: Package Output
|
||||
run: |
|
||||
cp -r /tmp/langbot_build_web/web/out ./web
|
||||
cp -r /tmp/langbot_build_web/web/dist ./web
|
||||
- name: Upload Artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
|
||||
25
.github/workflows/check-i18n.yml
vendored
Normal file
25
.github/workflows/check-i18n.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
name: Check i18n Keys
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
|
||||
jobs:
|
||||
check-i18n:
|
||||
name: Check i18n Key Consistency
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Check i18n keys against en-US reference
|
||||
run: node web/scripts/check-i18n.mjs
|
||||
4
.github/workflows/publish-to-pypi.yml
vendored
4
.github/workflows/publish-to-pypi.yml
vendored
@@ -29,8 +29,8 @@ jobs:
|
||||
npm install -g pnpm
|
||||
pnpm install
|
||||
pnpm build
|
||||
mkdir -p ../src/langbot/web/out
|
||||
cp -r out ../src/langbot/web/
|
||||
mkdir -p ../src/langbot/web/dist
|
||||
cp -r dist ../src/langbot/web/
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
109
.github/workflows/run-tests.yml
vendored
109
.github/workflows/run-tests.yml
vendored
@@ -4,25 +4,29 @@ on:
|
||||
pull_request:
|
||||
types: [opened, ready_for_review, synchronize]
|
||||
paths:
|
||||
- 'pkg/**'
|
||||
- 'src/langbot/**'
|
||||
- 'tests/**'
|
||||
- '.github/workflows/run-tests.yml'
|
||||
- 'pyproject.toml'
|
||||
- 'uv.lock'
|
||||
- 'run_tests.sh'
|
||||
- 'scripts/test-*.sh'
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- develop
|
||||
paths:
|
||||
- 'pkg/**'
|
||||
- 'src/langbot/**'
|
||||
- 'tests/**'
|
||||
- '.github/workflows/run-tests.yml'
|
||||
- 'pyproject.toml'
|
||||
- 'uv.lock'
|
||||
- 'run_tests.sh'
|
||||
- 'scripts/test-*.sh'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Run Unit Tests
|
||||
name: Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -39,28 +43,13 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --dev
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
bash run_tests.sh
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
if: matrix.python-version == '3.12'
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
flags: unit-tests
|
||||
name: unit-tests-coverage
|
||||
fail_ci_if_error: false
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
- name: Run unit + smoke tests
|
||||
run: uv run pytest tests/unit_tests/ tests/smoke/ -q --tb=short
|
||||
|
||||
- name: Test Summary
|
||||
if: always()
|
||||
@@ -69,3 +58,79 @@ jobs:
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Python Version: ${{ matrix.python-version }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
integration:
|
||||
name: Fast Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run fast integration tests
|
||||
run: uv run pytest tests/integration/ -m "not slow" -q --tb=short
|
||||
|
||||
- name: Integration Test Summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "## Integration Tests Results" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
coverage:
|
||||
name: Coverage Gate
|
||||
runs-on: ubuntu-latest
|
||||
needs: [test, integration]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run coverage (unit + smoke)
|
||||
run: |
|
||||
uv run pytest tests/unit_tests/ tests/smoke/ \
|
||||
--cov=langbot \
|
||||
--cov-report=xml \
|
||||
--cov-report=term-missing \
|
||||
--cov-fail-under=18 \
|
||||
-q --tb=short
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
flags: unit-tests
|
||||
name: coverage-report
|
||||
fail_ci_if_error: false
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "## Coverage Results" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Threshold: 18%" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||
78
.github/workflows/test-migrations.yml
vendored
Normal file
78
.github/workflows/test-migrations.yml
vendored
Normal file
@@ -0,0 +1,78 @@
|
||||
name: Test Migrations
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
- dev
|
||||
paths:
|
||||
- 'src/langbot/pkg/persistence/**'
|
||||
- 'src/langbot/pkg/entity/persistence/**'
|
||||
- 'tests/integration/persistence/**'
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- 'src/langbot/pkg/persistence/**'
|
||||
- 'src/langbot/pkg/entity/persistence/**'
|
||||
- 'tests/integration/persistence/**'
|
||||
|
||||
jobs:
|
||||
test-migrations-sqlite:
|
||||
name: Migrations (SQLite)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run SQLite migration tests
|
||||
run: uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
|
||||
|
||||
test-migrations-postgres:
|
||||
name: Migrations (PostgreSQL)
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:16
|
||||
env:
|
||||
POSTGRES_USER: langbot
|
||||
POSTGRES_PASSWORD: langbot
|
||||
POSTGRES_DB: langbot_test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd="pg_isready -U langbot"
|
||||
--health-interval=5s
|
||||
--health-timeout=5s
|
||||
--health-retries=5
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run PostgreSQL migration tests
|
||||
env:
|
||||
TEST_POSTGRES_URL: postgresql+asyncpg://langbot:langbot@localhost:5432/langbot_test
|
||||
run: uv run pytest tests/integration/persistence/test_migrations_postgres.py -q --tb=short
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -47,8 +47,12 @@ plugins.bak
|
||||
coverage.xml
|
||||
.coverage
|
||||
src/langbot/web/
|
||||
testsdk/
|
||||
|
||||
# Build artifacts
|
||||
/dist
|
||||
/build
|
||||
*.egg-info
|
||||
|
||||
# Next.js build cache (legacy)
|
||||
web/.next/
|
||||
|
||||
@@ -9,16 +9,14 @@ repos:
|
||||
# Run the formatter of backend.
|
||||
- id: ruff-format
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [javascript, jsx, ts, tsx, css, scss]
|
||||
additional_dependencies:
|
||||
- prettier@3.1.0
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: prettier
|
||||
name: prettier
|
||||
entry: npx --prefix web prettier --write --ignore-unknown
|
||||
language: system
|
||||
types_or: [javascript, jsx, ts, tsx, css, scss]
|
||||
|
||||
- id: lint-staged
|
||||
name: lint-staged
|
||||
entry: cd web && pnpm lint-staged
|
||||
|
||||
@@ -70,7 +70,7 @@ Plugin Runtime automatically starts each installed plugin and interacts through
|
||||
- type: must be a specific type, such as feat (new feature), fix (bug fix), docs (documentation), style (code style), refactor (refactoring), perf (performance optimization), etc.
|
||||
- scope: the scope of the commit, such as the package name, the file name, the function name, the class name, the module name, etc.
|
||||
- subject: the subject of the commit, such as the description of the commit, the reason for the commit, the impact of the commit, etc.
|
||||
- If you changed the definition of database entities, please update the migration file in `src/langbot/pkg/persistence/migrations/` and update the constants.py file in `src/langbot/pkg/utils/constants.py` with the new migration number.
|
||||
- LangBot uses [Alembic](https://alembic.sqlalchemy.org/) to manage database migrations, supporting both SQLite and PostgreSQL. Migration files are located in `src/langbot/pkg/persistence/alembic/versions/`. If you changed the definition of database entities (ORM models), generate a new migration script by running `uv run python -m langbot.pkg.persistence.alembic_runner autogenerate "description of your change"` in the project root (requires `data/config.yaml` to exist). Review and edit the generated script before committing. Migrations are executed automatically on LangBot startup. For data migrations (e.g. modifying JSON field content), you need to manually add the migration code in the generated script.
|
||||
|
||||
## Some Principles
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ WORKDIR /app
|
||||
|
||||
COPY web ./web
|
||||
|
||||
RUN cd web && npm install && npm run build
|
||||
RUN cd web && npm install && npx vite build
|
||||
|
||||
FROM python:3.12.7-slim
|
||||
|
||||
@@ -12,7 +12,7 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=node /app/web/out ./web/out
|
||||
COPY --from=node /app/web/dist ./web/dist
|
||||
|
||||
RUN apt update \
|
||||
&& apt install gcc -y \
|
||||
|
||||
36
Makefile
Normal file
36
Makefile
Normal file
@@ -0,0 +1,36 @@
|
||||
# LangBot Makefile
|
||||
# Quick developer commands
|
||||
|
||||
.PHONY: test test-quick test-integration-fast test-coverage test-all-local lint
|
||||
|
||||
# Run all tests (full suite with coverage)
|
||||
test:
|
||||
bash run_tests.sh
|
||||
|
||||
# Quick self-test for developers (lint + unit + smoke, no real credentials needed)
|
||||
test-quick:
|
||||
bash scripts/test-quick.sh
|
||||
|
||||
# Fast integration tests (SQLite/API/Pipeline, no external services)
|
||||
test-integration-fast:
|
||||
bash scripts/test-integration-fast.sh
|
||||
|
||||
# Coverage gate (all tests, enforces minimum threshold)
|
||||
test-coverage:
|
||||
bash scripts/test-coverage.sh
|
||||
|
||||
# Full local quality gate (quick + integration + coverage)
|
||||
test-all-local:
|
||||
bash scripts/test-quick.sh
|
||||
bash scripts/test-integration-fast.sh
|
||||
bash scripts/test-coverage.sh
|
||||
|
||||
# Run linting only
|
||||
lint:
|
||||
ruff check src/langbot/ tests/
|
||||
ruff format --check src/langbot/ tests/
|
||||
|
||||
# Fix linting issues
|
||||
lint-fix:
|
||||
ruff check --fix src/langbot/ tests/
|
||||
ruff format src/langbot/ tests/
|
||||
94
README.md
94
README.md
@@ -19,9 +19,9 @@ English / [简体中文](README_CN.md) / [繁體中文](README_TW.md) / [日本
|
||||
[](https://github.com/langbot-app/LangBot/stargazers)
|
||||
|
||||
<a href="https://langbot.app">Website</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features">Features</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide">Docs</a> |
|
||||
<a href="https://docs.langbot.app/en/tags/readme">API</a> |
|
||||
<a href="https://link.langbot.app/en/docs/features">Features</a> |
|
||||
<a href="https://link.langbot.app/en/docs/guide">Docs</a> |
|
||||
<a href="https://link.langbot.app/en/docs/api">API</a> |
|
||||
<a href="https://space.langbot.app/cloud">Cloud</a> |
|
||||
<a href="https://space.langbot.app">Plugin Market</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">Roadmap</a>
|
||||
@@ -45,7 +45,9 @@ LangBot is an **open-source, production-grade platform** for building AI-powered
|
||||
- **Web Management Panel** — Configure, manage, and monitor your bots through an intuitive browser interface. No YAML editing required.
|
||||
- **Multi-Pipeline Architecture** — Different bots for different scenarios, with comprehensive monitoring and exception handling.
|
||||
|
||||
[→ Learn more about all features](https://docs.langbot.app/en/insight/features)
|
||||
[→ Learn more about all features](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Practical guides: [deploy a multi-platform AI bot in 5 minutes](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [connect DeepSeek to WeChat, Discord, and Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [run a Dify Agent in Discord, Telegram, and Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/), and [build an n8n-powered chatbot](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
@@ -76,7 +78,7 @@ docker compose up -d
|
||||
[](https://zeabur.com/en-US/templates/ZKTBDH)
|
||||
[](https://railway.app/template/yRrAyL?referralCode=vogKPF)
|
||||
|
||||
**More options:** [Docker](https://docs.langbot.app/en/deploy/langbot/docker) · [Manual](https://docs.langbot.app/en/deploy/langbot/manual) · [BTPanel](https://docs.langbot.app/en/deploy/langbot/one-click/bt) · [Kubernetes](./docker/README_K8S.md)
|
||||
**More options:** [Docker](https://link.langbot.app/en/docs/docker) · [Manual](https://link.langbot.app/en/docs/manual-deploy) · [BTPanel](https://link.langbot.app/en/docs/bt-panel) · [Kubernetes](./docker/README_K8S.md)
|
||||
|
||||
---
|
||||
|
||||
@@ -84,68 +86,72 @@ docker compose up -d
|
||||
|
||||
| Platform | Status | Notes |
|
||||
|----------|--------|-------|
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | Personal & Official API |
|
||||
| Discord | ✅ | Official |
|
||||
| Telegram | ✅ | Official |
|
||||
| Slack | ✅ | Official |
|
||||
| LINE | ✅ | Official |
|
||||
| QQ | ✅ | Personal & Official API (Channel, DM, Group) |
|
||||
| WeCom | ✅ | Enterprise WeChat, External CS, AI Bot |
|
||||
| WeChat | ✅ | Personal & Official Account |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Lark | ✅ | Official |
|
||||
| DingTalk | ✅ | Official |
|
||||
| KOOK | ✅ | Official |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix, Satori |
|
||||
| Matrix | ✅ | Supports multiple bridged platforms such as Signal, WhatsApp, Messenger, iMessage, Mattermost, Google Chat, IRC, XMPP, Zulip, and more |
|
||||
|
||||
---
|
||||
|
||||
## Supported LLMs & Integrations
|
||||
|
||||
| Provider | Type | Status |
|
||||
|----------|------|--------|
|
||||
| [OpenAI](https://platform.openai.com/) | LLM | ✅ |
|
||||
| [Anthropic](https://www.anthropic.com/) | LLM | ✅ |
|
||||
| [DeepSeek](https://www.deepseek.com/) | LLM | ✅ |
|
||||
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | LLM | ✅ |
|
||||
| [xAI](https://x.ai/) | LLM | ✅ |
|
||||
| [Moonshot](https://www.moonshot.cn/) | LLM | ✅ |
|
||||
| [Zhipu AI](https://open.bigmodel.cn/) | LLM | ✅ |
|
||||
| [Ollama](https://ollama.com/) | Local LLM | ✅ |
|
||||
| [LM Studio](https://lmstudio.ai/) | Local LLM | ✅ |
|
||||
| [Dify](https://dify.ai) | LLMOps | ✅ |
|
||||
| [MCP](https://modelcontextprotocol.io/) | Protocol | ✅ |
|
||||
| [SiliconFlow](https://siliconflow.cn/) | Gateway | ✅ |
|
||||
| [Aliyun Bailian](https://bailian.console.aliyun.com/) | Gateway | ✅ |
|
||||
| [Volc Engine Ark](https://console.volcengine.com/ark/region:ark+cn-beijing/model?vendor=Bytedance&view=LIST_VIEW) | Gateway | ✅ |
|
||||
| [ModelScope](https://modelscope.cn/docs/model-service/API-Inference/intro) | Gateway | ✅ |
|
||||
| [GiteeAI](https://ai.gitee.com/) | Gateway | ✅ |
|
||||
| [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_langbot) | GPU Platform | ✅ |
|
||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | GPU Platform | ✅ |
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | GPU Platform | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | Gateway | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | Gateway | ✅ |
|
||||
| Provider | Type | Status |
|
||||
| ----------------------------------------------------------------------------------------------------------------- | ------------ | ------ |
|
||||
| [OpenAI](https://platform.openai.com/) | LLM | ✅ |
|
||||
| [Anthropic](https://www.anthropic.com/) | LLM | ✅ |
|
||||
| [DeepSeek](https://www.deepseek.com/) | LLM | ✅ |
|
||||
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | LLM | ✅ |
|
||||
| [xAI](https://x.ai/) | LLM | ✅ |
|
||||
| [Moonshot](https://www.moonshot.cn/) | LLM | ✅ |
|
||||
| [Zhipu AI](https://open.bigmodel.cn/) | LLM | ✅ |
|
||||
| [Ollama](https://ollama.com/) | Local LLM | ✅ |
|
||||
| [LM Studio](https://lmstudio.ai/) | Local LLM | ✅ |
|
||||
| [Dify](https://dify.ai) | LLMOps | ✅ |
|
||||
| [MCP](https://modelcontextprotocol.io/) | Protocol | ✅ |
|
||||
| [SiliconFlow](https://siliconflow.cn/) | Gateway | ✅ |
|
||||
| [Aliyun Bailian](https://bailian.console.aliyun.com/) | Gateway | ✅ |
|
||||
| [Volc Engine Ark](https://console.volcengine.com/ark/region:ark+cn-beijing/model?vendor=Bytedance&view=LIST_VIEW) | Gateway | ✅ |
|
||||
| [ModelScope](https://modelscope.cn/docs/model-service/API-Inference/intro) | Gateway | ✅ |
|
||||
| [GiteeAI](https://ai.gitee.com/) | Gateway | ✅ |
|
||||
| [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_langbot) | GPU Platform | ✅ |
|
||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | GPU Platform | ✅ |
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | GPU Platform | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | Gateway | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | Gateway | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | Gateway | ✅ |
|
||||
|
||||
[→ View all integrations](https://docs.langbot.app/en/insight/features)
|
||||
[→ View all integrations](https://link.langbot.app/en/docs/features)
|
||||
|
||||
---
|
||||
|
||||
## Why LangBot?
|
||||
|
||||
| Use Case | How LangBot Helps |
|
||||
|----------|-------------------|
|
||||
| **Customer Support** | Deploy AI agents to Slack/Discord/Telegram that answer questions using your knowledge base |
|
||||
| **Internal Tools** | Connect n8n/Dify workflows to WeCom/DingTalk for automated business processes |
|
||||
| **Community Management** | Moderate QQ/Discord groups with AI-powered content filtering and interaction |
|
||||
| **Multi-Platform Presence** | One bot, all platforms. Manage from a single dashboard |
|
||||
| Use Case | How LangBot Helps |
|
||||
| --------------------------- | ------------------------------------------------------------------------------------------ |
|
||||
| **Customer Support** | Deploy AI agents to Slack/Discord/Telegram that answer questions using your knowledge base |
|
||||
| **Internal Tools** | Connect n8n/Dify workflows to WeCom/DingTalk for automated business processes |
|
||||
| **Community Management** | Moderate QQ/Discord groups with AI-powered content filtering and interaction |
|
||||
| **Multi-Platform Presence** | One bot, all platforms. Manage from a single dashboard |
|
||||
|
||||
---
|
||||
|
||||
## Live Demo
|
||||
|
||||
**Try it now:** https://demo.langbot.dev/
|
||||
|
||||
- Email: `demo@langbot.app`
|
||||
- Password: `langbot123456`
|
||||
|
||||
*Note: Public demo environment. Do not enter sensitive information.*
|
||||
_Note: Public demo environment. Do not enter sensitive information._
|
||||
|
||||
---
|
||||
|
||||
|
||||
36
README_CN.md
36
README_CN.md
@@ -21,9 +21,9 @@
|
||||
[](https://gitcode.com/RockChinQ/LangBot)
|
||||
|
||||
<a href="https://langbot.app">官网</a> |
|
||||
<a href="https://docs.langbot.app/zh/insight/features.html">特性</a> |
|
||||
<a href="https://docs.langbot.app/zh/insight/guide.html">文档</a> |
|
||||
<a href="https://docs.langbot.app/zh/tags/readme.html">API</a> |
|
||||
<a href="https://link.langbot.app/zh/docs/features">特性</a> |
|
||||
<a href="https://link.langbot.app/zh/docs/guide">文档</a> |
|
||||
<a href="https://link.langbot.app/zh/docs/api">API</a> |
|
||||
<a href="https://space.langbot.app/cloud">Cloud</a> |
|
||||
<a href="https://space.langbot.app">插件市场</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">路线图</a>
|
||||
@@ -34,8 +34,6 @@
|
||||
|
||||
---
|
||||
|
||||
## 什么是 LangBot?
|
||||
|
||||
LangBot 是一个**开源的生产级平台**,用于构建 AI 驱动的即时通信机器人。它将大语言模型(LLM)连接到各种聊天平台,帮助你创建能够对话、执行任务、并集成到现有工作流程中的智能 Agent。
|
||||
|
||||
### 核心能力
|
||||
@@ -43,11 +41,13 @@ LangBot 是一个**开源的生产级平台**,用于构建 AI 驱动的即时
|
||||
- **AI 对话与 Agent** — 多轮对话、工具调用、多模态、流式输出。自带 RAG(知识库),深度集成 [Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io)、[Langflow](https://langflow.org) 等 LLMOps 平台。
|
||||
- **全平台支持** — 一套代码,覆盖 QQ、微信、企业微信、飞书、钉钉、Discord、Telegram、Slack、LINE、KOOK 等平台。
|
||||
- **生产就绪** — 访问控制、限速、敏感词过滤、全面监控与异常处理,已被多家企业采用。
|
||||
- **插件生态** — 数百个插件,事件驱动架构,组件扩展,适配 [MCP 协议](https://modelcontextprotocol.io/)。
|
||||
- **插件生态** — 数百个插件,跨进程的事件驱动架构,组件扩展,适配 [MCP 协议](https://modelcontextprotocol.io/)。
|
||||
- **Web 管理面板** — 通过浏览器直观地配置、管理和监控机器人,无需手动编辑配置文件。
|
||||
- **多流水线架构** — 不同机器人用于不同场景,具备全面的监控和异常处理能力。
|
||||
|
||||
[→ 了解更多功能特性](https://docs.langbot.app/zh/insight/features.html)
|
||||
[→ 了解更多功能特性](https://link.langbot.app/zh/docs/features)
|
||||
|
||||
📍 实践指南:[5 分钟部署多平台 AI 机器人](https://blog.langbot.app/zh/blog/deploy-ai-bot-in-5-minutes/)、[将 DeepSeek 接入微信、企业微信与 Discord](https://blog.langbot.app/zh/blog/connect-deepseek-to-wechat/)、[让 Dify Agent 跑在 Discord、Telegram 和 Slack 上](https://blog.langbot.app/zh/blog/dify-agent-discord-telegram-slack/),以及[用 n8n 构建多平台 AI 聊天机器人](https://blog.langbot.app/zh/blog/n8n-multi-platform-ai-chatbot/)。
|
||||
|
||||
---
|
||||
|
||||
@@ -78,7 +78,7 @@ docker compose up -d
|
||||
[](https://zeabur.com/zh-CN/templates/ZKTBDH)
|
||||
[](https://railway.app/template/yRrAyL?referralCode=vogKPF)
|
||||
|
||||
**更多方式:** [Docker](https://docs.langbot.app/zh/deploy/langbot/docker.html) · [手动部署](https://docs.langbot.app/zh/deploy/langbot/manual.html) · [宝塔面板](https://docs.langbot.app/zh/deploy/langbot/one-click/bt.html) · [Kubernetes](./docker/README_K8S.md)
|
||||
**更多方式:** [Docker](https://link.langbot.app/zh/docs/docker) · [手动部署](https://link.langbot.app/zh/docs/manual-deploy) · [宝塔面板](https://link.langbot.app/zh/docs/bt-panel) · [Kubernetes](./docker/README_K8S.md)
|
||||
|
||||
---
|
||||
|
||||
@@ -89,13 +89,16 @@ docker compose up -d
|
||||
| QQ | ✅ | 个人号、官方机器人(频道、私聊、群聊) |
|
||||
| 微信 | ✅ | 个人微信、微信公众号 |
|
||||
| 企业微信 | ✅ | 应用消息、对外客服、智能机器人 |
|
||||
| 飞书 | ✅ | |
|
||||
| 钉钉 | ✅ | |
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| 飞书 | ✅ | 官方 |
|
||||
| 钉钉 | ✅ | 官方 |
|
||||
| Satori | ✅ | |
|
||||
| Discord | ✅ | 官方 |
|
||||
| Telegram | ✅ | 官方 |
|
||||
| Slack | ✅ | 官方 |
|
||||
| LINE | ✅ | 官方 |
|
||||
| KOOK | ✅ | 官方 |
|
||||
| Email | ✅ | 只 Matrix、Satori |
|
||||
| Matrix | ✅ | 支持多种桥接平台,如 Signal、WhatsApp、Messenger、iMessage、Mattermost、Google Chat、IRC、XMPP、Zulip 等 |
|
||||
|
||||
---
|
||||
|
||||
@@ -126,8 +129,9 @@ docker compose up -d
|
||||
| [302.AI](https://share.302.ai/SuTG99) | 聚合平台 | ✅ |
|
||||
| [小马算力](https://www.tokenpony.cn/453z1) | 聚合平台 | ✅ |
|
||||
| [百宝箱Tbox](https://www.tbox.cn/open) | 智能体平台 | ✅ |
|
||||
| [七牛云Qiniu](https://www.qiniu.com/ai/agent) | 聚合平台 | ✅ |
|
||||
|
||||
[→ 查看完整集成列表](https://docs.langbot.app/zh/insight/features.html)
|
||||
[→ 查看完整集成列表](https://link.langbot.app/zh/docs/features)
|
||||
|
||||
### TTS(语音合成)
|
||||
|
||||
|
||||
33
README_ES.md
33
README_ES.md
@@ -19,9 +19,9 @@
|
||||
[](https://github.com/langbot-app/LangBot/stargazers)
|
||||
|
||||
<a href="https://langbot.app">Inicio</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features.html">Características</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide.html">Documentación</a> |
|
||||
<a href="https://docs.langbot.app/en/tags/readme.html">API</a> |
|
||||
<a href="https://link.langbot.app/en/docs/features">Características</a> |
|
||||
<a href="https://link.langbot.app/en/docs/guide">Documentación</a> |
|
||||
<a href="https://link.langbot.app/en/docs/api">API</a> |
|
||||
<a href="https://space.langbot.app">Mercado de Plugins</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">Hoja de Ruta</a>
|
||||
|
||||
@@ -44,7 +44,9 @@ LangBot es una **plataforma de código abierto y grado de producción** para con
|
||||
- **Panel de Gestión Web** — Configure, gestione y monitoree sus bots a través de una interfaz de navegador intuitiva. Sin necesidad de editar YAML.
|
||||
- **Arquitectura Multi-Pipeline** — Diferentes bots para diferentes escenarios, con monitoreo completo y manejo de excepciones.
|
||||
|
||||
[→ Conocer más sobre todas las funcionalidades](https://docs.langbot.app/en/insight/features.html)
|
||||
[→ Conocer más sobre todas las funcionalidades](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Guías prácticas: [desplegar un bot de IA multiplataforma en 5 minutos](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [conectar DeepSeek a WeChat, Discord y Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [ejecutar un Dify Agent en Discord, Telegram y Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) y [crear un chatbot con n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
@@ -75,7 +77,7 @@ docker compose up -d
|
||||
[](https://zeabur.com/en-US/templates/ZKTBDH)
|
||||
[](https://railway.app/template/yRrAyL?referralCode=vogKPF)
|
||||
|
||||
**Más opciones:** [Docker](https://docs.langbot.app/en/deploy/langbot/docker.html) · [Manual](https://docs.langbot.app/en/deploy/langbot/manual.html) · [BTPanel](https://docs.langbot.app/en/deploy/langbot/one-click/bt.html) · [Kubernetes](./docker/README_K8S.md)
|
||||
**Más opciones:** [Docker](https://link.langbot.app/en/docs/docker) · [Manual](https://link.langbot.app/en/docs/manual-deploy) · [BTPanel](https://link.langbot.app/en/docs/bt-panel) · [Kubernetes](./docker/README_K8S.md)
|
||||
|
||||
---
|
||||
|
||||
@@ -83,17 +85,19 @@ docker compose up -d
|
||||
|
||||
| Plataforma | Estado | Notas |
|
||||
|----------|--------|-------|
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | Personal y API Oficial |
|
||||
| Discord | ✅ | Oficial |
|
||||
| Telegram | ✅ | Oficial |
|
||||
| Slack | ✅ | Oficial |
|
||||
| LINE | ✅ | Oficial |
|
||||
| QQ | ✅ | Personal y API Oficial (Canal, DM, Grupo) |
|
||||
| WeCom | ✅ | WeChat Empresarial, CS Externo, AI Bot |
|
||||
| WeChat | ✅ | Personal y Cuenta Oficial |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Lark | ✅ | Oficial |
|
||||
| DingTalk | ✅ | Oficial |
|
||||
| KOOK | ✅ | Oficial |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix, Satori |
|
||||
| Matrix | ✅ | Admite varias plataformas puenteadas como Signal, WhatsApp, Messenger, iMessage, Mattermost, Google Chat, IRC, XMPP, Zulip y más |
|
||||
|
||||
---
|
||||
|
||||
@@ -122,8 +126,9 @@ docker compose up -d
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | Plataforma GPU | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | Pasarela | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | Pasarela | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | Pasarela | ✅ |
|
||||
|
||||
[→ Ver todas las integraciones](https://docs.langbot.app/en/insight/features.html)
|
||||
[→ Ver todas las integraciones](https://link.langbot.app/en/docs/features)
|
||||
|
||||
---
|
||||
|
||||
|
||||
33
README_FR.md
33
README_FR.md
@@ -19,9 +19,9 @@
|
||||
[](https://github.com/langbot-app/LangBot/stargazers)
|
||||
|
||||
<a href="https://langbot.app">Accueil</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features.html">Fonctionnalités</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide.html">Documentation</a> |
|
||||
<a href="https://docs.langbot.app/en/tags/readme.html">API</a> |
|
||||
<a href="https://link.langbot.app/en/docs/features">Fonctionnalités</a> |
|
||||
<a href="https://link.langbot.app/en/docs/guide">Documentation</a> |
|
||||
<a href="https://link.langbot.app/en/docs/api">API</a> |
|
||||
<a href="https://space.langbot.app">Marché des Plugins</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">Feuille de Route</a>
|
||||
|
||||
@@ -44,7 +44,9 @@ LangBot est une **plateforme open-source de niveau production** pour créer des
|
||||
- **Panneau de Gestion Web** — Configurez, gérez et surveillez vos bots via une interface navigateur intuitive. Aucune édition de YAML requise.
|
||||
- **Architecture Multi-Pipeline** — Différents bots pour différents scénarios, avec surveillance complète et gestion des exceptions.
|
||||
|
||||
[→ En savoir plus sur toutes les fonctionnalités](https://docs.langbot.app/en/insight/features.html)
|
||||
[→ En savoir plus sur toutes les fonctionnalités](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Guides pratiques : [déployer un bot IA multiplateforme en 5 minutes](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [connecter DeepSeek à WeChat, Discord et Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [exécuter un Dify Agent dans Discord, Telegram et Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) et [créer un chatbot avec n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
@@ -75,7 +77,7 @@ docker compose up -d
|
||||
[](https://zeabur.com/en-US/templates/ZKTBDH)
|
||||
[](https://railway.app/template/yRrAyL?referralCode=vogKPF)
|
||||
|
||||
**Plus d'options :** [Docker](https://docs.langbot.app/en/deploy/langbot/docker.html) · [Manuel](https://docs.langbot.app/en/deploy/langbot/manual.html) · [BTPanel](https://docs.langbot.app/en/deploy/langbot/one-click/bt.html) · [Kubernetes](./docker/README_K8S.md)
|
||||
**Plus d'options :** [Docker](https://link.langbot.app/en/docs/docker) · [Manuel](https://link.langbot.app/en/docs/manual-deploy) · [BTPanel](https://link.langbot.app/en/docs/bt-panel) · [Kubernetes](./docker/README_K8S.md)
|
||||
|
||||
---
|
||||
|
||||
@@ -83,17 +85,19 @@ docker compose up -d
|
||||
|
||||
| Plateforme | Statut | Notes |
|
||||
|----------|--------|-------|
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | Personnel & API Officielle |
|
||||
| Discord | ✅ | Officiel |
|
||||
| Telegram | ✅ | Officiel |
|
||||
| Slack | ✅ | Officiel |
|
||||
| LINE | ✅ | Officiel |
|
||||
| QQ | ✅ | Personnel & API Officielle (Canal, DM, Groupe) |
|
||||
| WeCom | ✅ | WeChat Entreprise, CS Externe, AI Bot |
|
||||
| WeChat | ✅ | Personnel & Compte Officiel |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Lark | ✅ | Officiel |
|
||||
| DingTalk | ✅ | Officiel |
|
||||
| KOOK | ✅ | Officiel |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix, Satori |
|
||||
| Matrix | ✅ | Prend en charge plusieurs plateformes via ponts, comme Signal, WhatsApp, Messenger, iMessage, Mattermost, Google Chat, IRC, XMPP, Zulip, etc. |
|
||||
|
||||
---
|
||||
|
||||
@@ -122,8 +126,9 @@ docker compose up -d
|
||||
| [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_langbot) | Plateforme GPU | ✅ |
|
||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | Plateforme GPU | ✅ |
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | Plateforme GPU | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | Passerelle | ✅ |
|
||||
|
||||
[→ Voir toutes les intégrations](https://docs.langbot.app/en/insight/features.html)
|
||||
[→ Voir toutes les intégrations](https://link.langbot.app/en/docs/features)
|
||||
|
||||
---
|
||||
|
||||
|
||||
35
README_JP.md
35
README_JP.md
@@ -19,9 +19,9 @@
|
||||
[](https://github.com/langbot-app/LangBot/stargazers)
|
||||
|
||||
<a href="https://langbot.app">ホーム</a> |
|
||||
<a href="https://docs.langbot.app/ja/insight/features.html">機能</a> |
|
||||
<a href="https://docs.langbot.app/ja/insight/guide.html">ドキュメント</a> |
|
||||
<a href="https://docs.langbot.app/ja/tags/readme.html">API</a> |
|
||||
<a href="https://link.langbot.app/ja/docs/features">機能</a> |
|
||||
<a href="https://link.langbot.app/ja/docs/guide">ドキュメント</a> |
|
||||
<a href="https://link.langbot.app/ja/docs/api">API</a> |
|
||||
<a href="https://space.langbot.app">プラグインマーケット</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">ロードマップ</a>
|
||||
|
||||
@@ -44,7 +44,9 @@ LangBot は、AI搭載のインスタントメッセージングボットを構
|
||||
- **Web管理パネル** — 直感的なブラウザインターフェースからボットの設定、管理、監視が可能。YAML編集は不要。
|
||||
- **マルチパイプラインアーキテクチャ** — 異なるシナリオに異なるボットを配置し、包括的な監視と例外処理を実現。
|
||||
|
||||
[→ すべての機能について詳しく見る](https://docs.langbot.app/ja/insight/features.html)
|
||||
[→ すべての機能について詳しく見る](https://link.langbot.app/ja/docs/features)
|
||||
|
||||
📍 実践ガイド: [5分でマルチプラットフォームAIボットをデプロイ](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/)、[DeepSeekをWeChat・Discord・Telegramに接続](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/)、[Dify AgentをDiscord・Telegram・Slackで動かす](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/)、[n8n連携チャットボットを構築](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/)。
|
||||
|
||||
---
|
||||
|
||||
@@ -75,7 +77,7 @@ docker compose up -d
|
||||
[](https://zeabur.com/en-US/templates/ZKTBDH)
|
||||
[](https://railway.app/template/yRrAyL?referralCode=vogKPF)
|
||||
|
||||
**その他:** [Docker](https://docs.langbot.app/en/deploy/langbot/docker.html) · [手動デプロイ](https://docs.langbot.app/en/deploy/langbot/manual.html) · [BTPanel](https://docs.langbot.app/en/deploy/langbot/one-click/bt.html) · [Kubernetes](./docker/README_K8S.md)
|
||||
**その他:** [Docker](https://link.langbot.app/en/docs/docker) · [手動デプロイ](https://link.langbot.app/en/docs/manual-deploy) · [BTPanel](https://link.langbot.app/en/docs/bt-panel) · [Kubernetes](./docker/README_K8S.md)
|
||||
|
||||
---
|
||||
|
||||
@@ -83,17 +85,19 @@ docker compose up -d
|
||||
|
||||
| プラットフォーム | ステータス | 備考 |
|
||||
|----------|--------|-------|
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | 個人 & 公式API |
|
||||
| Discord | ✅ | 公式 |
|
||||
| Telegram | ✅ | 公式 |
|
||||
| Slack | ✅ | 公式 |
|
||||
| LINE | ✅ | 公式 |
|
||||
| QQ | ✅ | 個人・公式API(チャンネル・DM・グループ) |
|
||||
| WeCom | ✅ | 企業WeChat、外部CS、AIボット |
|
||||
| WeChat | ✅ | 個人 & 公式アカウント |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| WeChat | ✅ | 個人・公式アカウント |
|
||||
| Lark | ✅ | 公式 |
|
||||
| DingTalk | ✅ | 公式 |
|
||||
| KOOK | ✅ | 公式 |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix、Satori |
|
||||
| Matrix | ✅ | Signal、WhatsApp、Messenger、iMessage、Mattermost、Google Chat、IRC、XMPP、Zulip など複数のブリッジ先プラットフォームに対応 |
|
||||
|
||||
---
|
||||
|
||||
@@ -122,8 +126,9 @@ docker compose up -d
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | GPUプラットフォーム | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | ゲートウェイ | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | ゲートウェイ | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | ゲートウェイ | ✅ |
|
||||
|
||||
[→ すべての統合を表示](https://docs.langbot.app/en/insight/features.html)
|
||||
[→ すべての統合を表示](https://link.langbot.app/en/docs/features)
|
||||
|
||||
---
|
||||
|
||||
|
||||
33
README_KO.md
33
README_KO.md
@@ -19,9 +19,9 @@
|
||||
[](https://github.com/langbot-app/LangBot/stargazers)
|
||||
|
||||
<a href="https://langbot.app">홈</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features.html">기능</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide.html">문서</a> |
|
||||
<a href="https://docs.langbot.app/en/tags/readme.html">API</a> |
|
||||
<a href="https://link.langbot.app/en/docs/features">기능</a> |
|
||||
<a href="https://link.langbot.app/en/docs/guide">문서</a> |
|
||||
<a href="https://link.langbot.app/en/docs/api">API</a> |
|
||||
<a href="https://space.langbot.app">플러그인 마켓</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">로드맵</a>
|
||||
|
||||
@@ -44,7 +44,9 @@ LangBot은 AI 기반 인스턴트 메시징 봇을 구축하기 위한 **오픈
|
||||
- **웹 관리 패널** — 직관적인 브라우저 인터페이스로 봇을 구성, 관리 및 모니터링. YAML 편집 불필요.
|
||||
- **멀티 파이프라인 아키텍처** — 다양한 시나리오에 맞는 다양한 봇 구성, 종합 모니터링 및 예외 처리.
|
||||
|
||||
[→ 모든 기능 자세히 보기](https://docs.langbot.app/en/insight/features.html)
|
||||
[→ 모든 기능 자세히 보기](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 실전 가이드: [5분 만에 멀티 플랫폼 AI 봇 배포하기](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [DeepSeek를 WeChat, Discord, Telegram에 연결하기](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [Dify Agent를 Discord, Telegram, Slack에서 실행하기](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/), [n8n 기반 챗봇 만들기](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
@@ -75,7 +77,7 @@ docker compose up -d
|
||||
[](https://zeabur.com/en-US/templates/ZKTBDH)
|
||||
[](https://railway.app/template/yRrAyL?referralCode=vogKPF)
|
||||
|
||||
**더 많은 옵션:** [Docker](https://docs.langbot.app/en/deploy/langbot/docker.html) · [수동 배포](https://docs.langbot.app/en/deploy/langbot/manual.html) · [BTPanel](https://docs.langbot.app/en/deploy/langbot/one-click/bt.html) · [Kubernetes](./docker/README_K8S.md)
|
||||
**더 많은 옵션:** [Docker](https://link.langbot.app/en/docs/docker) · [수동 배포](https://link.langbot.app/en/docs/manual-deploy) · [BTPanel](https://link.langbot.app/en/docs/bt-panel) · [Kubernetes](./docker/README_K8S.md)
|
||||
|
||||
---
|
||||
|
||||
@@ -83,17 +85,19 @@ docker compose up -d
|
||||
|
||||
| 플랫폼 | 상태 | 비고 |
|
||||
|--------|------|------|
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | 개인 및 공식 API |
|
||||
| Discord | ✅ | 공식 |
|
||||
| Telegram | ✅ | 공식 |
|
||||
| Slack | ✅ | 공식 |
|
||||
| LINE | ✅ | 공식 |
|
||||
| QQ | ✅ | 개인 및 공식 API (채널, DM, 그룹) |
|
||||
| WeCom | ✅ | 기업 WeChat, 외부 CS, AI Bot |
|
||||
| WeChat | ✅ | 개인 및 공식 계정 |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Lark | ✅ | 공식 |
|
||||
| DingTalk | ✅ | 공식 |
|
||||
| KOOK | ✅ | 공식 |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix, Satori |
|
||||
| Matrix | ✅ | Signal, WhatsApp, Messenger, iMessage, Mattermost, Google Chat, IRC, XMPP, Zulip 등 여러 브리지 플랫폼 지원 |
|
||||
|
||||
---
|
||||
|
||||
@@ -122,8 +126,9 @@ docker compose up -d
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | GPU 플랫폼 | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | 게이트웨이 | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | 게이트웨이 | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | 게이트웨이 | ✅ |
|
||||
|
||||
[→ 모든 통합 보기](https://docs.langbot.app/en/insight/features.html)
|
||||
[→ 모든 통합 보기](https://link.langbot.app/en/docs/features)
|
||||
|
||||
---
|
||||
|
||||
|
||||
33
README_RU.md
33
README_RU.md
@@ -19,9 +19,9 @@
|
||||
[](https://github.com/langbot-app/LangBot/stargazers)
|
||||
|
||||
<a href="https://langbot.app">Главная</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features.html">Возможности</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide.html">Документация</a> |
|
||||
<a href="https://docs.langbot.app/en/tags/readme.html">API</a> |
|
||||
<a href="https://link.langbot.app/en/docs/features">Возможности</a> |
|
||||
<a href="https://link.langbot.app/en/docs/guide">Документация</a> |
|
||||
<a href="https://link.langbot.app/en/docs/api">API</a> |
|
||||
<a href="https://space.langbot.app">Магазин плагинов</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">Дорожная карта</a>
|
||||
|
||||
@@ -44,7 +44,9 @@ LangBot — это **платформа с открытым исходным к
|
||||
- **Веб-панель управления** — Настраивайте, управляйте и мониторьте ваших ботов через интуитивный браузерный интерфейс. Ручное редактирование YAML не требуется.
|
||||
- **Мультиконвейерная архитектура** — Разные боты для разных сценариев с комплексным мониторингом и обработкой исключений.
|
||||
|
||||
[→ Подробнее обо всех возможностях](https://docs.langbot.app/en/insight/features.html)
|
||||
[→ Подробнее обо всех возможностях](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Практические руководства: [развернуть мультиплатформенного ИИ-бота за 5 минут](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [подключить DeepSeek к WeChat, Discord и Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [запустить Dify Agent в Discord, Telegram и Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) и [создать чат-бота на n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
@@ -75,7 +77,7 @@ docker compose up -d
|
||||
[](https://zeabur.com/en-US/templates/ZKTBDH)
|
||||
[](https://railway.app/template/yRrAyL?referralCode=vogKPF)
|
||||
|
||||
**Другие варианты:** [Docker](https://docs.langbot.app/en/deploy/langbot/docker.html) · [Ручная установка](https://docs.langbot.app/en/deploy/langbot/manual.html) · [BTPanel](https://docs.langbot.app/en/deploy/langbot/one-click/bt.html) · [Kubernetes](./docker/README_K8S.md)
|
||||
**Другие варианты:** [Docker](https://link.langbot.app/en/docs/docker) · [Ручная установка](https://link.langbot.app/en/docs/manual-deploy) · [BTPanel](https://link.langbot.app/en/docs/bt-panel) · [Kubernetes](./docker/README_K8S.md)
|
||||
|
||||
---
|
||||
|
||||
@@ -83,17 +85,19 @@ docker compose up -d
|
||||
|
||||
| Платформа | Статус | Примечания |
|
||||
|-----------|--------|------------|
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | Личный и официальный API |
|
||||
| Discord | ✅ | Официальный |
|
||||
| Telegram | ✅ | Официальный |
|
||||
| Slack | ✅ | Официальный |
|
||||
| LINE | ✅ | Официальный |
|
||||
| QQ | ✅ | Личный и официальный API (Канал, ЛС, Группа) |
|
||||
| WeCom | ✅ | Корпоративный WeChat, внешний CS, AI-бот |
|
||||
| WeChat | ✅ | Личный и официальный аккаунт |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Lark | ✅ | Официальный |
|
||||
| DingTalk | ✅ | Официальный |
|
||||
| KOOK | ✅ | Официальный |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix, Satori |
|
||||
| Matrix | ✅ | Поддерживает несколько платформ через мосты, включая Signal, WhatsApp, Messenger, iMessage, Mattermost, Google Chat, IRC, XMPP, Zulip и другие |
|
||||
|
||||
---
|
||||
|
||||
@@ -122,8 +126,9 @@ docker compose up -d
|
||||
| [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_langbot) | Платформа GPU | ✅ |
|
||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | Платформа GPU | ✅ |
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | Платформа GPU | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | Шлюз | ✅ |
|
||||
|
||||
[→ Смотреть все интеграции](https://docs.langbot.app/en/insight/features.html)
|
||||
[→ Смотреть все интеграции](https://link.langbot.app/en/docs/features)
|
||||
|
||||
---
|
||||
|
||||
|
||||
33
README_TW.md
33
README_TW.md
@@ -21,9 +21,9 @@
|
||||
[](https://gitcode.com/RockChinQ/LangBot)
|
||||
|
||||
<a href="https://langbot.app">官網</a> |
|
||||
<a href="https://docs.langbot.app/zh/insight/features.html">特性</a> |
|
||||
<a href="https://docs.langbot.app/zh/insight/guide.html">文件</a> |
|
||||
<a href="https://docs.langbot.app/zh/tags/readme.html">API</a> |
|
||||
<a href="https://link.langbot.app/zh/docs/features">特性</a> |
|
||||
<a href="https://link.langbot.app/zh/docs/guide">文件</a> |
|
||||
<a href="https://link.langbot.app/zh/docs/api">API</a> |
|
||||
<a href="https://space.langbot.app">外掛市場</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">路線圖</a>
|
||||
|
||||
@@ -46,7 +46,9 @@ LangBot 是一個**開源的生產級平台**,用於建構 AI 驅動的即時
|
||||
- **Web 管理面板** — 透過瀏覽器直觀地配置、管理和監控機器人,無需手動編輯設定檔。
|
||||
- **多流水線架構** — 不同機器人用於不同場景,具備全面的監控和異常處理能力。
|
||||
|
||||
[→ 了解更多功能特性](https://docs.langbot.app/zh/insight/features.html)
|
||||
[→ 了解更多功能特性](https://link.langbot.app/zh/docs/features)
|
||||
|
||||
📍 實踐指南:[5 分鐘部署多平台 AI 機器人](https://blog.langbot.app/zh/blog/deploy-ai-bot-in-5-minutes/)、[將 DeepSeek 接入微信、企業微信與 Discord](https://blog.langbot.app/zh/blog/connect-deepseek-to-wechat/)、[讓 Dify Agent 跑在 Discord、Telegram 和 Slack 上](https://blog.langbot.app/zh/blog/dify-agent-discord-telegram-slack/),以及[用 n8n 建構多平台 AI 聊天機器人](https://blog.langbot.app/zh/blog/n8n-multi-platform-ai-chatbot/)。
|
||||
|
||||
---
|
||||
|
||||
@@ -77,7 +79,7 @@ docker compose up -d
|
||||
[](https://zeabur.com/zh-CN/templates/ZKTBDH)
|
||||
[](https://railway.app/template/yRrAyL?referralCode=vogKPF)
|
||||
|
||||
**更多方式:** [Docker](https://docs.langbot.app/zh/deploy/langbot/docker.html) · [手動部署](https://docs.langbot.app/zh/deploy/langbot/manual.html) · [寶塔面板](https://docs.langbot.app/zh/deploy/langbot/one-click/bt.html) · [Kubernetes](./docker/README_K8S.md)
|
||||
**更多方式:** [Docker](https://link.langbot.app/zh/docs/docker) · [手動部署](https://link.langbot.app/zh/docs/manual-deploy) · [寶塔面板](https://link.langbot.app/zh/docs/bt-panel) · [Kubernetes](./docker/README_K8S.md)
|
||||
|
||||
---
|
||||
|
||||
@@ -85,17 +87,19 @@ docker compose up -d
|
||||
|
||||
| 平台 | 狀態 | 備註 |
|
||||
|------|------|------|
|
||||
| Discord | ✅ | 官方 |
|
||||
| Telegram | ✅ | 官方 |
|
||||
| Slack | ✅ | 官方 |
|
||||
| LINE | ✅ | 官方 |
|
||||
| QQ | ✅ | 個人號、官方機器人(頻道、私聊、群聊) |
|
||||
| 微信 | ✅ | 個人微信、微信公眾號 |
|
||||
| 企業微信 | ✅ | 應用訊息、對外客服、智能機器人 |
|
||||
| 飛書 | ✅ | |
|
||||
| 釘釘 | ✅ | |
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| 微信 | ✅ | 個人微信、微信公眾號 |
|
||||
| 飛書 | ✅ | 官方 |
|
||||
| 釘釘 | ✅ | 官方 |
|
||||
| KOOK | ✅ | 官方 |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | 只 Matrix、Satori |
|
||||
| Matrix | ✅ | 支援多種橋接平台,如 Signal、WhatsApp、Messenger、iMessage、Mattermost、Google Chat、IRC、XMPP、Zulip 等 |
|
||||
|
||||
---
|
||||
|
||||
@@ -124,6 +128,7 @@ docker compose up -d
|
||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | GPU 平台 | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | 聚合平台 | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | 聚合平台 | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | 聚合平台 | ✅ |
|
||||
|
||||
### TTS(語音合成)
|
||||
|
||||
@@ -139,7 +144,7 @@ docker compose up -d
|
||||
|-----------|------|
|
||||
| 阿里雲百煉 | [外掛](https://github.com/Thetail001/LangBot_BailianTextToImagePlugin) |
|
||||
|
||||
[→ 查看完整整合列表](https://docs.langbot.app/zh/insight/features.html)
|
||||
[→ 查看完整整合列表](https://link.langbot.app/zh/docs/features)
|
||||
|
||||
---
|
||||
|
||||
|
||||
33
README_VI.md
33
README_VI.md
@@ -19,9 +19,9 @@
|
||||
[](https://github.com/langbot-app/LangBot/stargazers)
|
||||
|
||||
<a href="https://langbot.app">Trang chủ</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features.html">Tính năng</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide.html">Tài liệu</a> |
|
||||
<a href="https://docs.langbot.app/en/tags/readme.html">API</a> |
|
||||
<a href="https://link.langbot.app/en/docs/features">Tính năng</a> |
|
||||
<a href="https://link.langbot.app/en/docs/guide">Tài liệu</a> |
|
||||
<a href="https://link.langbot.app/en/docs/api">API</a> |
|
||||
<a href="https://space.langbot.app">Chợ Plugin</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">Lộ trình</a>
|
||||
|
||||
@@ -44,7 +44,9 @@ LangBot là một **nền tảng mã nguồn mở, cấp sản xuất** để x
|
||||
- **Bảng quản lý Web** — Cấu hình, quản lý và giám sát bot thông qua giao diện trình duyệt trực quan. Không cần chỉnh sửa YAML.
|
||||
- **Kiến trúc đa Pipeline** — Các bot khác nhau cho các kịch bản khác nhau, với giám sát toàn diện và xử lý ngoại lệ.
|
||||
|
||||
[→ Tìm hiểu thêm về tất cả tính năng](https://docs.langbot.app/en/insight/features.html)
|
||||
[→ Tìm hiểu thêm về tất cả tính năng](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Hướng dẫn thực hành: [triển khai bot AI đa nền tảng trong 5 phút](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [kết nối DeepSeek với WeChat, Discord và Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [chạy Dify Agent trên Discord, Telegram và Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) và [xây dựng chatbot với n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
@@ -75,7 +77,7 @@ docker compose up -d
|
||||
[](https://zeabur.com/en-US/templates/ZKTBDH)
|
||||
[](https://railway.app/template/yRrAyL?referralCode=vogKPF)
|
||||
|
||||
**Thêm tùy chọn:** [Docker](https://docs.langbot.app/en/deploy/langbot/docker.html) · [Thủ công](https://docs.langbot.app/en/deploy/langbot/manual.html) · [BTPanel](https://docs.langbot.app/en/deploy/langbot/one-click/bt.html) · [Kubernetes](./docker/README_K8S.md)
|
||||
**Thêm tùy chọn:** [Docker](https://link.langbot.app/en/docs/docker) · [Thủ công](https://link.langbot.app/en/docs/manual-deploy) · [BTPanel](https://link.langbot.app/en/docs/bt-panel) · [Kubernetes](./docker/README_K8S.md)
|
||||
|
||||
---
|
||||
|
||||
@@ -83,17 +85,19 @@ docker compose up -d
|
||||
|
||||
| Nền tảng | Trạng thái | Ghi chú |
|
||||
|----------|--------|-------|
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | Cá nhân & API chính thức |
|
||||
| Discord | ✅ | Chính thức |
|
||||
| Telegram | ✅ | Chính thức |
|
||||
| Slack | ✅ | Chính thức |
|
||||
| LINE | ✅ | Chính thức |
|
||||
| QQ | ✅ | Cá nhân & API chính thức (Kênh, DM, Nhóm) |
|
||||
| WeCom | ✅ | WeChat doanh nghiệp, CS bên ngoài, AI Bot |
|
||||
| WeChat | ✅ | Cá nhân & Tài khoản công khai |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Lark | ✅ | Chính thức |
|
||||
| DingTalk | ✅ | Chính thức |
|
||||
| KOOK | ✅ | Chính thức |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix, Satori |
|
||||
| Matrix | ✅ | Hỗ trợ nhiều nền tảng qua bridge như Signal, WhatsApp, Messenger, iMessage, Mattermost, Google Chat, IRC, XMPP, Zulip và hơn thế nữa |
|
||||
|
||||
---
|
||||
|
||||
@@ -122,8 +126,9 @@ docker compose up -d
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | Nền tảng GPU | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | Cổng | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | Cổng | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | Cổng | ✅ |
|
||||
|
||||
[→ Xem tất cả tích hợp](https://docs.langbot.app/en/insight/features.html)
|
||||
[→ Xem tất cả tích hợp](https://link.langbot.app/en/docs/features)
|
||||
|
||||
---
|
||||
|
||||
|
||||
163
compare_nodes.py
Normal file
163
compare_nodes.py
Normal file
@@ -0,0 +1,163 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Compare YAML node definitions with frontend node-configs."""
|
||||
|
||||
import yaml
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
|
||||
# 1. Parse YAML files
|
||||
yaml_dir = 'src/langbot/templates/metadata/nodes'
|
||||
yaml_nodes = {}
|
||||
|
||||
for filename in sorted(os.listdir(yaml_dir)):
|
||||
if filename.endswith('.yaml'):
|
||||
filepath = os.path.join(yaml_dir, filename)
|
||||
with open(filepath, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
node_name = data.get('name', filename.replace('.yaml', ''))
|
||||
yaml_nodes[node_name] = {
|
||||
'category': data.get('category', ''),
|
||||
'inputs': [i['name'] for i in data.get('inputs', [])],
|
||||
'outputs': [o['name'] for o in data.get('outputs', [])],
|
||||
'config': [c['name'] for c in data.get('config', [])]
|
||||
}
|
||||
|
||||
# 2. Parse frontend node-configs TypeScript files
|
||||
node_configs_dir = 'web/src/app/home/workflows/components/workflow-editor/node-configs'
|
||||
|
||||
frontend_nodes = {}
|
||||
|
||||
def parse_ts_file(filepath):
|
||||
"""Parse a TypeScript file to extract node configurations."""
|
||||
with open(filepath, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Find all node type definitions
|
||||
# Pattern: nodeType: 'xxx'
|
||||
node_type_pattern = r"nodeType:\s*'([^']+)'"
|
||||
node_types = re.findall(node_type_pattern, content)
|
||||
|
||||
# For each node type, extract inputs, outputs, and config
|
||||
for node_type in node_types:
|
||||
# Find the config object for this node type
|
||||
# Look for the section between this nodeType and the next one or end of object
|
||||
pattern = rf"nodeType:\s*'({re.escape(node_type)})'.*?(?=nodeType:|export\s+(const|function)|$)"
|
||||
match = re.search(pattern, content, re.DOTALL)
|
||||
|
||||
if match:
|
||||
section = match.group(0)
|
||||
|
||||
# Extract inputs
|
||||
inputs = re.findall(r"createInput\('([^']+)'", section)
|
||||
|
||||
# Extract outputs
|
||||
outputs = re.findall(r"createOutput\('([^']+)'", section)
|
||||
|
||||
# Extract config names
|
||||
config_names = re.findall(r"name:\s*'([^']+)'", section)
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_config = []
|
||||
for c in config_names:
|
||||
if c not in seen:
|
||||
seen.add(c)
|
||||
unique_config.append(c)
|
||||
|
||||
frontend_nodes[node_type] = {
|
||||
'inputs': inputs,
|
||||
'outputs': outputs,
|
||||
'config': unique_config
|
||||
}
|
||||
|
||||
# Parse all config files
|
||||
for filename in os.listdir(node_configs_dir):
|
||||
if filename.endswith('.ts') and filename != 'types.ts' and filename != 'index.ts':
|
||||
filepath = os.path.join(node_configs_dir, filename)
|
||||
parse_ts_file(filepath)
|
||||
|
||||
# 3. Compare and report differences
|
||||
print("=" * 80)
|
||||
print("WORKFLOW NODE COMPARISON REPORT: YAML vs Frontend")
|
||||
print("=" * 80)
|
||||
|
||||
all_node_types = sorted(set(list(yaml_nodes.keys()) + list(frontend_nodes.keys())))
|
||||
|
||||
discrepancies = []
|
||||
|
||||
for node_type in all_node_types:
|
||||
yaml_def = yaml_nodes.get(node_type)
|
||||
frontend_def = frontend_nodes.get(node_type)
|
||||
|
||||
node_discrepancies = []
|
||||
|
||||
if not yaml_def:
|
||||
print(f"\n⚠️ {node_type}: ONLY in frontend (not in YAML)")
|
||||
continue
|
||||
if not frontend_def:
|
||||
print(f"\n⚠️ {node_type}: ONLY in YAML (not in frontend)")
|
||||
continue
|
||||
|
||||
# Compare inputs
|
||||
yaml_inputs = set(yaml_def['inputs'])
|
||||
frontend_inputs = set(frontend_def['inputs'])
|
||||
if yaml_inputs != frontend_inputs:
|
||||
only_yaml = yaml_inputs - frontend_inputs
|
||||
only_frontend = frontend_inputs - yaml_inputs
|
||||
node_discrepancies.append({
|
||||
'type': 'inputs',
|
||||
'only_yaml': list(only_yaml),
|
||||
'only_frontend': list(only_frontend)
|
||||
})
|
||||
|
||||
# Compare outputs
|
||||
yaml_outputs = set(yaml_def['outputs'])
|
||||
frontend_outputs = set(frontend_def['outputs'])
|
||||
if yaml_outputs != frontend_outputs:
|
||||
only_yaml = yaml_outputs - frontend_outputs
|
||||
only_frontend = frontend_outputs - yaml_outputs
|
||||
node_discrepancies.append({
|
||||
'type': 'outputs',
|
||||
'only_yaml': list(only_yaml),
|
||||
'only_frontend': list(only_frontend)
|
||||
})
|
||||
|
||||
# Compare config
|
||||
yaml_config = set(yaml_def['config'])
|
||||
frontend_config = set(frontend_def['config'])
|
||||
if yaml_config != frontend_config:
|
||||
only_yaml = yaml_config - frontend_config
|
||||
only_frontend = frontend_config - yaml_config
|
||||
node_discrepancies.append({
|
||||
'type': 'config',
|
||||
'only_yaml': list(only_yaml),
|
||||
'only_frontend': list(only_frontend)
|
||||
})
|
||||
|
||||
if node_discrepancies:
|
||||
print(f"\n❌ {node_type} ({yaml_def['category']}): HAS DISCREPANCIES")
|
||||
for d in node_discrepancies:
|
||||
print(f" {d['type']}:")
|
||||
if d['only_yaml']:
|
||||
print(f" Only in YAML: {d['only_yaml']}")
|
||||
if d['only_frontend']:
|
||||
print(f" Only in Frontend: {d['only_frontend']}")
|
||||
discrepancies.append((node_type, node_discrepancies))
|
||||
else:
|
||||
print(f"\n✅ {node_type} ({yaml_def['category']}): OK")
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"SUMMARY: {len(discrepancies)} nodes with discrepancies out of {len(all_node_types)} total")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Output as JSON for further processing
|
||||
output = {
|
||||
'yaml_nodes': {k: v for k, v in yaml_nodes.items()},
|
||||
'frontend_nodes': {k: v for k, v in frontend_nodes.items()},
|
||||
'discrepancies': {k: v for k, v in discrepancies}
|
||||
}
|
||||
|
||||
with open('node_comparison.json', 'w') as f:
|
||||
json.dump(output, f, indent=2)
|
||||
|
||||
print(f"\nDetailed comparison saved to node_comparison.json")
|
||||
@@ -312,7 +312,7 @@ spec:
|
||||
### 参考资源
|
||||
|
||||
- [LangBot 官方文档](https://docs.langbot.app)
|
||||
- [Docker 部署文档](https://docs.langbot.app/zh/deploy/langbot/docker.html)
|
||||
- [Docker 部署文档](https://link.langbot.app/zh/docs/docker)
|
||||
- [Kubernetes 官方文档](https://kubernetes.io/docs/)
|
||||
|
||||
---
|
||||
@@ -625,5 +625,5 @@ spec:
|
||||
### References
|
||||
|
||||
- [LangBot Official Documentation](https://docs.langbot.app)
|
||||
- [Docker Deployment Guide](https://docs.langbot.app/zh/deploy/langbot/docker.html)
|
||||
- [Docker Deployment Guide](https://link.langbot.app/zh/docs/docker)
|
||||
- [Kubernetes Official Documentation](https://kubernetes.io/docs/)
|
||||
|
||||
@@ -34,4 +34,4 @@ services:
|
||||
|
||||
networks:
|
||||
langbot_network:
|
||||
driver: bridge
|
||||
driver: bridge
|
||||
713
docs/development/workflow-system.md
Normal file
713
docs/development/workflow-system.md
Normal file
@@ -0,0 +1,713 @@
|
||||
# Workflow 系统开发者文档
|
||||
|
||||
本文档面向 LangBot 开发者,详细介绍 Workflow 系统的技术架构、核心组件和扩展方法。
|
||||
|
||||
## 目录
|
||||
|
||||
- [系统架构概述](#系统架构概述)
|
||||
- [目录结构](#目录结构)
|
||||
- [核心组件](#核心组件)
|
||||
- [后端模块](#后端模块)
|
||||
- [前端组件](#前端组件)
|
||||
- [数据库表结构](#数据库表结构)
|
||||
- [API 接口文档](#api-接口文档)
|
||||
- [如何添加新节点类型](#如何添加新节点类型)
|
||||
- [调试功能实现](#调试功能实现)
|
||||
|
||||
---
|
||||
|
||||
## 系统架构概述
|
||||
|
||||
Workflow 系统采用前后端分离架构,主要包含以下层次:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ 前端层 (React) │
|
||||
│ ┌─────────────┬──────────────┬──────────────┬───────────┐ │
|
||||
│ │ 可视化编辑器 │ 节点面板 │ 属性面板 │ 调试器 │ │
|
||||
│ │ ReactFlow │ NodePalette │ PropertyPanel│ Debugger │ │
|
||||
│ └─────────────┴──────────────┴──────────────┴───────────┘ │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ API 层 (Quart) │
|
||||
│ ┌─────────────┬──────────────┬──────────────────────────┐ │
|
||||
│ │ Workflow API│ Debug API │ Node Types API │ │
|
||||
│ └─────────────┴──────────────┴──────────────────────────┘ │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ 核心引擎层 (Python) │
|
||||
│ ┌─────────────┬──────────────┬──────────────┬───────────┐ │
|
||||
│ │ Executor │ Registry │ Node │ Entities │ │
|
||||
│ │ 执行引擎 │ 节点注册表 │ 节点基类 │ 数据结构 │ │
|
||||
│ └─────────────┴──────────────┴──────────────┴───────────┘ │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ 存储层 (SQLAlchemy) │
|
||||
│ ┌─────────────┬──────────────┬──────────────────────────┐ │
|
||||
│ │ Workflow │ Executions │ Triggers │ │
|
||||
│ └─────────────┴──────────────┴──────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 目录结构
|
||||
|
||||
### 后端代码结构
|
||||
|
||||
```
|
||||
LangBot/src/langbot/pkg/
|
||||
├── workflow/ # Workflow 核心模块
|
||||
│ ├── __init__.py # 模块初始化,导出公共接口
|
||||
│ ├── entities.py # 数据实体定义
|
||||
│ ├── executor.py # 执行引擎
|
||||
│ ├── node.py # 节点基类和装饰器
|
||||
│ ├── registry.py # 节点类型注册表
|
||||
│ └── nodes/ # 内置节点实现
|
||||
│ ├── __init__.py # 注册所有内置节点
|
||||
│ ├── trigger.py # 触发节点
|
||||
│ ├── process.py # 处理节点
|
||||
│ ├── control.py # 控制节点
|
||||
│ └── action.py # 动作节点
|
||||
├── entity/persistence/
|
||||
│ └── workflow.py # 数据库模型
|
||||
├── api/http/
|
||||
│ ├── controller/groups/workflows/
|
||||
│ │ └── workflows.py # API 路由控制器
|
||||
│ └── service/
|
||||
│ └── workflow.py # 业务逻辑服务
|
||||
└── persistence/migrations/
|
||||
└── dbm026_workflow_tables.py # 数据库迁移
|
||||
```
|
||||
|
||||
### 前端代码结构
|
||||
|
||||
```
|
||||
LangBot/web/src/app/home/workflows/
|
||||
├── page.tsx # Workflow 列表页
|
||||
├── WorkflowDetailContent.tsx # 详情页内容
|
||||
├── store/
|
||||
│ └── useWorkflowStore.ts # Zustand 状态管理
|
||||
└── components/
|
||||
├── workflow-editor/ # 可视化编辑器
|
||||
│ ├── index.ts # 导出
|
||||
│ ├── WorkflowEditorComponent.tsx # 主编辑器组件
|
||||
│ ├── WorkflowNodeComponent.tsx # 自定义节点组件
|
||||
│ ├── NodePalette.tsx # 节点面板
|
||||
│ ├── PropertyPanel.tsx # 属性面板
|
||||
│ └── node-configs/ # 节点配置元数据
|
||||
│ ├── types.ts # 配置类型定义
|
||||
│ ├── trigger-configs.ts
|
||||
│ ├── ai-configs.ts
|
||||
│ ├── process-configs.ts
|
||||
│ ├── control-configs.ts
|
||||
│ ├── action-configs.ts
|
||||
│ ├── integration-configs.ts
|
||||
│ └── index.ts # 配置汇总
|
||||
├── workflow-debugger/ # 调试器组件
|
||||
│ ├── index.ts
|
||||
│ └── WorkflowDebugger.tsx
|
||||
├── workflow-form/ # 表单组件
|
||||
│ └── WorkflowFormComponent.tsx
|
||||
└── workflow-executions/ # 执行历史组件
|
||||
└── WorkflowExecutionsTab.tsx
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 核心组件
|
||||
|
||||
### 后端模块
|
||||
|
||||
#### 1. 执行引擎 (WorkflowExecutor)
|
||||
|
||||
位置:[`executor.py`](../../src/langbot/pkg/workflow/executor.py)
|
||||
|
||||
执行引擎负责工作流的实际执行,包括:
|
||||
|
||||
- **拓扑排序**:确定节点执行顺序
|
||||
- **节点执行**:调用各节点的 execute 方法
|
||||
- **控制流处理**:处理条件分支、循环、并行执行
|
||||
- **错误处理**:支持重试机制
|
||||
|
||||
```python
|
||||
class WorkflowExecutor:
|
||||
async def execute(
|
||||
self,
|
||||
workflow: WorkflowDefinition,
|
||||
context: ExecutionContext,
|
||||
start_node_id: Optional[str] = None
|
||||
) -> ExecutionContext:
|
||||
"""执行工作流"""
|
||||
# 1. 构建执行图
|
||||
# 2. 初始化节点状态
|
||||
# 3. 找到起始节点
|
||||
# 4. 按拓扑顺序执行
|
||||
```
|
||||
|
||||
**调试执行器 (DebugWorkflowExecutor)**
|
||||
|
||||
继承自 WorkflowExecutor,增加了调试支持:
|
||||
|
||||
- 断点支持
|
||||
- 单步执行
|
||||
- 暂停/继续
|
||||
- 实时日志
|
||||
|
||||
```python
|
||||
class DebugWorkflowExecutor(WorkflowExecutor):
|
||||
async def execute_debug(
|
||||
self,
|
||||
workflow: WorkflowDefinition,
|
||||
context: ExecutionContext,
|
||||
debug_state: DebugExecutionState,
|
||||
) -> ExecutionContext:
|
||||
"""调试模式执行"""
|
||||
```
|
||||
|
||||
#### 2. 节点注册表 (NodeTypeRegistry)
|
||||
|
||||
位置:[`registry.py`](../../src/langbot/pkg/workflow/registry.py)
|
||||
|
||||
单例模式管理所有节点类型:
|
||||
|
||||
```python
|
||||
class NodeTypeRegistry:
|
||||
_instance: Optional['NodeTypeRegistry'] = None
|
||||
|
||||
def register(self, node_type: str, node_class: type[WorkflowNode]):
|
||||
"""注册节点类型"""
|
||||
|
||||
def create_instance(self, node_type: str, node_id: str, config: dict) -> WorkflowNode:
|
||||
"""创建节点实例"""
|
||||
|
||||
def list_all(self) -> list[dict]:
|
||||
"""获取所有节点类型的 Schema"""
|
||||
```
|
||||
|
||||
#### 3. 节点基类 (WorkflowNode)
|
||||
|
||||
位置:[`node.py`](../../src/langbot/pkg/workflow/node.py)
|
||||
|
||||
所有节点必须继承此基类:
|
||||
|
||||
```python
|
||||
class WorkflowNode(abc.ABC):
|
||||
# 节点元数据
|
||||
type_name: str = ""
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
category: str = "misc"
|
||||
icon: str = ""
|
||||
|
||||
# 端口定义
|
||||
inputs: list[NodePort] = []
|
||||
outputs: list[NodePort] = []
|
||||
|
||||
# 配置 Schema
|
||||
config_schema: list[NodeConfig] = []
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
context: ExecutionContext
|
||||
) -> dict[str, Any]:
|
||||
"""执行节点逻辑"""
|
||||
pass
|
||||
```
|
||||
|
||||
#### 4. 数据实体 (entities.py)
|
||||
|
||||
主要数据结构:
|
||||
|
||||
```python
|
||||
class WorkflowDefinition:
|
||||
"""工作流定义"""
|
||||
uuid: str
|
||||
name: str
|
||||
nodes: list[NodeDefinition]
|
||||
edges: list[EdgeDefinition]
|
||||
settings: WorkflowSettings
|
||||
|
||||
class ExecutionContext:
|
||||
"""执行上下文"""
|
||||
execution_id: str
|
||||
workflow_id: str
|
||||
status: ExecutionStatus
|
||||
variables: dict
|
||||
node_states: dict[str, NodeState]
|
||||
history: list[ExecutionStep]
|
||||
```
|
||||
|
||||
### 前端组件
|
||||
|
||||
#### 1. WorkflowEditorComponent
|
||||
|
||||
主编辑器组件,基于 React Flow 实现:
|
||||
|
||||
- **画布交互**:拖拽、缩放、平移
|
||||
- **节点连接**:自动验证端口类型
|
||||
- **撤销/重做**:基于历史记录栈
|
||||
- **复制/粘贴**:支持多选复制
|
||||
|
||||
关键功能:
|
||||
|
||||
```tsx
|
||||
function WorkflowEditorInner() {
|
||||
const { nodes, edges, onNodesChange, onEdgesChange, onConnect } = useWorkflowStore();
|
||||
|
||||
// 拖放添加节点
|
||||
const onDrop = useCallback((event: React.DragEvent) => {
|
||||
const type = event.dataTransfer.getData('application/reactflow');
|
||||
const position = screenToFlowPosition({ x: event.clientX, y: event.clientY });
|
||||
addNode(type, position);
|
||||
}, []);
|
||||
|
||||
// 复制粘贴
|
||||
const handleCopy = useCallback(() => { ... }, []);
|
||||
const handlePaste = useCallback(() => { ... }, []);
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. NodePalette
|
||||
|
||||
节点面板组件,展示可用节点类型:
|
||||
|
||||
```tsx
|
||||
function NodePalette() {
|
||||
// 按类别组织节点
|
||||
const categories = [
|
||||
{ id: 'trigger', name: '触发节点', icon: Zap },
|
||||
{ id: 'ai', name: 'AI 节点', icon: Brain },
|
||||
{ id: 'process', name: '处理节点', icon: Cpu },
|
||||
{ id: 'control', name: '控制节点', icon: GitBranch },
|
||||
{ id: 'action', name: '动作节点', icon: Send },
|
||||
{ id: 'integration', name: '集成节点', icon: Plug },
|
||||
];
|
||||
|
||||
// 拖拽开始
|
||||
const onDragStart = (event: React.DragEvent, nodeType: string) => {
|
||||
event.dataTransfer.setData('application/reactflow', nodeType);
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
#### 3. PropertyPanel
|
||||
|
||||
属性面板组件,动态渲染节点配置表单:
|
||||
|
||||
```tsx
|
||||
function PropertyPanel() {
|
||||
const { selectedNodeId, nodes, updateNodeData } = useWorkflowStore();
|
||||
|
||||
// 根据节点类型获取配置元数据
|
||||
const selectedNode = nodes.find(n => n.id === selectedNodeId);
|
||||
const nodeConfig = getNodeConfig(selectedNode?.data?.nodeType);
|
||||
|
||||
// 动态渲染配置字段
|
||||
return (
|
||||
<div>
|
||||
{nodeConfig?.fields.map(field => (
|
||||
<ConfigField key={field.name} field={field} />
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
#### 4. WorkflowDebugger
|
||||
|
||||
调试器组件,支持实时调试:
|
||||
|
||||
```tsx
|
||||
function WorkflowDebugger({ workflowUuid, workflow }) {
|
||||
const [debugState, setDebugState] = useState<DebugState>('idle');
|
||||
const [executionId, setExecutionId] = useState<string>('');
|
||||
const [logs, setLogs] = useState<ExecutionLog[]>([]);
|
||||
|
||||
// 启动调试
|
||||
const startDebug = async () => {
|
||||
const result = await backendClient.post(
|
||||
`/api/v1/workflows/${workflowUuid}/debug/start`,
|
||||
{ context, variables, breakpoints }
|
||||
);
|
||||
setExecutionId(result.execution_id);
|
||||
};
|
||||
|
||||
// 轮询状态
|
||||
useEffect(() => {
|
||||
if (debugState === 'running') {
|
||||
const interval = setInterval(fetchState, 500);
|
||||
return () => clearInterval(interval);
|
||||
}
|
||||
}, [debugState]);
|
||||
}
|
||||
```
|
||||
|
||||
#### 5. useWorkflowStore
|
||||
|
||||
Zustand 状态管理:
|
||||
|
||||
```typescript
|
||||
interface WorkflowState {
|
||||
nodes: WorkflowNode[];
|
||||
edges: WorkflowEdge[];
|
||||
selectedNodeId: string | null;
|
||||
history: HistoryEntry[];
|
||||
historyIndex: number;
|
||||
isDirty: boolean;
|
||||
|
||||
// Actions
|
||||
addNode: (type: string, position: XYPosition) => void;
|
||||
updateNodeData: (nodeId: string, data: Partial<NodeData>) => void;
|
||||
deleteNode: (nodeId: string) => void;
|
||||
undo: () => void;
|
||||
redo: () => void;
|
||||
}
|
||||
|
||||
export const useWorkflowStore = create<WorkflowState>((set, get) => ({
|
||||
// ... state and actions
|
||||
}));
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 数据库表结构
|
||||
|
||||
### workflows 表
|
||||
|
||||
```sql
|
||||
CREATE TABLE workflows (
|
||||
uuid VARCHAR(255) PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
emoji VARCHAR(10) DEFAULT '🔄',
|
||||
version INTEGER DEFAULT 1,
|
||||
is_enabled BOOLEAN DEFAULT TRUE,
|
||||
definition JSON NOT NULL, -- 节点和边定义
|
||||
global_config JSON DEFAULT '{}', -- 全局配置
|
||||
extensions_preferences JSON, -- 插件和 MCP 配置
|
||||
created_at TIMESTAMP,
|
||||
updated_at TIMESTAMP
|
||||
);
|
||||
```
|
||||
|
||||
### workflow_versions 表
|
||||
|
||||
```sql
|
||||
CREATE TABLE workflow_versions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
workflow_uuid VARCHAR(255) NOT NULL,
|
||||
version INTEGER NOT NULL,
|
||||
definition JSON NOT NULL,
|
||||
global_config JSON DEFAULT '{}',
|
||||
created_at TIMESTAMP,
|
||||
created_by VARCHAR(255),
|
||||
UNIQUE(workflow_uuid, version)
|
||||
);
|
||||
```
|
||||
|
||||
### workflow_executions 表
|
||||
|
||||
```sql
|
||||
CREATE TABLE workflow_executions (
|
||||
uuid VARCHAR(255) PRIMARY KEY,
|
||||
workflow_uuid VARCHAR(255) NOT NULL,
|
||||
workflow_version INTEGER NOT NULL,
|
||||
status VARCHAR(20) NOT NULL, -- pending/running/completed/failed/cancelled
|
||||
trigger_type VARCHAR(50),
|
||||
trigger_data JSON,
|
||||
variables JSON,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
error TEXT,
|
||||
created_at TIMESTAMP
|
||||
);
|
||||
```
|
||||
|
||||
### workflow_node_executions 表
|
||||
|
||||
```sql
|
||||
CREATE TABLE workflow_node_executions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
execution_uuid VARCHAR(255) NOT NULL,
|
||||
node_id VARCHAR(100) NOT NULL,
|
||||
node_type VARCHAR(50) NOT NULL,
|
||||
status VARCHAR(20) NOT NULL,
|
||||
inputs JSON,
|
||||
outputs JSON,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
error TEXT,
|
||||
retry_count INTEGER DEFAULT 0
|
||||
);
|
||||
```
|
||||
|
||||
### workflow_triggers 表
|
||||
|
||||
```sql
|
||||
CREATE TABLE workflow_triggers (
|
||||
uuid VARCHAR(255) PRIMARY KEY,
|
||||
workflow_uuid VARCHAR(255) NOT NULL,
|
||||
type VARCHAR(50) NOT NULL, -- message/cron/event/webhook
|
||||
config JSON NOT NULL,
|
||||
is_enabled BOOLEAN DEFAULT TRUE,
|
||||
priority INTEGER DEFAULT 0,
|
||||
created_at TIMESTAMP,
|
||||
updated_at TIMESTAMP
|
||||
);
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API 接口文档
|
||||
|
||||
### Workflow CRUD
|
||||
|
||||
| 方法 | 路径 | 描述 |
|
||||
|-----|------|------|
|
||||
| GET | `/api/v1/workflows` | 获取工作流列表 |
|
||||
| POST | `/api/v1/workflows` | 创建工作流 |
|
||||
| GET | `/api/v1/workflows/:uuid` | 获取单个工作流 |
|
||||
| PUT | `/api/v1/workflows/:uuid` | 更新工作流 |
|
||||
| DELETE | `/api/v1/workflows/:uuid` | 删除工作流 |
|
||||
| POST | `/api/v1/workflows/:uuid/copy` | 复制工作流 |
|
||||
|
||||
### 执行相关
|
||||
|
||||
| 方法 | 路径 | 描述 |
|
||||
|-----|------|------|
|
||||
| POST | `/api/v1/workflows/:uuid/execute` | 手动执行工作流 |
|
||||
| GET | `/api/v1/workflows/:uuid/executions` | 获取执行记录 |
|
||||
|
||||
### 版本管理
|
||||
|
||||
| 方法 | 路径 | 描述 |
|
||||
|-----|------|------|
|
||||
| GET | `/api/v1/workflows/:uuid/versions` | 获取版本列表 |
|
||||
| POST | `/api/v1/workflows/:uuid/rollback/:version` | 回滚到指定版本 |
|
||||
|
||||
### 调试 API
|
||||
|
||||
| 方法 | 路径 | 描述 |
|
||||
|-----|------|------|
|
||||
| POST | `/api/v1/workflows/:uuid/debug/start` | 启动调试 |
|
||||
| POST | `/api/v1/workflows/:uuid/debug/:exec_id/pause` | 暂停执行 |
|
||||
| POST | `/api/v1/workflows/:uuid/debug/:exec_id/resume` | 继续执行 |
|
||||
| POST | `/api/v1/workflows/:uuid/debug/:exec_id/stop` | 停止执行 |
|
||||
| POST | `/api/v1/workflows/:uuid/debug/:exec_id/step` | 单步执行 |
|
||||
| GET | `/api/v1/workflows/:uuid/debug/:exec_id/state` | 获取调试状态 |
|
||||
|
||||
### 节点类型
|
||||
|
||||
| 方法 | 路径 | 描述 |
|
||||
|-----|------|------|
|
||||
| GET | `/api/v1/workflows/_/node-types` | 获取所有节点类型 |
|
||||
| GET | `/api/v1/workflows/_/node-types/categories` | 按类别获取节点类型 |
|
||||
|
||||
---
|
||||
|
||||
## 如何添加新节点类型
|
||||
|
||||
### 步骤 1:创建节点类
|
||||
|
||||
在 `LangBot/src/langbot/pkg/workflow/nodes/` 下创建或修改文件:
|
||||
|
||||
```python
|
||||
from ..node import WorkflowNode, NodePort, NodeConfig, workflow_node
|
||||
from ..entities import ExecutionContext
|
||||
|
||||
@workflow_node('my_custom_node')
|
||||
class MyCustomNode(WorkflowNode):
|
||||
"""自定义节点"""
|
||||
|
||||
# 元数据
|
||||
type_name = 'my_custom_node'
|
||||
name = '我的自定义节点'
|
||||
description = '这是一个自定义节点'
|
||||
category = 'process' # trigger/process/control/action/integration
|
||||
icon = '🔧'
|
||||
|
||||
# 输入端口
|
||||
inputs = [
|
||||
NodePort(name='input', type='string', description='输入数据', required=True),
|
||||
]
|
||||
|
||||
# 输出端口
|
||||
outputs = [
|
||||
NodePort(name='output', type='string', description='输出数据'),
|
||||
]
|
||||
|
||||
# 配置字段
|
||||
config_schema = [
|
||||
NodeConfig(
|
||||
name='option',
|
||||
type='select',
|
||||
required=True,
|
||||
options=['选项A', '选项B'],
|
||||
description='选择一个选项'
|
||||
),
|
||||
NodeConfig(
|
||||
name='value',
|
||||
type='string',
|
||||
required=False,
|
||||
default='默认值',
|
||||
description='配置值'
|
||||
),
|
||||
]
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
context: ExecutionContext
|
||||
) -> dict[str, Any]:
|
||||
"""执行节点逻辑"""
|
||||
input_data = inputs.get('input', '')
|
||||
option = self.get_config('option')
|
||||
value = self.get_config('value', '')
|
||||
|
||||
# 处理逻辑
|
||||
result = f"处理: {input_data} with {option} and {value}"
|
||||
|
||||
return {'output': result}
|
||||
```
|
||||
|
||||
### 步骤 2:注册节点
|
||||
|
||||
在 `LangBot/src/langbot/pkg/workflow/nodes/__init__.py` 中导入:
|
||||
|
||||
```python
|
||||
from .process import (
|
||||
CodeExecutorNode,
|
||||
HttpRequestNode,
|
||||
DataTransformNode,
|
||||
MyCustomNode, # 添加新节点
|
||||
)
|
||||
```
|
||||
|
||||
### 步骤 3:添加前端配置
|
||||
|
||||
在 `LangBot/web/src/app/home/workflows/components/workflow-editor/node-configs/` 目录下添加配置:
|
||||
|
||||
```typescript
|
||||
// process-configs.ts
|
||||
export const processNodeConfigs: NodeConfigMap = {
|
||||
// ... 其他配置
|
||||
|
||||
my_custom_node: {
|
||||
type: 'my_custom_node',
|
||||
label: 'workflows.nodes.myCustomNode',
|
||||
description: 'workflows.nodes.myCustomNodeDesc',
|
||||
icon: 'Wrench',
|
||||
category: 'process',
|
||||
fields: [
|
||||
{
|
||||
name: 'option',
|
||||
type: 'select',
|
||||
label: 'workflows.fields.option',
|
||||
required: true,
|
||||
options: [
|
||||
{ value: '选项A', label: '选项 A' },
|
||||
{ value: '选项B', label: '选项 B' },
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'value',
|
||||
type: 'string',
|
||||
label: 'workflows.fields.value',
|
||||
required: false,
|
||||
defaultValue: '默认值',
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
```
|
||||
|
||||
### 步骤 4:添加国际化
|
||||
|
||||
在 `LangBot/web/src/i18n/locales/` 中添加翻译:
|
||||
|
||||
```typescript
|
||||
// zh-Hans.ts
|
||||
workflows: {
|
||||
nodes: {
|
||||
myCustomNode: '我的自定义节点',
|
||||
myCustomNodeDesc: '这是一个自定义节点',
|
||||
},
|
||||
fields: {
|
||||
option: '选项',
|
||||
value: '值',
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 调试功能实现
|
||||
|
||||
### 后端调试状态管理
|
||||
|
||||
```python
|
||||
class DebugExecutionState:
|
||||
"""调试执行状态"""
|
||||
|
||||
def __init__(self, execution_id: str, breakpoints: list[str] = None):
|
||||
self.execution_id = execution_id
|
||||
self.status: str = 'running'
|
||||
self.is_paused: bool = False
|
||||
self.is_stopped: bool = False
|
||||
self.breakpoints: set[str] = set(breakpoints or [])
|
||||
self.logs: list[ExecutionLog] = []
|
||||
self._pause_event = asyncio.Event()
|
||||
|
||||
def pause(self):
|
||||
"""暂停执行"""
|
||||
self.is_paused = True
|
||||
self._pause_event.clear()
|
||||
|
||||
def resume(self):
|
||||
"""继续执行"""
|
||||
self.is_paused = False
|
||||
self._pause_event.set()
|
||||
|
||||
async def wait_if_paused(self):
|
||||
"""如果暂停则等待"""
|
||||
if self.is_paused:
|
||||
await self._pause_event.wait()
|
||||
```
|
||||
|
||||
### 前端调试流程
|
||||
|
||||
1. **设置断点**:点击节点设置断点
|
||||
2. **启动调试**:调用 `/debug/start` 启动调试执行
|
||||
3. **轮询状态**:定期调用 `/debug/:id/state` 获取状态
|
||||
4. **控制执行**:调用 pause/resume/step/stop 控制执行
|
||||
5. **查看日志**:实时显示执行日志和节点状态
|
||||
|
||||
```typescript
|
||||
// 调试状态轮询
|
||||
const fetchDebugState = async () => {
|
||||
const state = await backendClient.get(
|
||||
`/api/v1/workflows/${workflowUuid}/debug/${executionId}/state`
|
||||
);
|
||||
|
||||
// 更新节点状态
|
||||
setNodeStates(state.node_states);
|
||||
|
||||
// 追加新日志
|
||||
if (state.new_logs.length > 0) {
|
||||
setLogs(prev => [...prev, ...state.new_logs]);
|
||||
}
|
||||
|
||||
// 检查完成状态
|
||||
if (state.status === 'completed' || state.status === 'error') {
|
||||
setDebugState('idle');
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 扩展阅读
|
||||
|
||||
- [Workflow 功能设计文档](../../../plans/langbot-workflow-design.md)
|
||||
- [用户使用指南](../user-guide/workflow-guide.md)
|
||||
- [API 认证文档](../API_KEY_AUTH.md)
|
||||
425
docs/user-guide/workflow-guide.md
Normal file
425
docs/user-guide/workflow-guide.md
Normal file
@@ -0,0 +1,425 @@
|
||||
# Workflow 用户指南
|
||||
|
||||
本文档帮助您了解和使用 LangBot 的 Workflow(工作流)功能,通过可视化方式构建自动化的对话处理流程。
|
||||
|
||||
## 目录
|
||||
|
||||
- [功能介绍](#功能介绍)
|
||||
- [快速入门](#快速入门)
|
||||
- [节点类型说明](#节点类型说明)
|
||||
- [编辑器使用指南](#编辑器使用指南)
|
||||
- [调试功能](#调试功能)
|
||||
- [常见问题解答](#常见问题解答)
|
||||
|
||||
---
|
||||
|
||||
## 功能介绍
|
||||
|
||||
### 什么是 Workflow?
|
||||
|
||||
Workflow(工作流)是 LangBot 提供的可视化自动化编排系统。通过拖拽节点、连接边的方式,您可以:
|
||||
|
||||
- 📝 **构建复杂的对话流程**:使用条件分支、循环等控制节点
|
||||
- 🤖 **调用 AI 能力**:集成 LLM、知识库检索、参数提取
|
||||
- 🔗 **连接外部服务**:集成 Dify、n8n、Coze 等平台
|
||||
- ⚡ **自动化任务执行**:消息触发、定时触发、Webhook 触发
|
||||
|
||||
### Workflow vs Pipeline
|
||||
|
||||
| 对比项 | Pipeline | Workflow |
|
||||
|-------|----------|----------|
|
||||
| 配置方式 | 表单配置 | 可视化拖拽 |
|
||||
| 流程控制 | 线性执行 | 支持分支、循环、并行 |
|
||||
| 适用场景 | 简单对话 | 复杂流程 |
|
||||
| 学习曲线 | 低 | 中等 |
|
||||
|
||||
---
|
||||
|
||||
## 快速入门
|
||||
|
||||
### 第一步:创建 Workflow
|
||||
|
||||
1. 在侧边栏点击 **Workflow** 进入工作流列表
|
||||
2. 点击右上角 **创建工作流** 按钮
|
||||
3. 填写基本信息:
|
||||
- **名称**:给工作流起一个描述性的名字
|
||||
- **描述**:可选,说明工作流的用途
|
||||
- **图标**:选择一个 emoji 作为标识
|
||||
|
||||
### 第二步:添加节点
|
||||
|
||||
进入编辑器后,左侧是节点面板,中间是画布区域,右侧是属性面板。
|
||||
|
||||
1. **添加触发节点**:从左侧面板拖拽一个"消息触发"节点到画布
|
||||
2. **添加 AI 节点**:拖拽一个"LLM 调用"节点
|
||||
3. **添加回复节点**:拖拽一个"回复消息"节点
|
||||
|
||||
### 第三步:连接节点
|
||||
|
||||
1. 将鼠标悬停在触发节点的输出端口(右侧小圆点)
|
||||
2. 按住鼠标拖拽到 LLM 节点的输入端口(左侧小圆点)
|
||||
3. 同样方式连接 LLM 节点和回复节点
|
||||
|
||||
```
|
||||
[消息触发] ──▶ [LLM 调用] ──▶ [回复消息]
|
||||
```
|
||||
|
||||
### 第四步:配置节点
|
||||
|
||||
点击 LLM 调用节点,在右侧属性面板配置:
|
||||
|
||||
- **运行方式**:选择"本地 Agent"
|
||||
- **系统提示词**:描述 AI 的角色和行为
|
||||
- **模型**:选择要使用的 LLM 模型
|
||||
|
||||
点击回复消息节点配置:
|
||||
|
||||
- **消息内容**:设置为 `{{nodes.llm_call.outputs.response}}`(引用 LLM 输出)
|
||||
|
||||
### 第五步:保存并绑定
|
||||
|
||||
1. 点击工具栏的 **保存** 按钮
|
||||
2. 返回 Bot 配置页面
|
||||
3. 在 Bot 的绑定设置中选择 **Workflow**,然后选择刚创建的工作流
|
||||
|
||||
恭喜!您已经创建了第一个 Workflow。
|
||||
|
||||
---
|
||||
|
||||
## 节点类型说明
|
||||
|
||||
### 触发节点 (Trigger)
|
||||
|
||||
触发节点是工作流的入口,定义何时启动执行。
|
||||
|
||||
| 节点 | 说明 | 输出 |
|
||||
|-----|------|------|
|
||||
| 消息触发 | 收到消息时触发 | message, sender_id, platform |
|
||||
| 定时触发 | 按 Cron 表达式定时触发 | timestamp |
|
||||
| Webhook 触发 | 收到 HTTP 请求时触发 | request_body, headers |
|
||||
| 事件触发 | 系统事件触发 | event_type, event_data |
|
||||
|
||||
**消息触发配置示例**:
|
||||
|
||||
```yaml
|
||||
触发条件:
|
||||
- 关键词匹配: ["帮助", "help"]
|
||||
- 平台: ["wechat", "qq"]
|
||||
```
|
||||
|
||||
### AI 节点
|
||||
|
||||
AI 节点用于调用各种 AI 能力。
|
||||
|
||||
| 节点 | 说明 | 典型用途 |
|
||||
|-----|------|---------|
|
||||
| LLM 调用 | 调用大语言模型 | 生成回复、理解意图 |
|
||||
| 问题分类器 | 对用户问题分类 | 路由到不同处理分支 |
|
||||
| 参数提取器 | 从文本提取结构化数据 | 提取订单号、日期等 |
|
||||
| 知识库检索 | 查询知识库 | RAG 增强回复 |
|
||||
|
||||
**LLM 调用配置示例**:
|
||||
|
||||
```yaml
|
||||
运行方式: 本地 Agent
|
||||
模型: gpt-4
|
||||
系统提示词: |
|
||||
你是一个友好的客服助手。
|
||||
请根据用户的问题提供帮助。
|
||||
温度: 0.7
|
||||
最大 Token 数: 2000
|
||||
```
|
||||
|
||||
### 处理节点 (Process)
|
||||
|
||||
处理节点用于数据处理和外部调用。
|
||||
|
||||
| 节点 | 说明 | 典型用途 |
|
||||
|-----|------|---------|
|
||||
| 代码执行 | 执行 Python/JavaScript 代码 | 数据处理、格式转换 |
|
||||
| HTTP 请求 | 发送 HTTP 请求 | 调用外部 API |
|
||||
| 数据转换 | JSON/模板转换 | 数据格式化 |
|
||||
|
||||
**HTTP 请求配置示例**:
|
||||
|
||||
```yaml
|
||||
URL: https://api.example.com/data
|
||||
方法: POST
|
||||
请求头:
|
||||
Content-Type: application/json
|
||||
Authorization: Bearer {{variables.api_key}}
|
||||
请求体: |
|
||||
{"query": "{{message.content}}"}
|
||||
```
|
||||
|
||||
### 控制节点 (Control)
|
||||
|
||||
控制节点用于流程控制。
|
||||
|
||||
| 节点 | 说明 | 用途 |
|
||||
|-----|------|------|
|
||||
| 条件分支 | 二选一分支 | if-else 逻辑 |
|
||||
| 多路分支 | 多选一分支 | switch-case 逻辑 |
|
||||
| 循环 | 遍历数组 | 批量处理 |
|
||||
| 并行 | 同时执行多分支 | 并发处理 |
|
||||
| 等待 | 暂停执行 | 延时处理 |
|
||||
| 合并 | 合并多个分支 | 汇总结果 |
|
||||
|
||||
**条件分支配置示例**:
|
||||
|
||||
```yaml
|
||||
条件表达式: "{{nodes.classifier.outputs.category}}" == "complaint"
|
||||
真分支: 投诉处理
|
||||
假分支: 普通咨询
|
||||
```
|
||||
|
||||
### 动作节点 (Action)
|
||||
|
||||
动作节点执行具体操作。
|
||||
|
||||
| 节点 | 说明 | 用途 |
|
||||
|-----|------|------|
|
||||
| 发送消息 | 主动发送消息 | 通知、推送 |
|
||||
| 回复消息 | 回复当前消息 | 对话回复 |
|
||||
| 存储数据 | 保存数据到存储 | 持久化 |
|
||||
| 调用 Pipeline | 调用现有 Pipeline | 复用现有流程 |
|
||||
|
||||
**回复消息配置示例**:
|
||||
|
||||
```yaml
|
||||
消息内容: |
|
||||
感谢您的咨询!
|
||||
|
||||
{{nodes.llm_call.outputs.response}}
|
||||
|
||||
如有其他问题,随时联系我。
|
||||
```
|
||||
|
||||
### 集成节点 (Integration)
|
||||
|
||||
集成节点连接外部平台。
|
||||
|
||||
| 节点 | 说明 | 平台 |
|
||||
|-----|------|------|
|
||||
| Dify 工作流 | 调用 Dify 应用 | Dify |
|
||||
| Dify 知识库 | 查询 Dify 知识库 | Dify |
|
||||
| n8n 工作流 | 调用 n8n 流程 | n8n |
|
||||
| Langflow | 调用 Langflow 流程 | Langflow |
|
||||
| Coze Bot | 调用扣子 Bot | Coze |
|
||||
|
||||
**Dify 工作流配置示例**:
|
||||
|
||||
```yaml
|
||||
API 地址: https://api.dify.ai/v1
|
||||
API Key: sk-xxxxx
|
||||
应用类型: workflow
|
||||
同步对话历史: true
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 编辑器使用指南
|
||||
|
||||
### 画布操作
|
||||
|
||||
| 操作 | 方式 |
|
||||
|-----|------|
|
||||
| 平移画布 | 按住鼠标中键/空格+左键 拖拽 |
|
||||
| 缩放画布 | 鼠标滚轮 / 工具栏按钮 |
|
||||
| 框选多个节点 | 按住 Shift + 拖拽框选 |
|
||||
| 适应视图 | 点击工具栏"适应"按钮 |
|
||||
|
||||
### 节点操作
|
||||
|
||||
| 操作 | 方式 |
|
||||
|-----|------|
|
||||
| 添加节点 | 从左侧面板拖拽到画布 |
|
||||
| 移动节点 | 点击节点拖拽 |
|
||||
| 删除节点 | 选中后按 Delete / 点击工具栏删除 |
|
||||
| 复制节点 | 选中后 Ctrl+C / 工具栏复制 |
|
||||
| 粘贴节点 | Ctrl+V / 工具栏粘贴 |
|
||||
|
||||
### 连接操作
|
||||
|
||||
| 操作 | 方式 |
|
||||
|-----|------|
|
||||
| 创建连接 | 从输出端口拖拽到输入端口 |
|
||||
| 删除连接 | 点击连接线后按 Delete |
|
||||
| 选中连接 | 点击连接线 |
|
||||
|
||||
### 快捷键
|
||||
|
||||
| 快捷键 | 功能 |
|
||||
|-------|------|
|
||||
| Ctrl + Z | 撤销 |
|
||||
| Ctrl + Shift + Z | 重做 |
|
||||
| Ctrl + C | 复制 |
|
||||
| Ctrl + V | 粘贴 |
|
||||
| Delete | 删除选中 |
|
||||
| Ctrl + S | 保存 |
|
||||
|
||||
### 工具栏功能
|
||||
|
||||
```
|
||||
[撤销] [重做] | [放大] [缩小] [适应] | [复制] [粘贴] [删除] | [保存] [调试]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 调试功能
|
||||
|
||||
### 启动调试
|
||||
|
||||
1. 点击工具栏的 **调试** 按钮
|
||||
2. 在调试面板中配置初始数据:
|
||||
- **输入消息**:模拟用户发送的消息
|
||||
- **会话 ID**:可选,用于测试会话变量
|
||||
- **变量**:设置初始变量值
|
||||
|
||||
3. 点击 **开始调试** 按钮
|
||||
|
||||
### 调试控制
|
||||
|
||||
| 按钮 | 功能 |
|
||||
|-----|------|
|
||||
| ▶️ 开始/继续 | 开始或继续执行 |
|
||||
| ⏸️ 暂停 | 暂停执行 |
|
||||
| ⏹️ 停止 | 停止执行 |
|
||||
| ⏭️ 单步 | 执行下一个节点 |
|
||||
|
||||
### 断点
|
||||
|
||||
- **设置断点**:点击节点上的断点图标
|
||||
- **断点触发**:执行到断点时自动暂停
|
||||
- **查看状态**:在暂停时查看节点的输入输出
|
||||
|
||||
### 执行日志
|
||||
|
||||
调试面板下方显示实时日志:
|
||||
|
||||
```
|
||||
[INFO] 2024-01-15 10:30:00 - Starting debug execution
|
||||
[INFO] 2024-01-15 10:30:00 - Executing node: message_trigger
|
||||
[DEBUG] 2024-01-15 10:30:00 - Node inputs: {"message": "你好"}
|
||||
[INFO] 2024-01-15 10:30:01 - Node completed in 50ms
|
||||
[INFO] 2024-01-15 10:30:01 - Executing node: llm_call
|
||||
...
|
||||
```
|
||||
|
||||
### 节点状态颜色
|
||||
|
||||
| 颜色 | 状态 |
|
||||
|-----|------|
|
||||
| 灰色 | 待执行 |
|
||||
| 蓝色 | 执行中 |
|
||||
| 绿色 | 已完成 |
|
||||
| 红色 | 失败 |
|
||||
| 黄色 | 已跳过 |
|
||||
|
||||
---
|
||||
|
||||
## 常见问题解答
|
||||
|
||||
### Q1:如何在节点间传递数据?
|
||||
|
||||
使用表达式语法引用其他节点的输出:
|
||||
|
||||
```
|
||||
{{nodes.节点ID.outputs.输出名称}}
|
||||
```
|
||||
|
||||
例如:
|
||||
- `{{nodes.llm_call.outputs.response}}` - 引用 LLM 节点的响应
|
||||
- `{{nodes.http_request.outputs.body}}` - 引用 HTTP 请求的响应体
|
||||
|
||||
### Q2:如何使用变量?
|
||||
|
||||
Workflow 支持三种变量类型:
|
||||
|
||||
1. **工作流变量**:`{{variables.变量名}}`
|
||||
2. **会话变量**:`{{conversation_variables.变量名}}`
|
||||
3. **消息上下文**:`{{message.content}}`、`{{message.sender_id}}`
|
||||
|
||||
### Q3:条件分支如何写条件表达式?
|
||||
|
||||
支持以下运算符:
|
||||
|
||||
- 比较:`==`, `!=`, `>`, `<`, `>=`, `<=`
|
||||
- 逻辑:`and`, `or`, `not`
|
||||
- 包含:`in`
|
||||
|
||||
示例:
|
||||
```python
|
||||
# 字符串比较
|
||||
"{{nodes.classifier.outputs.intent}}" == "purchase"
|
||||
|
||||
# 数值比较
|
||||
{{nodes.extractor.outputs.amount}} > 1000
|
||||
|
||||
# 包含检查
|
||||
"退款" in "{{message.content}}"
|
||||
```
|
||||
|
||||
### Q4:如何处理错误?
|
||||
|
||||
1. **节点级重试**:在节点配置中设置重试次数
|
||||
2. **全局错误处理**:在 Workflow 设置中配置错误处理策略
|
||||
3. **条件分支**:使用条件节点检查上一节点的状态
|
||||
|
||||
### Q5:如何查看执行历史?
|
||||
|
||||
1. 进入 Workflow 详情页
|
||||
2. 点击 **执行历史** 标签
|
||||
3. 查看每次执行的状态、耗时、输入输出
|
||||
|
||||
### Q6:Workflow 可以被多个 Bot 使用吗?
|
||||
|
||||
是的。一个 Workflow 可以被多个 Bot 绑定使用,但每个 Bot 只能绑定一个处理单元(Pipeline 或 Workflow)。
|
||||
|
||||
### Q7:如何复制现有的 Workflow?
|
||||
|
||||
在 Workflow 列表页,点击工作流卡片右上角的菜单,选择"复制"即可创建副本。
|
||||
|
||||
### Q8:支持版本回滚吗?
|
||||
|
||||
支持。每次保存都会创建新版本。在 Workflow 详情页可以查看版本历史并回滚到指定版本。
|
||||
|
||||
---
|
||||
|
||||
## 最佳实践
|
||||
|
||||
### 1. 合理命名
|
||||
|
||||
- 为节点和 Workflow 使用描述性名称
|
||||
- 使用统一的命名规范
|
||||
|
||||
### 2. 模块化设计
|
||||
|
||||
- 将复杂流程拆分为多个小 Workflow
|
||||
- 使用"调用 Pipeline"节点复用现有流程
|
||||
|
||||
### 3. 错误处理
|
||||
|
||||
- 为关键节点设置重试机制
|
||||
- 使用条件分支处理异常情况
|
||||
- 添加日志记录便于排查问题
|
||||
|
||||
### 4. 测试先行
|
||||
|
||||
- 使用调试功能充分测试
|
||||
- 准备多种测试场景
|
||||
- 检查边界情况
|
||||
|
||||
### 5. 性能优化
|
||||
|
||||
- 避免不必要的节点
|
||||
- 使用并行节点提高效率
|
||||
- 合理设置超时时间
|
||||
|
||||
---
|
||||
|
||||
## 更多资源
|
||||
|
||||
- [开发者文档](../development/workflow-system.md)
|
||||
- [设计文档](../../../plans/langbot-workflow-design.md)
|
||||
- [API 文档](../service-api-openapi.json)
|
||||
1468
node_comparison.json
Normal file
1468
node_comparison.json
Normal file
File diff suppressed because it is too large
Load Diff
3791
plans/translation-analysis-report.txt
Normal file
3791
plans/translation-analysis-report.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "langbot"
|
||||
version = "4.9.0"
|
||||
version = "4.9.7"
|
||||
description = "Production-grade platform for building agentic IM bots"
|
||||
readme = "README.md"
|
||||
license-files = ["LICENSE"]
|
||||
@@ -8,7 +8,7 @@ requires-python = ">=3.11,<4.0"
|
||||
dependencies = [
|
||||
"aiocqhttp>=1.4.4",
|
||||
"aiofiles>=24.1.0",
|
||||
"aiohttp>=3.11.18",
|
||||
"aiohttp>=3.13.4",
|
||||
"aioshutil>=1.5",
|
||||
"aiosqlite>=0.21.0",
|
||||
"anthropic>=0.51.0",
|
||||
@@ -16,18 +16,18 @@ dependencies = [
|
||||
"async-lru>=2.0.5",
|
||||
"certifi>=2025.4.26",
|
||||
"colorlog~=6.6.0",
|
||||
"cryptography>=44.0.3",
|
||||
"cryptography>=46.0.7",
|
||||
"dashscope>=1.25.10",
|
||||
"dingtalk-stream>=0.24.0",
|
||||
"discord-py>=2.5.2",
|
||||
"pynacl>=1.5.0", # Required for Discord voice support
|
||||
"gewechat-client>=0.1.5",
|
||||
"lark-oapi>=1.4.15",
|
||||
"lark-oapi>=1.5.5",
|
||||
"mcp>=1.25.0",
|
||||
"nakuru-project-idk>=0.0.2.1",
|
||||
"ollama>=0.4.8",
|
||||
"openai>1.0.0",
|
||||
"pillow>=11.2.1",
|
||||
"pillow>=12.2.0",
|
||||
"psutil>=7.0.0",
|
||||
"pycryptodome>=3.22.0",
|
||||
"pydantic>2.0",
|
||||
@@ -35,10 +35,12 @@ dependencies = [
|
||||
"python-telegram-bot>=22.0",
|
||||
"pyyaml>=6.0.2",
|
||||
"qq-botpy-rc>=1.2.1.6",
|
||||
"qrcode>=7.4",
|
||||
"quart>=0.20.0",
|
||||
"quart-cors>=0.8.0",
|
||||
"requests>=2.32.3",
|
||||
"slack-sdk>=3.35.0",
|
||||
"alembic>=1.15.0",
|
||||
"sqlalchemy[asyncio]>=2.0.40",
|
||||
"sqlmodel>=0.0.24",
|
||||
"telegramify-markdown>=0.5.1",
|
||||
@@ -49,7 +51,7 @@ dependencies = [
|
||||
"pip>=25.1.1",
|
||||
"ruff>=0.11.9",
|
||||
"pre-commit>=4.2.0",
|
||||
"uv>=0.7.11",
|
||||
"uv>=0.11.6",
|
||||
"mypy>=1.16.0",
|
||||
"PyPDF2>=3.0.1",
|
||||
"python-docx>=1.1.0",
|
||||
@@ -60,13 +62,18 @@ dependencies = [
|
||||
"ebooklib>=0.18",
|
||||
"html2text>=2024.2.26",
|
||||
"langchain>=0.2.0",
|
||||
"langchain-text-splitters>=0.0.1",
|
||||
"chromadb>=0.4.24",
|
||||
"langchain-core>=1.2.28",
|
||||
"langsmith>=0.7.31",
|
||||
"python-multipart>=0.0.26",
|
||||
"Mako>=1.3.11",
|
||||
"langchain-text-splitters>=1.1.2",
|
||||
"chromadb>=1.0.0,<2.0.0",
|
||||
"qdrant-client (>=1.15.1,<2.0.0)",
|
||||
"pyseekdb==1.1.0.post3",
|
||||
"langbot-plugin==0.3.0",
|
||||
"langbot-plugin @ file:///home/typer/Desktop/langbot-plugin-sdk",
|
||||
"asyncpg>=0.30.0",
|
||||
"line-bot-sdk>=3.19.0",
|
||||
"matrix-nio>=0.25.2",
|
||||
"tboxsdk>=0.0.10",
|
||||
"boto3>=1.35.0",
|
||||
"pymilvus>=2.6.4",
|
||||
@@ -111,12 +118,13 @@ requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
package-data = { "langbot" = ["templates/**", "pkg/provider/modelmgr/requesters/*", "pkg/platform/sources/*", "web/out/**"] }
|
||||
package-data = { "langbot" = ["templates/**", "pkg/provider/modelmgr/requesters/*", "pkg/platform/sources/*", "web/dist/**", "pkg/persistence/alembic/**"] }
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"moto>=5.2.1",
|
||||
"pre-commit>=4.2.0",
|
||||
"pytest>=8.4.1",
|
||||
"pytest>=9.0.3",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
"pytest-cov>=7.0.0",
|
||||
"ruff>=0.11.9",
|
||||
|
||||
@@ -4,6 +4,9 @@ python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
|
||||
# Python path for imports
|
||||
pythonpath = . tests
|
||||
|
||||
# Test paths
|
||||
testpaths = tests
|
||||
|
||||
@@ -22,7 +25,9 @@ markers =
|
||||
asyncio: mark test as async
|
||||
unit: mark test as unit test
|
||||
integration: mark test as integration test
|
||||
smoke: mark test as smoke test
|
||||
slow: mark test as slow running
|
||||
e2e: mark test as end-to-end test (requires real LangBot process)
|
||||
|
||||
# Coverage options (when using pytest-cov)
|
||||
[coverage:run]
|
||||
|
||||
65
scripts/test-coverage.sh
Executable file
65
scripts/test-coverage.sh
Executable file
@@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Coverage gate script
|
||||
# Runs all tests with coverage, enforcing minimum coverage threshold
|
||||
# Uses separate pytest invocations to avoid sys.modules pollution between test types
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
echo "=== LangBot Coverage Gate ==="
|
||||
echo ""
|
||||
|
||||
# Coverage threshold (baseline from current coverage, conservative buffer)
|
||||
# Current: ~22.14%, threshold: 18%
|
||||
COVERAGE_THRESHOLD=18
|
||||
|
||||
# Create temporary directory for coverage files
|
||||
COV_DIR=$(mktemp -d)
|
||||
trap "rm -rf $COV_DIR" EXIT
|
||||
|
||||
echo "[1/3] Running unit + smoke tests with coverage..."
|
||||
uv run pytest tests/unit_tests/ tests/smoke/ \
|
||||
--cov=langbot \
|
||||
--cov-report=json:$COV_DIR/unit.json \
|
||||
--cov-report=term-missing \
|
||||
-q --tb=short
|
||||
echo ""
|
||||
|
||||
echo "[2/3] Running fast integration tests with coverage..."
|
||||
uv run pytest tests/integration/ -m "not slow" \
|
||||
--cov=langbot \
|
||||
--cov-report=json:$COV_DIR/integration.json \
|
||||
--cov-report=term-missing \
|
||||
-q --tb=short
|
||||
echo ""
|
||||
|
||||
echo "[3/3] Combining coverage reports..."
|
||||
# Use coverage combine if available, otherwise just report total
|
||||
if command -v coverage &> /dev/null; then
|
||||
# Combine JSON reports
|
||||
coverage combine --keep $COV_DIR/unit.json $COV_DIR/integration.json \
|
||||
--data-file=$COV_DIR/combined.data 2>/dev/null || true
|
||||
|
||||
coverage report --data-file=$COV_DIR/combined.data || true
|
||||
else
|
||||
echo "Note: coverage combine not available, showing individual reports above"
|
||||
fi
|
||||
|
||||
# Generate final XML report for CI (from last run)
|
||||
uv run pytest tests/unit_tests/ tests/smoke/ \
|
||||
--cov=langbot \
|
||||
--cov-report=xml:coverage.xml \
|
||||
--cov-report=term \
|
||||
--cov-fail-under=$COVERAGE_THRESHOLD \
|
||||
-q 2>/dev/null || {
|
||||
# If threshold check fails on combined, check unit+smoke baseline
|
||||
echo ""
|
||||
echo "Coverage threshold: $COVERAGE_THRESHOLD%"
|
||||
echo "Note: Full coverage requires running all test types separately"
|
||||
}
|
||||
|
||||
echo ""
|
||||
echo "=== Coverage Gate Complete ==="
|
||||
echo ""
|
||||
echo "Coverage baseline: $COVERAGE_THRESHOLD%"
|
||||
echo "Coverage report saved to coverage.xml"
|
||||
16
scripts/test-integration-fast.sh
Executable file
16
scripts/test-integration-fast.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Fast integration tests
|
||||
# Runs integration tests excluding slow ones (PostgreSQL, external services)
|
||||
# Uses fake runner/provider, no real credentials needed
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
echo "=== LangBot Fast Integration Tests ==="
|
||||
echo ""
|
||||
|
||||
echo "Running integration tests (excluding slow)..."
|
||||
uv run pytest tests/integration/ -m "not slow" -q --tb=short
|
||||
|
||||
echo ""
|
||||
echo "=== Fast Integration Tests Complete ==="
|
||||
36
scripts/test-quick.sh
Executable file
36
scripts/test-quick.sh
Executable file
@@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Quick developer self-test command
|
||||
# Runs linting, unit tests, and smoke tests without requiring real provider keys
|
||||
# Suitable for local branch validation
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
echo "=== LangBot Quick Self-Test ==="
|
||||
echo ""
|
||||
|
||||
# 1. Ruff check
|
||||
echo "[1/3] Running ruff check..."
|
||||
uv run ruff check src/langbot/ tests/ --output-format=concise || {
|
||||
echo ""
|
||||
echo "⚠ Ruff check found issues. Run 'uv run ruff check --fix' to auto-fix."
|
||||
exit 1
|
||||
}
|
||||
echo "✓ Ruff check passed"
|
||||
echo ""
|
||||
|
||||
# 2. Unit tests
|
||||
echo "[2/3] Running unit tests..."
|
||||
uv run pytest tests/unit_tests/ -q --tb=short
|
||||
echo ""
|
||||
|
||||
# 3. Smoke tests (if exists)
|
||||
echo "[3/3] Running smoke tests..."
|
||||
if [ -d "tests/smoke" ]; then
|
||||
uv run pytest tests/smoke/ -q --tb=short
|
||||
else
|
||||
echo "No smoke tests found, skipping"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "=== Quick Self-Test Complete ==="
|
||||
@@ -1,3 +1,3 @@
|
||||
"""LangBot - Production-grade platform for building agentic IM bots"""
|
||||
|
||||
__version__ = '4.9.0'
|
||||
__version__ = '4.9.7'
|
||||
|
||||
@@ -182,6 +182,88 @@ class DingTalkClient:
|
||||
for handler in self._message_handlers[msg_type]:
|
||||
await handler(event)
|
||||
|
||||
async def _parse_quoted_message(self, replied_msg: dict) -> dict:
|
||||
"""Parse the quoted/replied message and extract its content.
|
||||
|
||||
Args:
|
||||
replied_msg: The repliedMsg object from DingTalk message
|
||||
|
||||
Returns:
|
||||
A dict containing the quoted message info with keys:
|
||||
- message_id: The original message ID
|
||||
- msg_type: The message type (text, file, picture, audio, etc.)
|
||||
- content: The text content (if any)
|
||||
- file_url: The file download URL (if file type)
|
||||
- file_name: The file name (if file type)
|
||||
- picture: The picture base64 (if picture type)
|
||||
- audio: The audio base64 (if audio type)
|
||||
"""
|
||||
quote_info = {
|
||||
'message_id': replied_msg.get('msgId', ''),
|
||||
'msg_type': replied_msg.get('msgType', ''),
|
||||
'sender_id': replied_msg.get('senderId', ''),
|
||||
}
|
||||
|
||||
msg_type = replied_msg.get('msgType', '')
|
||||
content = replied_msg.get('content', {})
|
||||
|
||||
# Handle content as string (JSON) or dict
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
content = {}
|
||||
|
||||
if msg_type == 'text':
|
||||
# Text message
|
||||
if isinstance(content, dict):
|
||||
quote_info['content'] = content.get('content', '')
|
||||
else:
|
||||
quote_info['content'] = str(content)
|
||||
|
||||
elif msg_type == 'file':
|
||||
# File message
|
||||
download_code = content.get('downloadCode')
|
||||
file_name = content.get('fileName')
|
||||
if download_code and file_name:
|
||||
try:
|
||||
quote_info['file_url'] = await self.get_file_url(download_code)
|
||||
quote_info['file_name'] = file_name
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
await self.logger.error(f'Failed to get quoted file URL: {e}')
|
||||
|
||||
elif msg_type == 'picture':
|
||||
# Picture message
|
||||
download_code = content.get('downloadCode')
|
||||
if download_code:
|
||||
try:
|
||||
quote_info['picture'] = await self.download_image(download_code)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
await self.logger.error(f'Failed to download quoted image: {e}')
|
||||
|
||||
elif msg_type == 'audio':
|
||||
# Audio message
|
||||
download_code = content.get('downloadCode')
|
||||
if download_code:
|
||||
try:
|
||||
quote_info['audio'] = await self.get_audio_url(download_code)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
await self.logger.error(f'Failed to get quoted audio: {e}')
|
||||
|
||||
elif msg_type == 'richText':
|
||||
# Rich text message - extract text content
|
||||
rich_text = content.get('richText', [])
|
||||
texts = []
|
||||
for item in rich_text:
|
||||
if 'text' in item and item['text'] != '\n':
|
||||
texts.append(item['text'])
|
||||
quote_info['content'] = '\n'.join(texts)
|
||||
|
||||
return quote_info
|
||||
|
||||
async def get_message(self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage):
|
||||
try:
|
||||
# print(json.dumps(incoming_message.to_dict(), indent=4, ensure_ascii=False))
|
||||
@@ -193,6 +275,15 @@ class DingTalkClient:
|
||||
elif str(incoming_message.conversation_type) == '2':
|
||||
message_data['conversation_type'] = 'GroupMessage'
|
||||
|
||||
# Check for quoted/replied message
|
||||
raw_data = incoming_message.to_dict()
|
||||
text_data = raw_data.get('text', {})
|
||||
if isinstance(text_data, dict) and text_data.get('isReplyMsg'):
|
||||
replied_msg = text_data.get('repliedMsg', {})
|
||||
if replied_msg:
|
||||
quote_info = await self._parse_quoted_message(replied_msg)
|
||||
message_data['QuotedMessage'] = quote_info
|
||||
|
||||
if incoming_message.message_type == 'richText':
|
||||
data = incoming_message.rich_text_content.to_dict()
|
||||
|
||||
@@ -268,19 +359,52 @@ class DingTalkClient:
|
||||
|
||||
message_data['Type'] = 'image'
|
||||
elif incoming_message.message_type == 'audio':
|
||||
message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode'])
|
||||
raw_content = incoming_message.to_dict().get('content', {})
|
||||
# 兼容处理:如果 content 仍为 JSON 字符串则进行解析
|
||||
if isinstance(raw_content, str):
|
||||
try:
|
||||
raw_content = json.loads(raw_content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
raw_content = {}
|
||||
|
||||
if self.logger:
|
||||
await self.logger.info(f'DingTalk audio raw content: {json.dumps(raw_content, ensure_ascii=False)}')
|
||||
|
||||
# 提取钉钉自带的语音转写文字(Powered by Qwen)
|
||||
recognition = raw_content.get('recognition', '')
|
||||
if recognition:
|
||||
message_data['Content'] = recognition
|
||||
|
||||
download_code = raw_content.get('downloadCode')
|
||||
if download_code:
|
||||
message_data['Audio'] = await self.get_audio_url(download_code)
|
||||
|
||||
message_data['Type'] = 'audio'
|
||||
elif incoming_message.message_type == 'file':
|
||||
down_list = incoming_message.get_down_list()
|
||||
if len(down_list) >= 2:
|
||||
message_data['File'] = await self.get_file_url(down_list[0])
|
||||
message_data['Name'] = down_list[1]
|
||||
# 获取原始数据字典并提取嵌套的文件信息
|
||||
raw_data = incoming_message.to_dict()
|
||||
file_info = raw_data.get('content', {})
|
||||
|
||||
# 兼容处理:如果 content 仍为 JSON 字符串则进行解析
|
||||
if isinstance(file_info, str):
|
||||
try:
|
||||
file_info = json.loads(file_info)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
file_info = {}
|
||||
|
||||
download_code = file_info.get('downloadCode')
|
||||
file_name = file_info.get('fileName')
|
||||
|
||||
if download_code and file_name:
|
||||
# 转换 downloadCode 为可下载的真实 URL
|
||||
message_data['File'] = await self.get_file_url(download_code)
|
||||
message_data['Name'] = file_name
|
||||
else:
|
||||
if self.logger:
|
||||
await self.logger.error(f'get_down_list() returned fewer than 2 elements: {down_list}')
|
||||
await self.logger.error(f'Failed to extract file info from message content: {file_info}')
|
||||
message_data['File'] = None
|
||||
message_data['Name'] = None
|
||||
|
||||
message_data['Type'] = 'file'
|
||||
|
||||
copy_message_data = message_data.copy()
|
||||
@@ -357,6 +481,12 @@ class DingTalkClient:
|
||||
card_data['config'] = json.dumps({'autoLayout': card_auto_layout})
|
||||
card_data['content'] = ''
|
||||
|
||||
# 将用户的消息内容作为卡片的查询参数,方便后续处理
|
||||
if incoming_message.message_type == 'text':
|
||||
card_data['query'] = incoming_message.get_text_list()[0]
|
||||
else:
|
||||
card_data['query'] = '...'
|
||||
|
||||
card_instance = dingtalk_stream.AICardReplier(self.client, incoming_message)
|
||||
# print(card_instance)
|
||||
# 先投放卡片: https://open.dingtalk.com/document/orgapp/create-and-deliver-cards
|
||||
|
||||
@@ -47,6 +47,22 @@ class DingTalkEvent(dict):
|
||||
def conversation(self):
|
||||
return self.get('conversation_type', '')
|
||||
|
||||
@property
|
||||
def quoted_message(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get the quoted/replied message info if this is a reply message.
|
||||
|
||||
Returns:
|
||||
A dict containing:
|
||||
- message_id: The original message ID
|
||||
- msg_type: The message type (text, file, picture, audio, etc.)
|
||||
- content: The text content (if any)
|
||||
- file_url: The file download URL (if file type)
|
||||
- file_name: The file name (if file type)
|
||||
- picture: The picture base64 (if picture type)
|
||||
- audio: The audio base64 (if audio type)
|
||||
"""
|
||||
return self.get('QuotedMessage')
|
||||
|
||||
def __getattr__(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
允许通过属性访问数据中的任意字段。
|
||||
|
||||
3
src/langbot/libs/openclaw_weixin_api/__init__.py
Normal file
3
src/langbot/libs/openclaw_weixin_api/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .client import OpenClawWeixinClient as OpenClawWeixinClient
|
||||
from .types import ApiError as ApiError
|
||||
from .types import LoginResult as LoginResult
|
||||
807
src/langbot/libs/openclaw_weixin_api/client.py
Normal file
807
src/langbot/libs/openclaw_weixin_api/client.py
Normal file
@@ -0,0 +1,807 @@
|
||||
"""Async HTTP client for the OpenClaw WeChat API.
|
||||
|
||||
Implements the iLink Bot API protocol.
|
||||
Reference: https://github.com/epiral/weixin-bot
|
||||
|
||||
Endpoints: getUpdates (long-poll), sendMessage, getUploadUrl, getConfig, sendTyping.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import typing
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .types import (
|
||||
ApiError,
|
||||
CDNMedia,
|
||||
FileItem,
|
||||
GetConfigResponse,
|
||||
GetUpdatesResponse,
|
||||
GetUploadUrlResponse,
|
||||
ImageItem,
|
||||
LoginResult,
|
||||
MessageItem,
|
||||
QRCodeResponse,
|
||||
QRStatusResponse,
|
||||
RefMessage,
|
||||
TextItem,
|
||||
VideoItem,
|
||||
VoiceItem,
|
||||
WeixinMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger('openclaw-weixin-sdk')
|
||||
|
||||
DEFAULT_BASE_URL = 'https://ilinkai.weixin.qq.com'
|
||||
CDN_BASE_URL = 'https://novac2c.cdn.weixin.qq.com/c2c'
|
||||
|
||||
CHANNEL_VERSION = '1.0.0'
|
||||
|
||||
DEFAULT_API_TIMEOUT = 15
|
||||
DEFAULT_LONG_POLL_TIMEOUT = 40
|
||||
DEFAULT_CONFIG_TIMEOUT = 10
|
||||
DEFAULT_QR_POLL_TIMEOUT = 35
|
||||
|
||||
SESSION_EXPIRED_ERRCODE = -14
|
||||
|
||||
DEFAULT_BOT_TYPE = '3'
|
||||
|
||||
# Maximum text length per message chunk (WeChat limit)
|
||||
MAX_TEXT_CHUNK_SIZE = 2000
|
||||
|
||||
|
||||
def _random_wechat_uin() -> str:
|
||||
"""Generate the X-WECHAT-UIN header: random uint32 -> decimal string -> base64."""
|
||||
rand_bytes = os.urandom(4)
|
||||
uint32_val = struct.unpack('>I', rand_bytes)[0]
|
||||
return base64.b64encode(str(uint32_val).encode('utf-8')).decode('utf-8')
|
||||
|
||||
|
||||
def _build_base_info() -> dict:
|
||||
"""Build the base_info payload included in every API request."""
|
||||
return {'channel_version': CHANNEL_VERSION}
|
||||
|
||||
|
||||
def _chunk_text(text: str, max_size: int = MAX_TEXT_CHUNK_SIZE) -> list[str]:
|
||||
"""Split long text into chunks that fit within WeChat's message size limit."""
|
||||
if len(text) <= max_size:
|
||||
return [text]
|
||||
chunks = []
|
||||
while text:
|
||||
chunks.append(text[:max_size])
|
||||
text = text[max_size:]
|
||||
return chunks
|
||||
|
||||
|
||||
class OpenClawWeixinClient:
|
||||
"""Async client for the OpenClaw WeChat HTTP JSON API."""
|
||||
|
||||
def __init__(self, base_url: str, token: str):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.token = token
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'AuthorizationType': 'ilink_bot_token',
|
||||
'X-WECHAT-UIN': _random_wechat_uin(),
|
||||
}
|
||||
if self.token:
|
||||
headers['Authorization'] = f'Bearer {self.token}'
|
||||
return headers
|
||||
|
||||
async def _post(self, endpoint: str, payload: dict, timeout: float = DEFAULT_API_TIMEOUT) -> dict:
|
||||
"""Make a POST request and return the JSON response.
|
||||
|
||||
Raises ApiError on HTTP errors or when the response contains a non-zero errcode.
|
||||
"""
|
||||
payload['base_info'] = _build_base_info()
|
||||
|
||||
session = await self._get_session()
|
||||
url = f'{self.base_url}/{endpoint}'
|
||||
headers = self._build_headers()
|
||||
|
||||
async with session.post(
|
||||
url, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ApiError(
|
||||
f'OpenClaw API error {resp.status}: {text}',
|
||||
status=resp.status,
|
||||
)
|
||||
data = await resp.json(content_type=None)
|
||||
|
||||
# Check for application-level errors in the response body
|
||||
errcode = data.get('errcode') or data.get('ret')
|
||||
if errcode and errcode != 0:
|
||||
raise ApiError(
|
||||
data.get('errmsg') or f'API errcode {errcode}',
|
||||
status=200,
|
||||
code=errcode,
|
||||
payload=data,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def get_updates(
|
||||
self, get_updates_buf: str = '', timeout: float = DEFAULT_LONG_POLL_TIMEOUT
|
||||
) -> GetUpdatesResponse:
|
||||
"""Long-poll for new messages.
|
||||
|
||||
Note: This method does NOT raise ApiError for errcode responses —
|
||||
it returns them in the GetUpdatesResponse so the caller can handle
|
||||
session expiry and other errors with full context.
|
||||
"""
|
||||
try:
|
||||
# Bypass the errcode check in _post since get_updates needs
|
||||
# to return error info (e.g. session expired) to the caller.
|
||||
payload: dict = {'get_updates_buf': get_updates_buf}
|
||||
payload['base_info'] = _build_base_info()
|
||||
|
||||
session = await self._get_session()
|
||||
url = f'{self.base_url}/ilink/bot/getupdates'
|
||||
headers = self._build_headers()
|
||||
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ApiError(
|
||||
f'OpenClaw API error {resp.status}: {text}',
|
||||
status=resp.status,
|
||||
)
|
||||
data = await resp.json(content_type=None)
|
||||
|
||||
except (asyncio.TimeoutError, aiohttp.ServerTimeoutError):
|
||||
return GetUpdatesResponse(ret=0, msgs=[], get_updates_buf=get_updates_buf)
|
||||
except ApiError:
|
||||
raise
|
||||
except Exception as e:
|
||||
if 'timeout' in str(e).lower():
|
||||
return GetUpdatesResponse(ret=0, msgs=[], get_updates_buf=get_updates_buf)
|
||||
raise
|
||||
|
||||
return _parse_get_updates_response(data)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
to_user_id: str,
|
||||
item_list: list[MessageItem],
|
||||
context_token: str = '',
|
||||
) -> None:
|
||||
"""Send a message to a user."""
|
||||
items_payload = [_message_item_to_dict(item) for item in item_list]
|
||||
|
||||
payload = {
|
||||
'msg': {
|
||||
'from_user_id': '',
|
||||
'to_user_id': to_user_id,
|
||||
'client_id': f'langbot-{uuid.uuid4().hex[:16]}',
|
||||
'message_type': WeixinMessage.TYPE_BOT,
|
||||
'message_state': WeixinMessage.STATE_FINISH,
|
||||
'item_list': items_payload,
|
||||
'context_token': context_token or None,
|
||||
}
|
||||
}
|
||||
await self._post('ilink/bot/sendmessage', payload)
|
||||
|
||||
async def send_text(self, to_user_id: str, text: str, context_token: str = '') -> None:
|
||||
"""Send a plain text message, automatically chunking if too long."""
|
||||
chunks = _chunk_text(text)
|
||||
for chunk in chunks:
|
||||
item = MessageItem(type=MessageItem.TEXT, text_item=TextItem(text=chunk))
|
||||
await self.send_message(to_user_id, [item], context_token)
|
||||
|
||||
async def get_config(self, ilink_user_id: str, context_token: str = '') -> GetConfigResponse:
|
||||
"""Get bot config including typing_ticket."""
|
||||
data = await self._post(
|
||||
'ilink/bot/getconfig',
|
||||
{'ilink_user_id': ilink_user_id, 'context_token': context_token or None},
|
||||
timeout=DEFAULT_CONFIG_TIMEOUT,
|
||||
)
|
||||
return GetConfigResponse(
|
||||
ret=data.get('ret'),
|
||||
errmsg=data.get('errmsg'),
|
||||
typing_ticket=data.get('typing_ticket'),
|
||||
)
|
||||
|
||||
async def send_typing(self, ilink_user_id: str, typing_ticket: str, status: int = 1) -> None:
|
||||
"""Send typing indicator. status: 1=typing, 2=cancel."""
|
||||
await self._post(
|
||||
'ilink/bot/sendtyping',
|
||||
{
|
||||
'ilink_user_id': ilink_user_id,
|
||||
'typing_ticket': typing_ticket,
|
||||
'status': status,
|
||||
},
|
||||
timeout=DEFAULT_CONFIG_TIMEOUT,
|
||||
)
|
||||
|
||||
async def stop_typing(self, ilink_user_id: str, typing_ticket: str) -> None:
|
||||
"""Cancel the typing indicator for a user."""
|
||||
await self.send_typing(ilink_user_id, typing_ticket, status=2)
|
||||
|
||||
async def download_media(
|
||||
self,
|
||||
media: CDNMedia,
|
||||
) -> bytes:
|
||||
"""Download and decrypt a file from the WeChat CDN.
|
||||
|
||||
Args:
|
||||
media: CDNMedia object with encrypt_query_param and aes_key.
|
||||
|
||||
Returns:
|
||||
Decrypted file bytes.
|
||||
"""
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.padding import PKCS7
|
||||
|
||||
if not media.encrypt_query_param:
|
||||
raise ApiError('CDN media has no encrypt_query_param', status=0)
|
||||
if not media.aes_key:
|
||||
raise ApiError('CDN media has no aes_key', status=0)
|
||||
|
||||
# Derive 16-byte AES key
|
||||
# aes_key is base64-encoded; the decoded content may be:
|
||||
# - raw 16 bytes (direct AES key)
|
||||
# - 32-char hex string (decode hex to get 16 bytes)
|
||||
raw = base64.b64decode(media.aes_key)
|
||||
if len(raw) == 16:
|
||||
aes_key = raw
|
||||
elif len(raw) == 32:
|
||||
# Hex-encoded 16-byte key
|
||||
aes_key = bytes.fromhex(raw.decode('utf-8'))
|
||||
else:
|
||||
raise ApiError(f'Invalid AES key length: {len(raw)} (expected 16 or 32)', status=0)
|
||||
|
||||
# Download encrypted bytes from CDN
|
||||
session = await self._get_session()
|
||||
cdn_url = f'{CDN_BASE_URL}/download?encrypted_query_param={quote(media.encrypt_query_param, safe="")}'
|
||||
|
||||
async with session.get(cdn_url, timeout=aiohttp.ClientTimeout(total=120)) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ApiError(f'CDN download failed: {resp.status} {text}', status=resp.status)
|
||||
encrypted = await resp.read()
|
||||
|
||||
# Decrypt AES-128-ECB with PKCS7 padding
|
||||
cipher = Cipher(algorithms.AES(aes_key), modes.ECB())
|
||||
decryptor = cipher.decryptor()
|
||||
padded = decryptor.update(encrypted) + decryptor.finalize()
|
||||
|
||||
unpadder = PKCS7(128).unpadder()
|
||||
return unpadder.update(padded) + unpadder.finalize()
|
||||
|
||||
async def upload_media(
|
||||
self,
|
||||
file_bytes: bytes,
|
||||
to_user_id: str,
|
||||
media_type: int,
|
||||
) -> CDNMedia:
|
||||
"""Encrypt and upload media to WeChat CDN.
|
||||
|
||||
Args:
|
||||
file_bytes: Raw file bytes to upload.
|
||||
to_user_id: Recipient user ID.
|
||||
media_type: 1=IMAGE, 2=VIDEO, 3=FILE, 4=VOICE.
|
||||
|
||||
Returns:
|
||||
CDNMedia with encrypt_query_param and aes_key for use in sendMessage.
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.padding import PKCS7
|
||||
|
||||
# 1. Generate random 16-byte AES key
|
||||
raw_key = os.urandom(16)
|
||||
aes_key_hex = raw_key.hex() # 32-char hex string
|
||||
|
||||
# 2. Encode key for CDNMedia: base64(hex_string) — same for all media types
|
||||
# Matches official SDK: Buffer.from(aeskey_hex).toString("base64")
|
||||
encoded_key = base64.b64encode(aes_key_hex.encode('utf-8')).decode('utf-8')
|
||||
|
||||
# 3. Encrypt file with AES-128-ECB + PKCS7
|
||||
padder = PKCS7(128).padder()
|
||||
padded = padder.update(file_bytes) + padder.finalize()
|
||||
cipher = Cipher(algorithms.AES(raw_key), modes.ECB())
|
||||
encryptor = cipher.encryptor()
|
||||
encrypted = encryptor.update(padded) + encryptor.finalize()
|
||||
|
||||
# 4. Get upload URL
|
||||
raw_md5 = hashlib.md5(file_bytes).hexdigest()
|
||||
filekey = os.urandom(16).hex() # 32-char hex, matches official SDK
|
||||
|
||||
upload_resp = await self.get_upload_url(
|
||||
filekey=filekey,
|
||||
media_type=media_type,
|
||||
to_user_id=to_user_id,
|
||||
rawsize=len(file_bytes),
|
||||
rawfilemd5=raw_md5,
|
||||
filesize=len(encrypted),
|
||||
aeskey=aes_key_hex, # hex string, as expected by the API
|
||||
)
|
||||
|
||||
if not upload_resp.upload_param:
|
||||
raise ApiError('Failed to get upload URL', status=0)
|
||||
|
||||
# 5. Upload to CDN
|
||||
# upload_param is an opaque token from the server — pass it as-is
|
||||
session = await self._get_session()
|
||||
cdn_url = f'{CDN_BASE_URL}/upload?encrypted_query_param={quote(upload_resp.upload_param, safe="")}&filekey={quote(filekey, safe="")}'
|
||||
logger.debug(
|
||||
'CDN upload: url=%s raw_size=%d encrypted_size=%d md5=%s aeskey=%s',
|
||||
cdn_url,
|
||||
len(file_bytes),
|
||||
len(encrypted),
|
||||
raw_md5,
|
||||
encoded_key,
|
||||
)
|
||||
|
||||
async with session.post(
|
||||
cdn_url,
|
||||
data=encrypted,
|
||||
headers={'Content-Type': 'application/octet-stream'},
|
||||
timeout=aiohttp.ClientTimeout(total=120),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
logger.error('CDN upload failed: status=%d url=%s body=%s', resp.status, cdn_url, text[:500])
|
||||
raise ApiError(f'CDN upload failed: {resp.status} {text}', status=resp.status)
|
||||
download_param = resp.headers.get('x-encrypted-param', '')
|
||||
|
||||
if not download_param:
|
||||
raise ApiError('CDN upload succeeded but no x-encrypted-param returned', status=0)
|
||||
|
||||
return CDNMedia(
|
||||
encrypt_query_param=download_param,
|
||||
aes_key=encoded_key,
|
||||
encrypt_type=1,
|
||||
)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
to_user_id: str,
|
||||
image_bytes: bytes,
|
||||
context_token: str = '',
|
||||
) -> None:
|
||||
"""Upload an image to CDN and send it."""
|
||||
media = await self.upload_media(image_bytes, to_user_id, media_type=1)
|
||||
item = MessageItem(
|
||||
type=MessageItem.IMAGE,
|
||||
image_item=ImageItem(
|
||||
media=media,
|
||||
aeskey=media.aes_key,
|
||||
),
|
||||
)
|
||||
await self.send_message(to_user_id, [item], context_token)
|
||||
|
||||
async def send_file(
|
||||
self,
|
||||
to_user_id: str,
|
||||
file_bytes: bytes,
|
||||
file_name: str,
|
||||
context_token: str = '',
|
||||
) -> None:
|
||||
"""Upload a file to CDN and send it."""
|
||||
import hashlib
|
||||
|
||||
media = await self.upload_media(file_bytes, to_user_id, media_type=3)
|
||||
item = MessageItem(
|
||||
type=MessageItem.FILE,
|
||||
file_item=FileItem(
|
||||
media=media,
|
||||
file_name=file_name,
|
||||
md5=hashlib.md5(file_bytes).hexdigest(),
|
||||
len=str(len(file_bytes)),
|
||||
),
|
||||
)
|
||||
await self.send_message(to_user_id, [item], context_token)
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
to_user_id: str,
|
||||
voice_bytes: bytes,
|
||||
playtime: int = 0,
|
||||
context_token: str = '',
|
||||
) -> None:
|
||||
"""Upload a voice message to CDN and send it."""
|
||||
media = await self.upload_media(voice_bytes, to_user_id, media_type=4)
|
||||
item = MessageItem(
|
||||
type=MessageItem.VOICE,
|
||||
voice_item=VoiceItem(
|
||||
media=media,
|
||||
playtime=playtime,
|
||||
),
|
||||
)
|
||||
await self.send_message(to_user_id, [item], context_token)
|
||||
|
||||
async def get_upload_url(
|
||||
self,
|
||||
filekey: str,
|
||||
media_type: int,
|
||||
to_user_id: str,
|
||||
rawsize: int,
|
||||
rawfilemd5: str,
|
||||
filesize: int,
|
||||
thumb_rawsize: Optional[int] = None,
|
||||
thumb_rawfilemd5: Optional[str] = None,
|
||||
thumb_filesize: Optional[int] = None,
|
||||
aeskey: Optional[str] = None,
|
||||
) -> GetUploadUrlResponse:
|
||||
"""Get a pre-signed CDN upload URL."""
|
||||
payload: dict = {
|
||||
'filekey': filekey,
|
||||
'media_type': media_type,
|
||||
'to_user_id': to_user_id,
|
||||
'rawsize': rawsize,
|
||||
'rawfilemd5': rawfilemd5,
|
||||
'filesize': filesize,
|
||||
'no_need_thumb': True,
|
||||
}
|
||||
if thumb_rawsize is not None:
|
||||
payload['thumb_rawsize'] = thumb_rawsize
|
||||
if thumb_rawfilemd5 is not None:
|
||||
payload['thumb_rawfilemd5'] = thumb_rawfilemd5
|
||||
if thumb_filesize is not None:
|
||||
payload['thumb_filesize'] = thumb_filesize
|
||||
if aeskey is not None:
|
||||
payload['aeskey'] = aeskey
|
||||
|
||||
data = await self._post('ilink/bot/getuploadurl', payload)
|
||||
logger.debug('get_upload_url response: %s', data)
|
||||
return GetUploadUrlResponse(
|
||||
upload_param=data.get('upload_param'),
|
||||
thumb_upload_param=data.get('thumb_upload_param'),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# QR Code Login
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def fetch_qrcode(self, bot_type: str = DEFAULT_BOT_TYPE) -> QRCodeResponse:
|
||||
"""Fetch a QR code for WeChat login authorization (GET, no auth needed)."""
|
||||
session = await self._get_session()
|
||||
url = f'{self.base_url}/ilink/bot/get_bot_qrcode?bot_type={bot_type}'
|
||||
|
||||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=DEFAULT_API_TIMEOUT)) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ApiError(
|
||||
f'Failed to fetch QR code: {resp.status} {text}',
|
||||
status=resp.status,
|
||||
)
|
||||
data = await resp.json(content_type=None)
|
||||
|
||||
logger.debug(
|
||||
'fetch_qrcode response: qrcode=%s, img=%s', data.get('qrcode'), bool(data.get('qrcode_img_content'))
|
||||
)
|
||||
|
||||
return QRCodeResponse(
|
||||
qrcode=data.get('qrcode'),
|
||||
qrcode_img_content=data.get('qrcode_img_content'),
|
||||
)
|
||||
|
||||
async def _fetch_qr_image_base64(self, url: str) -> str:
|
||||
"""Generate a QR code image from the URL and return a data URI string.
|
||||
|
||||
The qrcode_img_content URL points to an HTML page (not a raw image),
|
||||
so we generate the QR code locally using the qrcode library.
|
||||
"""
|
||||
import qrcode
|
||||
|
||||
qr = qrcode.QRCode(error_correction=qrcode.constants.ERROR_CORRECT_L)
|
||||
qr.add_data(url)
|
||||
qr.make(fit=True)
|
||||
img = qr.make_image(fill_color='black', back_color='white')
|
||||
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format='PNG')
|
||||
b64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
||||
return f'data:image/png;base64,{b64}'
|
||||
|
||||
async def poll_qrcode_status(self, qrcode: str) -> QRStatusResponse:
|
||||
"""Long-poll the QR code scan status (GET with iLink-App-ClientVersion header)."""
|
||||
session = await self._get_session()
|
||||
url = f'{self.base_url}/ilink/bot/get_qrcode_status?qrcode={quote(qrcode, safe="")}'
|
||||
headers = {'iLink-App-ClientVersion': '1'}
|
||||
|
||||
try:
|
||||
async with session.get(
|
||||
url, headers=headers, timeout=aiohttp.ClientTimeout(total=DEFAULT_QR_POLL_TIMEOUT)
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ApiError(
|
||||
f'Failed to poll QR status: {resp.status} {text}',
|
||||
status=resp.status,
|
||||
)
|
||||
data = await resp.json(content_type=None)
|
||||
logger.debug('QR status poll response: %s', data)
|
||||
except (asyncio.TimeoutError, aiohttp.ServerTimeoutError):
|
||||
return QRStatusResponse(status='wait')
|
||||
|
||||
return QRStatusResponse(
|
||||
status=data.get('status'),
|
||||
bot_token=data.get('bot_token'),
|
||||
ilink_bot_id=data.get('ilink_bot_id'),
|
||||
baseurl=data.get('baseurl'),
|
||||
ilink_user_id=data.get('ilink_user_id'),
|
||||
)
|
||||
|
||||
async def login(
|
||||
self,
|
||||
max_retries: int = 5,
|
||||
poll_timeout_ms: int = 480_000,
|
||||
on_qrcode: Optional[typing.Callable[[str, str], typing.Any]] = None,
|
||||
on_status: Optional[typing.Callable[[str], typing.Any]] = None,
|
||||
) -> LoginResult:
|
||||
"""Complete QR code login flow with auto-retry on expiry.
|
||||
|
||||
Args:
|
||||
max_retries: Max number of QR code refreshes on expiry.
|
||||
poll_timeout_ms: Timeout per QR code in milliseconds.
|
||||
on_qrcode: Callback(qr_image_base64, qr_url) called each time a
|
||||
new QR code is fetched. Use this to display the QR code.
|
||||
on_status: Callback(status_str) called on each status poll change.
|
||||
|
||||
Returns:
|
||||
LoginResult with token, base_url, and account_id.
|
||||
|
||||
Raises:
|
||||
ApiError: On unrecoverable API errors.
|
||||
Exception: If all retries are exhausted.
|
||||
"""
|
||||
last_qr_base64: Optional[str] = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
qr_resp = await self.fetch_qrcode()
|
||||
if not qr_resp.qrcode or not qr_resp.qrcode_img_content:
|
||||
raise ApiError('Failed to get QR code from server', status=0)
|
||||
|
||||
# Convert QR image to base64 and notify caller
|
||||
last_qr_base64 = await self._fetch_qr_image_base64(qr_resp.qrcode_img_content)
|
||||
if on_qrcode:
|
||||
try:
|
||||
result = on_qrcode(last_qr_base64, qr_resp.qrcode_img_content)
|
||||
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
|
||||
await result
|
||||
except Exception as e:
|
||||
logger.warning('on_qrcode callback error: %s', e)
|
||||
|
||||
# Poll until confirmed / expired / timeout
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + poll_timeout_ms / 1000.0
|
||||
|
||||
while loop.time() < deadline:
|
||||
try:
|
||||
status_resp = await self.poll_qrcode_status(qr_resp.qrcode)
|
||||
except Exception as e:
|
||||
logger.error('Error polling QR status: %s', e)
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
if on_status:
|
||||
try:
|
||||
cb_result = on_status(status_resp.status or 'unknown')
|
||||
if asyncio.iscoroutine(cb_result) or asyncio.isfuture(cb_result):
|
||||
await cb_result
|
||||
except Exception as e:
|
||||
logger.warning('on_status callback error: %s', e)
|
||||
|
||||
if status_resp.status == 'confirmed' and status_resp.bot_token:
|
||||
new_base_url = status_resp.baseurl or self.base_url
|
||||
# Update this client instance as well
|
||||
self.token = status_resp.bot_token
|
||||
self.base_url = new_base_url.rstrip('/')
|
||||
return LoginResult(
|
||||
token=status_resp.bot_token,
|
||||
base_url=new_base_url,
|
||||
account_id=status_resp.ilink_bot_id or '',
|
||||
qr_image_base64=last_qr_base64,
|
||||
)
|
||||
|
||||
if status_resp.status == 'expired':
|
||||
break # retry with a new QR code
|
||||
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
# While-loop ended without break → poll timeout, treat as expired
|
||||
pass
|
||||
|
||||
remaining = max_retries - attempt - 1
|
||||
if remaining > 0:
|
||||
logger.info('QR code expired, refreshing... (%d retries left)', remaining)
|
||||
else:
|
||||
raise ApiError('QR code login failed: max retries exceeded', status=0)
|
||||
|
||||
# Should not reach here, but just in case
|
||||
raise ApiError('QR code login failed', status=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parsing helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_cdn_media(data: Optional[dict]) -> Optional[CDNMedia]:
|
||||
if not data:
|
||||
return None
|
||||
return CDNMedia(
|
||||
encrypt_query_param=data.get('encrypt_query_param'),
|
||||
aes_key=data.get('aes_key'),
|
||||
encrypt_type=data.get('encrypt_type'),
|
||||
)
|
||||
|
||||
|
||||
def _parse_message_item(data: dict) -> MessageItem:
|
||||
item = MessageItem(
|
||||
type=data.get('type'),
|
||||
create_time_ms=data.get('create_time_ms'),
|
||||
update_time_ms=data.get('update_time_ms'),
|
||||
is_completed=data.get('is_completed'),
|
||||
msg_id=data.get('msg_id'),
|
||||
)
|
||||
|
||||
if data.get('text_item'):
|
||||
item.text_item = TextItem(text=data['text_item'].get('text'))
|
||||
|
||||
if data.get('image_item'):
|
||||
img = data['image_item']
|
||||
item.image_item = ImageItem(
|
||||
media=_parse_cdn_media(img.get('media')),
|
||||
thumb_media=_parse_cdn_media(img.get('thumb_media')),
|
||||
aeskey=img.get('aeskey'),
|
||||
url=img.get('url'),
|
||||
mid_size=img.get('mid_size'),
|
||||
)
|
||||
|
||||
if data.get('voice_item'):
|
||||
v = data['voice_item']
|
||||
item.voice_item = VoiceItem(
|
||||
media=_parse_cdn_media(v.get('media')),
|
||||
encode_type=v.get('encode_type'),
|
||||
playtime=v.get('playtime'),
|
||||
text=v.get('text'),
|
||||
)
|
||||
|
||||
if data.get('file_item'):
|
||||
f = data['file_item']
|
||||
item.file_item = FileItem(
|
||||
media=_parse_cdn_media(f.get('media')),
|
||||
file_name=f.get('file_name'),
|
||||
md5=f.get('md5'),
|
||||
len=f.get('len'),
|
||||
)
|
||||
|
||||
if data.get('video_item'):
|
||||
vid = data['video_item']
|
||||
item.video_item = VideoItem(
|
||||
media=_parse_cdn_media(vid.get('media')),
|
||||
video_size=vid.get('video_size'),
|
||||
play_length=vid.get('play_length'),
|
||||
video_md5=vid.get('video_md5'),
|
||||
thumb_media=_parse_cdn_media(vid.get('thumb_media')),
|
||||
)
|
||||
|
||||
if data.get('ref_msg'):
|
||||
ref = data['ref_msg']
|
||||
item.ref_msg = RefMessage(
|
||||
title=ref.get('title'),
|
||||
message_item=_parse_message_item(ref['message_item']) if ref.get('message_item') else None,
|
||||
)
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def _parse_weixin_message(data: dict) -> WeixinMessage:
|
||||
msg = WeixinMessage(
|
||||
seq=data.get('seq'),
|
||||
message_id=data.get('message_id'),
|
||||
from_user_id=data.get('from_user_id'),
|
||||
to_user_id=data.get('to_user_id'),
|
||||
client_id=data.get('client_id'),
|
||||
create_time_ms=data.get('create_time_ms'),
|
||||
session_id=data.get('session_id'),
|
||||
group_id=data.get('group_id'),
|
||||
message_type=data.get('message_type'),
|
||||
message_state=data.get('message_state'),
|
||||
context_token=data.get('context_token'),
|
||||
)
|
||||
if data.get('item_list'):
|
||||
msg.item_list = [_parse_message_item(item) for item in data['item_list']]
|
||||
return msg
|
||||
|
||||
|
||||
def _parse_get_updates_response(data: dict) -> GetUpdatesResponse:
|
||||
resp = GetUpdatesResponse(
|
||||
ret=data.get('ret'),
|
||||
errcode=data.get('errcode'),
|
||||
errmsg=data.get('errmsg'),
|
||||
get_updates_buf=data.get('get_updates_buf'),
|
||||
longpolling_timeout_ms=data.get('longpolling_timeout_ms'),
|
||||
)
|
||||
if data.get('msgs'):
|
||||
resp.msgs = [_parse_weixin_message(m) for m in data['msgs']]
|
||||
return resp
|
||||
|
||||
|
||||
def _cdn_media_to_dict(media: Optional[CDNMedia]) -> Optional[dict]:
|
||||
if not media:
|
||||
return None
|
||||
d: dict = {}
|
||||
if media.encrypt_query_param is not None:
|
||||
d['encrypt_query_param'] = media.encrypt_query_param
|
||||
if media.aes_key is not None:
|
||||
d['aes_key'] = media.aes_key
|
||||
if media.encrypt_type is not None:
|
||||
d['encrypt_type'] = media.encrypt_type
|
||||
return d or None
|
||||
|
||||
|
||||
def _message_item_to_dict(item: MessageItem) -> dict:
|
||||
d: dict = {'type': item.type}
|
||||
|
||||
if item.text_item:
|
||||
d['text_item'] = {'text': item.text_item.text}
|
||||
|
||||
if item.image_item:
|
||||
img_d: dict = {}
|
||||
if item.image_item.media:
|
||||
img_d['media'] = _cdn_media_to_dict(item.image_item.media)
|
||||
if item.image_item.mid_size is not None:
|
||||
img_d['mid_size'] = item.image_item.mid_size
|
||||
d['image_item'] = img_d
|
||||
|
||||
if item.voice_item:
|
||||
voice_d: dict = {}
|
||||
if item.voice_item.media:
|
||||
voice_d['media'] = _cdn_media_to_dict(item.voice_item.media)
|
||||
if item.voice_item.playtime is not None:
|
||||
voice_d['playtime'] = item.voice_item.playtime
|
||||
d['voice_item'] = voice_d
|
||||
|
||||
if item.file_item:
|
||||
file_d: dict = {}
|
||||
if item.file_item.media:
|
||||
file_d['media'] = _cdn_media_to_dict(item.file_item.media)
|
||||
if item.file_item.file_name:
|
||||
file_d['file_name'] = item.file_item.file_name
|
||||
if item.file_item.len:
|
||||
file_d['len'] = item.file_item.len
|
||||
d['file_item'] = file_d
|
||||
|
||||
if item.video_item:
|
||||
vid_d: dict = {}
|
||||
if item.video_item.media:
|
||||
vid_d['media'] = _cdn_media_to_dict(item.video_item.media)
|
||||
if item.video_item.video_size is not None:
|
||||
vid_d['video_size'] = item.video_item.video_size
|
||||
d['video_item'] = vid_d
|
||||
|
||||
return d
|
||||
200
src/langbot/libs/openclaw_weixin_api/types.py
Normal file
200
src/langbot/libs/openclaw_weixin_api/types.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Type definitions for the OpenClaw WeChat API, mirroring the upstream protocol."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
SESSION_EXPIRED_ERRCODE = -14
|
||||
|
||||
|
||||
class ApiError(Exception):
|
||||
"""Structured error raised by the OpenClaw WeChat API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
status: int = 0,
|
||||
code: int | None = None,
|
||||
payload: Any = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.status = status
|
||||
self.code = code
|
||||
self.payload = payload
|
||||
|
||||
@property
|
||||
def is_session_expired(self) -> bool:
|
||||
return self.code == SESSION_EXPIRED_ERRCODE
|
||||
|
||||
|
||||
@dataclass
|
||||
class CDNMedia:
|
||||
encrypt_query_param: Optional[str] = None
|
||||
aes_key: Optional[str] = None
|
||||
encrypt_type: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextItem:
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageItem:
|
||||
media: Optional[CDNMedia] = None
|
||||
thumb_media: Optional[CDNMedia] = None
|
||||
aeskey: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
mid_size: Optional[int] = None
|
||||
thumb_size: Optional[int] = None
|
||||
thumb_height: Optional[int] = None
|
||||
thumb_width: Optional[int] = None
|
||||
hd_size: Optional[int] = None
|
||||
_downloaded_bytes: Optional[bytes] = field(default=None, repr=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceItem:
|
||||
media: Optional[CDNMedia] = None
|
||||
encode_type: Optional[int] = None
|
||||
bits_per_sample: Optional[int] = None
|
||||
sample_rate: Optional[int] = None
|
||||
playtime: Optional[int] = None
|
||||
text: Optional[str] = None
|
||||
_downloaded_bytes: Optional[bytes] = field(default=None, repr=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileItem:
|
||||
media: Optional[CDNMedia] = None
|
||||
file_name: Optional[str] = None
|
||||
md5: Optional[str] = None
|
||||
len: Optional[str] = None
|
||||
_downloaded_bytes: Optional[bytes] = field(default=None, repr=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoItem:
|
||||
media: Optional[CDNMedia] = None
|
||||
video_size: Optional[int] = None
|
||||
play_length: Optional[int] = None
|
||||
video_md5: Optional[str] = None
|
||||
thumb_media: Optional[CDNMedia] = None
|
||||
thumb_size: Optional[int] = None
|
||||
thumb_height: Optional[int] = None
|
||||
thumb_width: Optional[int] = None
|
||||
_downloaded_bytes: Optional[bytes] = field(default=None, repr=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RefMessage:
|
||||
message_item: Optional[MessageItem] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageItem:
|
||||
"""A single content item inside a WeixinMessage."""
|
||||
|
||||
# Item types
|
||||
NONE = 0
|
||||
TEXT = 1
|
||||
IMAGE = 2
|
||||
VOICE = 3
|
||||
FILE = 4
|
||||
VIDEO = 5
|
||||
|
||||
type: Optional[int] = None
|
||||
create_time_ms: Optional[int] = None
|
||||
update_time_ms: Optional[int] = None
|
||||
is_completed: Optional[bool] = None
|
||||
msg_id: Optional[str] = None
|
||||
ref_msg: Optional[RefMessage] = None
|
||||
text_item: Optional[TextItem] = None
|
||||
image_item: Optional[ImageItem] = None
|
||||
voice_item: Optional[VoiceItem] = None
|
||||
file_item: Optional[FileItem] = None
|
||||
video_item: Optional[VideoItem] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeixinMessage:
|
||||
"""Unified message from getUpdates or for sendMessage."""
|
||||
|
||||
# Message types
|
||||
TYPE_USER = 1
|
||||
TYPE_BOT = 2
|
||||
|
||||
# Message states
|
||||
STATE_NEW = 0
|
||||
STATE_GENERATING = 1
|
||||
STATE_FINISH = 2
|
||||
|
||||
seq: Optional[int] = None
|
||||
message_id: Optional[int] = None
|
||||
from_user_id: Optional[str] = None
|
||||
to_user_id: Optional[str] = None
|
||||
client_id: Optional[str] = None
|
||||
create_time_ms: Optional[int] = None
|
||||
update_time_ms: Optional[int] = None
|
||||
delete_time_ms: Optional[int] = None
|
||||
session_id: Optional[str] = None
|
||||
group_id: Optional[str] = None
|
||||
message_type: Optional[int] = None
|
||||
message_state: Optional[int] = None
|
||||
item_list: Optional[list[MessageItem]] = None
|
||||
context_token: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetUpdatesResponse:
|
||||
ret: Optional[int] = None
|
||||
errcode: Optional[int] = None
|
||||
errmsg: Optional[str] = None
|
||||
msgs: list[WeixinMessage] = field(default_factory=list)
|
||||
get_updates_buf: Optional[str] = None
|
||||
longpolling_timeout_ms: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetConfigResponse:
|
||||
ret: Optional[int] = None
|
||||
errmsg: Optional[str] = None
|
||||
typing_ticket: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetUploadUrlResponse:
|
||||
upload_param: Optional[str] = None
|
||||
thumb_upload_param: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class QRCodeResponse:
|
||||
"""Response from get_bot_qrcode endpoint."""
|
||||
|
||||
qrcode: Optional[str] = None
|
||||
qrcode_img_content: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class QRStatusResponse:
|
||||
"""Response from get_qrcode_status endpoint."""
|
||||
|
||||
status: Optional[str] = None # "wait" | "scaned" | "confirmed" | "expired"
|
||||
bot_token: Optional[str] = None
|
||||
ilink_bot_id: Optional[str] = None
|
||||
baseurl: Optional[str] = None
|
||||
ilink_user_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoginResult:
|
||||
"""Result returned by the login flow."""
|
||||
|
||||
token: str
|
||||
base_url: str
|
||||
account_id: str
|
||||
qr_image_base64: Optional[str] = None # data URI of the last QR code shown
|
||||
@@ -1,8 +1,10 @@
|
||||
import re
|
||||
import time
|
||||
import asyncio
|
||||
from quart import request
|
||||
import httpx
|
||||
from quart import Quart
|
||||
from typing import Callable, Dict, Any
|
||||
from typing import Callable, Dict, Any, Optional
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
from .qqofficialevent import QQOfficialEvent
|
||||
import json
|
||||
@@ -32,6 +34,8 @@ class QQOfficialClient:
|
||||
self.access_token = ''
|
||||
self.access_token_expiry_time = None
|
||||
self.logger = logger
|
||||
self._msg_seq_counter = 0
|
||||
self._token_refresh_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def check_access_token(self):
|
||||
"""检查access_token是否存在"""
|
||||
@@ -50,18 +54,18 @@ class QQOfficialClient:
|
||||
headers = {
|
||||
'content-type': 'application/json',
|
||||
}
|
||||
try:
|
||||
response = await client.post(url, json=params, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
access_token = response_data.get('access_token')
|
||||
expires_in = int(response_data.get('expires_in', 7200))
|
||||
self.access_token_expiry_time = time.time() + expires_in - 60
|
||||
if access_token:
|
||||
self.access_token = access_token
|
||||
except Exception as e:
|
||||
await self.logger.error(f'获取access_token失败: {response_data}')
|
||||
raise Exception(f'获取access_token失败: {e}')
|
||||
response = await client.post(url, json=params, headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to get access_token: HTTP {response.status_code} {response.text}')
|
||||
response_data = response.json()
|
||||
access_token = response_data.get('access_token')
|
||||
expires_in = int(response_data.get('expires_in', 7200))
|
||||
self.access_token_expiry_time = time.time() + expires_in - 60
|
||||
if access_token:
|
||||
self.access_token = access_token
|
||||
await self.logger.info(f'access_token obtained, expires_in={expires_in}s')
|
||||
else:
|
||||
raise Exception('Failed to get access_token: no access_token in response')
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求(独立端口模式,使用全局 request)"""
|
||||
@@ -87,10 +91,10 @@ class QQOfficialClient:
|
||||
try:
|
||||
body = await req.get_data()
|
||||
|
||||
print(f'[QQ Official] Received request, body length: {len(body)}')
|
||||
await self.logger.info(f'Received request, body length: {len(body)}')
|
||||
|
||||
if not body or len(body) == 0:
|
||||
print('[QQ Official] Received empty body, might be health check or GET request')
|
||||
await self.logger.info('Received empty body, might be health check or GET request')
|
||||
return {'code': 0, 'message': 'ok'}, 200
|
||||
|
||||
payload = json.loads(body)
|
||||
@@ -111,7 +115,6 @@ class QQOfficialClient:
|
||||
return {'code': 0, 'message': 'success'}
|
||||
|
||||
except Exception as e:
|
||||
print(f'[QQ Official] ERROR: {traceback.format_exc()}')
|
||||
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
|
||||
return {'error': str(e)}, 400
|
||||
|
||||
@@ -139,21 +142,24 @@ class QQOfficialClient:
|
||||
|
||||
async def get_message(self, msg: dict) -> Dict[str, Any]:
|
||||
"""获取消息"""
|
||||
d = msg.get('d', {})
|
||||
if not isinstance(d, dict):
|
||||
return {}
|
||||
message_data = {
|
||||
't': msg.get('t', {}),
|
||||
'user_openid': msg.get('d', {}).get('author', {}).get('user_openid', {}),
|
||||
'timestamp': msg.get('d', {}).get('timestamp', {}),
|
||||
'd_author_id': msg.get('d', {}).get('author', {}).get('id', {}),
|
||||
'content': msg.get('d', {}).get('content', {}),
|
||||
'd_id': msg.get('d', {}).get('id', {}),
|
||||
'user_openid': d.get('author', {}).get('user_openid', {}),
|
||||
'timestamp': d.get('timestamp', {}),
|
||||
'd_author_id': d.get('author', {}).get('id', {}),
|
||||
'content': d.get('content', {}),
|
||||
'd_id': d.get('id', {}),
|
||||
'id': msg.get('id', {}),
|
||||
'channel_id': msg.get('d', {}).get('channel_id', {}),
|
||||
'username': msg.get('d', {}).get('author', {}).get('username', {}),
|
||||
'guild_id': msg.get('d', {}).get('guild_id', {}),
|
||||
'member_openid': msg.get('d', {}).get('author', {}).get('openid', {}),
|
||||
'group_openid': msg.get('d', {}).get('group_openid', {}),
|
||||
'channel_id': d.get('channel_id', {}),
|
||||
'username': d.get('author', {}).get('username', {}),
|
||||
'guild_id': d.get('guild_id', {}),
|
||||
'member_openid': d.get('author', {}).get('openid', {}),
|
||||
'group_openid': d.get('group_openid', {}),
|
||||
}
|
||||
attachments = msg.get('d', {}).get('attachments', [])
|
||||
attachments = d.get('attachments', [])
|
||||
image_attachments = [attachment['url'] for attachment in attachments if await self.is_image(attachment)]
|
||||
image_attachments_type = [
|
||||
attachment['content_type'] for attachment in attachments if await self.is_image(attachment)
|
||||
@@ -192,7 +198,7 @@ class QQOfficialClient:
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
await self.logger.error(f'发送私聊消息失败: {response_data}')
|
||||
await self.logger.error(f'Failed to send private message: {response_data}')
|
||||
raise ValueError(response)
|
||||
|
||||
async def send_group_text_msg(self, group_openid: str, content: str, msg_id: str):
|
||||
@@ -215,7 +221,7 @@ class QQOfficialClient:
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
await self.logger.error(f'发送群聊消息失败:{response.json()}')
|
||||
await self.logger.error(f'Failed to send group message: {response.json()}')
|
||||
raise Exception(response.read().decode())
|
||||
|
||||
async def send_channle_group_text_msg(self, channel_id: str, content: str, msg_id: str):
|
||||
@@ -238,7 +244,7 @@ class QQOfficialClient:
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
await self.logger.error(f'发送频道群聊消息失败: {response.json()}')
|
||||
await self.logger.error(f'Failed to send channel group message: {response.json()}')
|
||||
raise Exception(response)
|
||||
|
||||
async def send_channle_private_text_msg(self, guild_id: str, content: str, msg_id: str):
|
||||
@@ -261,9 +267,224 @@ class QQOfficialClient:
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
await self.logger.error(f'发送频道私聊消息失败: {response.json()}')
|
||||
await self.logger.error(f'Failed to send channel private message: {response.json()}')
|
||||
raise Exception(response)
|
||||
|
||||
# ---- 富媒体消息 ----
|
||||
|
||||
# 媒体文件类型
|
||||
MEDIA_TYPE_IMAGE = 1
|
||||
MEDIA_TYPE_VIDEO = 2
|
||||
MEDIA_TYPE_VOICE = 3
|
||||
MEDIA_TYPE_FILE = 4
|
||||
|
||||
async def upload_media(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
file_type: int,
|
||||
file_url: str = None,
|
||||
file_data: str = None,
|
||||
file_name: str = None,
|
||||
) -> str:
|
||||
"""上传媒体文件,返回 file_info。
|
||||
|
||||
Args:
|
||||
target_type: 'c2c' | 'group'
|
||||
target_id: 用户 openid 或群 openid
|
||||
file_type: 1=图片, 2=视频, 3=语音, 4=文件
|
||||
file_url: 在线 URL(与 file_data 二选一)
|
||||
file_data: base64 编码的文件数据或 data URL(与 file_url 二选一)
|
||||
file_name: 文件名(file_type=4 时必填)
|
||||
"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
|
||||
if target_type == 'c2c':
|
||||
url = f'{self.base_url}/v2/users/{target_id}/files'
|
||||
elif target_type == 'group':
|
||||
url = f'{self.base_url}/v2/groups/{target_id}/files'
|
||||
else:
|
||||
raise ValueError(f'Unsupported target_type: {target_type}')
|
||||
|
||||
body = {
|
||||
'file_type': file_type,
|
||||
'srv_send_msg': False,
|
||||
}
|
||||
if file_url:
|
||||
body['url'] = file_url
|
||||
elif file_data:
|
||||
# 处理 data URL 格式: data:image/png;base64,xxxxx
|
||||
if file_data.startswith('data:'):
|
||||
match = re.match(r'^data:[^;]+;base64,(.+)$', file_data, re.DOTALL)
|
||||
if match:
|
||||
body['file_data'] = match.group(1)
|
||||
else:
|
||||
body['file_data'] = file_data
|
||||
else:
|
||||
body['file_data'] = file_data
|
||||
else:
|
||||
raise ValueError('file_url or file_data is required')
|
||||
|
||||
if file_type == self.MEDIA_TYPE_FILE and file_name:
|
||||
body['file_name'] = file_name
|
||||
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
headers = {
|
||||
'Authorization': f'QQBot {self.access_token}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
file_info = data.get('file_info', '')
|
||||
preview = file_info[:80] + '...' if len(file_info) > 80 else file_info
|
||||
await self.logger.info(f'Upload media success, file_info={preview}')
|
||||
return file_info
|
||||
else:
|
||||
raise Exception(f'Failed to upload media: HTTP {response.status_code} {response.text}')
|
||||
|
||||
async def _send_media_msg(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
file_info: str,
|
||||
msg_id: str = None,
|
||||
content: str = None,
|
||||
):
|
||||
"""发送富媒体消息(msg_type=7)"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
|
||||
if target_type == 'c2c':
|
||||
url = f'{self.base_url}/v2/users/{target_id}/messages'
|
||||
elif target_type == 'group':
|
||||
url = f'{self.base_url}/v2/groups/{target_id}/messages'
|
||||
else:
|
||||
raise ValueError(f'Unsupported target_type: {target_type}')
|
||||
|
||||
self._msg_seq_counter += 1
|
||||
msg_seq = self._msg_seq_counter
|
||||
body = {
|
||||
'msg_type': 7,
|
||||
'media': {'file_info': file_info},
|
||||
'msg_seq': msg_seq,
|
||||
}
|
||||
if content:
|
||||
body['content'] = content
|
||||
if msg_id:
|
||||
body['msg_id'] = msg_id
|
||||
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
headers = {
|
||||
'Authorization': f'QQBot {self.access_token}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
await self.logger.info(f'Sending rich media: {json.dumps(body, ensure_ascii=False)[:200]}')
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to send rich media message: HTTP {response.status_code} {response.text}')
|
||||
|
||||
async def send_image_msg(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
file_url: str = None,
|
||||
file_data: str = None,
|
||||
msg_id: str = None,
|
||||
content: str = None,
|
||||
):
|
||||
"""发送图片消息"""
|
||||
file_info = await self.upload_media(
|
||||
target_type,
|
||||
target_id,
|
||||
self.MEDIA_TYPE_IMAGE,
|
||||
file_url=file_url,
|
||||
file_data=file_data,
|
||||
)
|
||||
await self._send_media_msg(target_type, target_id, file_info, msg_id, content)
|
||||
|
||||
async def send_voice_msg(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
file_url: str = None,
|
||||
file_data: str = None,
|
||||
msg_id: str = None,
|
||||
):
|
||||
"""发送语音消息"""
|
||||
file_info = await self.upload_media(
|
||||
target_type,
|
||||
target_id,
|
||||
self.MEDIA_TYPE_VOICE,
|
||||
file_url=file_url,
|
||||
file_data=file_data,
|
||||
)
|
||||
await self._send_media_msg(target_type, target_id, file_info, msg_id)
|
||||
|
||||
async def send_file_msg(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
file_url: str = None,
|
||||
file_data: str = None,
|
||||
file_name: str = None,
|
||||
msg_id: str = None,
|
||||
):
|
||||
"""发送文件消息(含视频)"""
|
||||
file_info = await self.upload_media(
|
||||
target_type,
|
||||
target_id,
|
||||
self.MEDIA_TYPE_FILE,
|
||||
file_url=file_url,
|
||||
file_data=file_data,
|
||||
file_name=file_name,
|
||||
)
|
||||
await self._send_media_msg(target_type, target_id, file_info, msg_id)
|
||||
|
||||
async def send_stream_msg(
|
||||
self,
|
||||
user_openid: str,
|
||||
content: str,
|
||||
event_id: str,
|
||||
msg_id: str,
|
||||
msg_seq: int = 1,
|
||||
index: int = 0,
|
||||
stream_msg_id: str = None,
|
||||
input_state: int = 1,
|
||||
):
|
||||
"""发送流式消息(C2C 私聊)。
|
||||
|
||||
Args:
|
||||
input_state: 1=生成中, 10=生成结束
|
||||
"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
|
||||
url = f'{self.base_url}/v2/users/{user_openid}/stream_messages'
|
||||
body = {
|
||||
'input_mode': 'replace',
|
||||
'input_state': input_state,
|
||||
'content_type': 'markdown',
|
||||
'content_raw': content,
|
||||
'event_id': event_id,
|
||||
'msg_id': msg_id,
|
||||
'msg_seq': msg_seq,
|
||||
'index': index,
|
||||
}
|
||||
if stream_msg_id:
|
||||
body['stream_msg_id'] = stream_msg_id
|
||||
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
headers = {
|
||||
'Authorization': f'QQBot {self.access_token}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to send stream message: HTTP {response.status_code} {response.text}')
|
||||
return response.json()
|
||||
|
||||
async def is_token_expired(self):
|
||||
"""检查token是否过期"""
|
||||
if self.access_token_expiry_time is None:
|
||||
@@ -292,3 +513,325 @@ class QQOfficialClient:
|
||||
'signature': signature,
|
||||
}
|
||||
return response
|
||||
|
||||
# ---- WebSocket Gateway ----
|
||||
# Reference: https://bot.q.qq.com/wiki/develop/api-v2/dev-prepare/interface-framework/event-emit.html
|
||||
|
||||
INTENT_GUILDS = 1 << 0
|
||||
INTENT_GUILD_MEMBERS = 1 << 1
|
||||
INTENT_PUBLIC_GUILD_MESSAGES = 1 << 30
|
||||
INTENT_DIRECT_MESSAGE = 1 << 12
|
||||
INTENT_GROUP_AND_C2C = 1 << 25
|
||||
INTENT_INTERACTION = 1 << 26
|
||||
|
||||
FULL_INTENTS = (
|
||||
INTENT_GUILDS
|
||||
| INTENT_GUILD_MEMBERS
|
||||
| INTENT_PUBLIC_GUILD_MESSAGES
|
||||
| INTENT_DIRECT_MESSAGE
|
||||
| INTENT_GROUP_AND_C2C
|
||||
| INTENT_INTERACTION
|
||||
)
|
||||
|
||||
async def get_gateway_url(self) -> str:
|
||||
"""获取 WebSocket 网关地址"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
|
||||
url = f'{self.base_url}/gateway'
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
'Authorization': f'QQBot {self.access_token}',
|
||||
}
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
ws_url = data.get('url', '')
|
||||
if not ws_url:
|
||||
raise Exception('Gateway URL is empty')
|
||||
return ws_url
|
||||
else:
|
||||
raise Exception(f'Failed to get Gateway URL: HTTP {response.status_code} {response.text}')
|
||||
|
||||
async def _background_token_refresh(self):
|
||||
"""在 token 到期前主动刷新"""
|
||||
try:
|
||||
while True:
|
||||
if self.access_token_expiry_time:
|
||||
remain = self.access_token_expiry_time - time.time()
|
||||
if remain > 120:
|
||||
await asyncio.sleep(remain - 60)
|
||||
continue
|
||||
self.access_token = ''
|
||||
self.access_token_expiry_time = None
|
||||
if await self.check_access_token():
|
||||
await asyncio.sleep(60)
|
||||
else:
|
||||
await self.get_access_token()
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def connect_gateway(
|
||||
self,
|
||||
on_event: Callable[[str, dict], Any],
|
||||
on_ready: Optional[Callable[[], Any]] = None,
|
||||
on_error: Optional[Callable[[Exception], Any]] = None,
|
||||
):
|
||||
"""WebSocket 网关连接,含重连逻辑。持续重连直到达到最大次数或被取消。
|
||||
|
||||
Args:
|
||||
on_event: 收到 op=0 Dispatch 事件时的回调,参数为 (event_type, event_data)
|
||||
on_ready: 连接就绪 (收到 READY) 时的回调
|
||||
on_error: 发生错误时的回调
|
||||
"""
|
||||
import websockets
|
||||
|
||||
session_id = ''
|
||||
last_seq = 0
|
||||
reconnect_attempts = 0
|
||||
max_reconnect_attempts = 100
|
||||
backoff_delays = [1, 2, 5, 10, 30, 60]
|
||||
rate_limit_delay = 60
|
||||
|
||||
# Cancel previous token refresh task if any
|
||||
if self._token_refresh_task and not self._token_refresh_task.done():
|
||||
self._token_refresh_task.cancel()
|
||||
try:
|
||||
await self._token_refresh_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._token_refresh_task = None
|
||||
|
||||
while reconnect_attempts <= max_reconnect_attempts:
|
||||
heartbeat_interval = 45000
|
||||
should_refresh_token = False
|
||||
ws = None
|
||||
heartbeat_task = None
|
||||
|
||||
# Refresh token if needed
|
||||
if should_refresh_token:
|
||||
self.access_token = ''
|
||||
self.access_token_expiry_time = None
|
||||
|
||||
try:
|
||||
ws_url = await self.get_gateway_url()
|
||||
await self.logger.info(f'Gateway URL obtained: {ws_url[:60]}...')
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
await self.logger.error(f'Failed to get gateway URL: {e}')
|
||||
reconnect_attempts += 1
|
||||
if '100017' in error_msg or '频率' in error_msg or 'Too many' in error_msg:
|
||||
delay = rate_limit_delay
|
||||
else:
|
||||
delay = backoff_delays[min(reconnect_attempts - 1, len(backoff_delays) - 1)]
|
||||
await self.logger.info(f'Reconnecting in {delay}s (attempt {reconnect_attempts})')
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
try:
|
||||
await self.logger.info('Connecting to WebSocket gateway...')
|
||||
ws = await websockets.connect(ws_url)
|
||||
await self.logger.info('WebSocket connected')
|
||||
except Exception as e:
|
||||
await self.logger.error(f'WebSocket connection failed: {e}')
|
||||
reconnect_attempts += 1
|
||||
delay = backoff_delays[min(reconnect_attempts - 1, len(backoff_delays) - 1)]
|
||||
await self.logger.info(f'Reconnecting in {delay}s (attempt {reconnect_attempts})')
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
try:
|
||||
async for raw_msg in ws:
|
||||
try:
|
||||
payload = json.loads(raw_msg)
|
||||
except json.JSONDecodeError:
|
||||
await self.logger.error(f'Failed to parse message: {raw_msg}')
|
||||
continue
|
||||
|
||||
op = payload.get('op')
|
||||
d = payload.get('d', {})
|
||||
s = payload.get('s')
|
||||
t = payload.get('t')
|
||||
|
||||
if not isinstance(d, dict):
|
||||
d = {}
|
||||
|
||||
if op == 10: # Hello
|
||||
heartbeat_interval = d.get('heartbeat_interval', 45000)
|
||||
await self.logger.info(f'Received Hello, heartbeat_interval={heartbeat_interval}ms')
|
||||
|
||||
# Send Identify or Resume
|
||||
if session_id and last_seq > 0:
|
||||
resume_payload = {
|
||||
'op': 6,
|
||||
'd': {
|
||||
'token': f'QQBot {self.access_token}',
|
||||
'session_id': session_id,
|
||||
'seq': last_seq,
|
||||
},
|
||||
}
|
||||
await ws.send(json.dumps(resume_payload))
|
||||
await self.logger.info(f'Sent Resume, session_id={session_id}, seq={last_seq}')
|
||||
else:
|
||||
identify_payload = {
|
||||
'op': 2,
|
||||
'd': {
|
||||
'token': f'QQBot {self.access_token}',
|
||||
'intents': self.FULL_INTENTS,
|
||||
'shard': [0, 1],
|
||||
},
|
||||
}
|
||||
await ws.send(json.dumps(identify_payload))
|
||||
await self.logger.info(f'Sent Identify, intents={self.FULL_INTENTS}')
|
||||
|
||||
# Start heartbeat
|
||||
async def _heartbeat_loop(conn, interval_ms):
|
||||
interval_sec = interval_ms / 1000.0
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(interval_sec)
|
||||
try:
|
||||
hb_payload = {'op': 1, 'd': last_seq}
|
||||
await conn.send(json.dumps(hb_payload))
|
||||
except Exception:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
heartbeat_task = asyncio.create_task(_heartbeat_loop(ws, heartbeat_interval))
|
||||
|
||||
elif op == 0: # Dispatch
|
||||
if s is not None:
|
||||
last_seq = s
|
||||
|
||||
if t == 'READY':
|
||||
session_id = d.get('session_id', '')
|
||||
reconnect_attempts = 0
|
||||
await self.logger.info(f'READY, session_id={session_id}')
|
||||
if on_ready:
|
||||
try:
|
||||
result = on_ready()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception:
|
||||
pass
|
||||
# Track token refresh task to avoid leaks
|
||||
if self._token_refresh_task and not self._token_refresh_task.done():
|
||||
self._token_refresh_task.cancel()
|
||||
try:
|
||||
await self._token_refresh_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._token_refresh_task = asyncio.create_task(self._background_token_refresh())
|
||||
|
||||
elif t == 'RESUMED':
|
||||
reconnect_attempts = 0
|
||||
await self.logger.info('RESUMED')
|
||||
|
||||
else:
|
||||
await self.logger.debug(f'Received event: {t}, seq={s}')
|
||||
if on_event:
|
||||
try:
|
||||
result = on_event(t, d)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception:
|
||||
await self.logger.error(f'Error handling event {t}: {traceback.format_exc()}')
|
||||
|
||||
elif op == 11: # Heartbeat ACK
|
||||
pass
|
||||
|
||||
elif op == 7: # Reconnect
|
||||
await self.logger.info('Received Reconnect directive')
|
||||
break
|
||||
|
||||
elif op == 9: # Invalid Session
|
||||
can_resume = d.get('can_resume', False)
|
||||
await self.logger.warning(f'Invalid Session, can_resume={can_resume}')
|
||||
if not can_resume:
|
||||
session_id = ''
|
||||
last_seq = 0
|
||||
should_refresh_token = True
|
||||
break
|
||||
|
||||
# Connection closed normally (end of async for)
|
||||
try:
|
||||
close_code = ws.close_code
|
||||
close_reason = ws.close_reason or ''
|
||||
except Exception:
|
||||
close_code = None
|
||||
close_reason = ''
|
||||
await self.logger.info(f'Connection closed, code={close_code}, reason={close_reason}')
|
||||
|
||||
if close_code == 4004:
|
||||
should_refresh_token = True
|
||||
elif close_code in (4006, 4007, 4009):
|
||||
session_id = ''
|
||||
last_seq = 0
|
||||
should_refresh_token = True
|
||||
elif close_code == 4008:
|
||||
reconnect_attempts += 1
|
||||
delay = rate_limit_delay
|
||||
await self.logger.info(
|
||||
f'Rate limited, waiting {delay}s before reconnect (attempt {reconnect_attempts})'
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
elif close_code in (4914, 4915):
|
||||
err = Exception(f'Bot disconnected/banned (close_code={close_code})')
|
||||
if on_error:
|
||||
await self._safe_callback(on_error, err)
|
||||
return
|
||||
elif close_code in (4900, 4901, 4902, 4903, 4904, 4905, 4906, 4907, 4908, 4909, 4910, 4911, 4912, 4913):
|
||||
session_id = ''
|
||||
last_seq = 0
|
||||
|
||||
if close_code == 1000:
|
||||
return
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
await self.logger.error(f'Unexpected error in WebSocket loop: {traceback.format_exc()}')
|
||||
finally:
|
||||
if heartbeat_task:
|
||||
heartbeat_task.cancel()
|
||||
try:
|
||||
await heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if ws:
|
||||
try:
|
||||
await ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If we reach here, we need to reconnect
|
||||
reconnect_attempts += 1
|
||||
if reconnect_attempts > max_reconnect_attempts:
|
||||
await self.logger.error(f'Max reconnect attempts ({max_reconnect_attempts}) reached, stopping')
|
||||
if on_error:
|
||||
await self._safe_callback(on_error, Exception('Max reconnect attempts reached'))
|
||||
return
|
||||
delay = backoff_delays[min(reconnect_attempts - 1, len(backoff_delays) - 1)]
|
||||
await self.logger.info(f'Reconnecting in {delay}s (attempt {reconnect_attempts})')
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async def _safe_callback(self, callback, *args):
|
||||
"""Safely invoke a callback, handling both sync and async functions."""
|
||||
try:
|
||||
result = callback(*args)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def connect_gateway_loop(
|
||||
self,
|
||||
on_event: Callable[[str, dict], Any],
|
||||
on_ready: Optional[Callable[[], Any]] = None,
|
||||
on_error: Optional[Callable[[Exception], Any]] = None,
|
||||
):
|
||||
"""持续重连的网关循环。"""
|
||||
await self.connect_gateway(on_event, on_ready, on_error)
|
||||
|
||||
@@ -6,7 +6,8 @@ import traceback
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Optional
|
||||
import re
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
from urllib.parse import unquote
|
||||
|
||||
import httpx
|
||||
@@ -63,16 +64,25 @@ class StreamSession:
|
||||
# 缓存最近一次片段,处理重试或超时兜底
|
||||
last_chunk: Optional[StreamChunk] = None
|
||||
|
||||
# 反馈 ID,用于接收用户点赞/点踩反馈
|
||||
feedback_id: Optional[str] = None
|
||||
|
||||
|
||||
class StreamSessionManager:
|
||||
"""管理 stream 会话的生命周期,并负责队列的生产消费。"""
|
||||
|
||||
# Sessions with registered feedback_ids use a longer TTL to survive the
|
||||
# full like → cancel → dislike feedback flow. Must align with the adapter's
|
||||
# _stream_to_monitoring_msg TTL (wecombot.py).
|
||||
_FEEDBACK_SESSION_TTL = 600 # 10 minutes
|
||||
|
||||
def __init__(self, logger: EventLogger, ttl: int = 60) -> None:
|
||||
self.logger = logger
|
||||
|
||||
self.ttl = ttl # 超时时间(秒),超过该时间未被访问的会话会被清理由 cleanup
|
||||
self._sessions: dict[str, StreamSession] = {} # stream_id -> StreamSession 映射
|
||||
self._msg_index: dict[str, str] = {} # msgid -> stream_id 映射,便于流水线根据消息 ID 找到会话
|
||||
self._feedback_index: dict[str, str] = {} # feedback_id -> stream_id 映射
|
||||
|
||||
def get_stream_id_by_msg(self, msg_id: str) -> Optional[str]:
|
||||
if not msg_id:
|
||||
@@ -82,6 +92,32 @@ class StreamSessionManager:
|
||||
def get_session(self, stream_id: str) -> Optional[StreamSession]:
|
||||
return self._sessions.get(stream_id)
|
||||
|
||||
def get_session_by_feedback_id(self, feedback_id: str) -> Optional[StreamSession]:
|
||||
"""根据 feedback_id 查找会话。
|
||||
|
||||
Args:
|
||||
feedback_id: 企业微信反馈事件中的反馈 ID。
|
||||
|
||||
Returns:
|
||||
Optional[StreamSession]: 找到的会话实例,未找到返回 None。
|
||||
"""
|
||||
if not feedback_id:
|
||||
return None
|
||||
stream_id = self._feedback_index.get(feedback_id)
|
||||
if stream_id:
|
||||
return self._sessions.get(stream_id)
|
||||
return None
|
||||
|
||||
def register_feedback_id(self, stream_id: str, feedback_id: str) -> None:
|
||||
"""注册 feedback_id 与 stream_id 的映射。
|
||||
|
||||
Args:
|
||||
stream_id: 企业微信流式会话 ID。
|
||||
feedback_id: 反馈 ID。
|
||||
"""
|
||||
if feedback_id and stream_id:
|
||||
self._feedback_index[feedback_id] = stream_id
|
||||
|
||||
def create_or_get(self, msg_json: dict[str, Any]) -> tuple[StreamSession, bool]:
|
||||
"""根据企业微信回调创建或获取会话。
|
||||
|
||||
@@ -183,11 +219,17 @@ class StreamSessionManager:
|
||||
session.last_access = time.time()
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""定期清理过期会话,防止队列与映射无上限累积。"""
|
||||
"""定期清理过期会话,防止队列与映射无上限累积。
|
||||
|
||||
已注册 feedback_id 的会话使用更长的 TTL,确保用户在点赞/取消/点踩流程中
|
||||
不会因为 session 被提前清除而丢失上下文信息。
|
||||
"""
|
||||
now = time.time()
|
||||
expired: list[str] = []
|
||||
for stream_id, session in self._sessions.items():
|
||||
if now - session.last_access > self.ttl:
|
||||
# Sessions with registered feedback_ids use a longer TTL
|
||||
effective_ttl = self._FEEDBACK_SESSION_TTL if session.feedback_id else self.ttl
|
||||
if now - session.last_access > effective_ttl:
|
||||
expired.append(stream_id)
|
||||
|
||||
for stream_id in expired:
|
||||
@@ -197,6 +239,488 @@ class StreamSessionManager:
|
||||
msg_id = session.msg_id
|
||||
if msg_id and self._msg_index.get(msg_id) == stream_id:
|
||||
self._msg_index.pop(msg_id, None)
|
||||
# Clean up feedback index for expired sessions
|
||||
if session.feedback_id:
|
||||
self._feedback_index.pop(session.feedback_id, None)
|
||||
|
||||
|
||||
def _decrypt_file(encrypted_data: bytes, aes_key_str: str) -> bytes:
|
||||
"""Decrypt AES-256-CBC encrypted file data.
|
||||
|
||||
Aligned with the official WeCom AI Bot Python SDK (crypto_utils.py).
|
||||
|
||||
Args:
|
||||
encrypted_data: The raw encrypted bytes.
|
||||
aes_key_str: Base64-encoded AES key (may lack padding).
|
||||
|
||||
Returns:
|
||||
Decrypted bytes with PKCS#7 padding removed.
|
||||
"""
|
||||
if not encrypted_data:
|
||||
raise ValueError('encrypted_data is empty')
|
||||
if not aes_key_str:
|
||||
raise ValueError('aes_key is empty')
|
||||
|
||||
# Python's base64.b64decode requires proper padding (length % 4 == 0).
|
||||
# Node.js Buffer.from tolerates missing '=', so we must pad manually.
|
||||
remainder = len(aes_key_str) % 4
|
||||
if remainder != 0:
|
||||
aes_key_str = aes_key_str + '=' * (4 - remainder)
|
||||
key = base64.b64decode(aes_key_str)
|
||||
|
||||
iv = key[:16]
|
||||
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
|
||||
# Ensure encrypted data is aligned to AES block size (16 bytes).
|
||||
# Node.js setAutoPadding(false) silently handles unaligned data,
|
||||
# but PyCryptodome will raise an error.
|
||||
block_size = 16
|
||||
data_remainder = len(encrypted_data) % block_size
|
||||
if data_remainder != 0:
|
||||
encrypted_data = encrypted_data + b'\x00' * (block_size - data_remainder)
|
||||
|
||||
decrypted = cipher.decrypt(encrypted_data)
|
||||
|
||||
# Remove PKCS#7 padding with validation
|
||||
if len(decrypted) == 0:
|
||||
raise ValueError('Decrypted data is empty')
|
||||
|
||||
pad_len = decrypted[-1]
|
||||
if pad_len < 1 or pad_len > 32 or pad_len > len(decrypted):
|
||||
raise ValueError(f'Invalid PKCS#7 padding value: {pad_len}')
|
||||
|
||||
# Verify all padding bytes are consistent
|
||||
for i in range(len(decrypted) - pad_len, len(decrypted)):
|
||||
if decrypted[i] != pad_len:
|
||||
raise ValueError('Invalid PKCS#7 padding: padding bytes mismatch')
|
||||
|
||||
return decrypted[: len(decrypted) - pad_len]
|
||||
|
||||
|
||||
def _extract_filename(content_disposition: str) -> Optional[str]:
|
||||
"""Extract filename from a Content-Disposition header value."""
|
||||
if not content_disposition:
|
||||
return None
|
||||
# RFC 5987: filename*=UTF-8''xxx
|
||||
utf8_match = re.search(r"filename\*=UTF-8''([^;\s]+)", content_disposition, re.IGNORECASE)
|
||||
if utf8_match:
|
||||
return unquote(utf8_match.group(1))
|
||||
# Standard: filename="xxx" or filename=xxx
|
||||
match = re.search(r'filename="?([^";\s]+)"?', content_disposition, re.IGNORECASE)
|
||||
if match:
|
||||
return unquote(match.group(1))
|
||||
return None
|
||||
|
||||
|
||||
def _bytes_to_data_uri(data: bytes) -> str:
|
||||
"""Convert raw bytes to a data URI with auto-detected MIME type."""
|
||||
if data.startswith(b'\xff\xd8'):
|
||||
mime_type = 'image/jpeg'
|
||||
elif data.startswith(b'\x89PNG'):
|
||||
mime_type = 'image/png'
|
||||
elif data.startswith((b'GIF87a', b'GIF89a')):
|
||||
mime_type = 'image/gif'
|
||||
elif data.startswith(b'BM'):
|
||||
mime_type = 'image/bmp'
|
||||
elif data.startswith(b'II*\x00') or data.startswith(b'MM\x00*'):
|
||||
mime_type = 'image/tiff'
|
||||
elif data[:4] == b'%PDF':
|
||||
mime_type = 'application/pdf'
|
||||
elif data[:4] == b'PK\x03\x04':
|
||||
mime_type = 'application/zip'
|
||||
else:
|
||||
mime_type = 'application/octet-stream'
|
||||
|
||||
base64_str = base64.b64encode(data).decode('utf-8')
|
||||
return f'data:{mime_type};base64,{base64_str}'
|
||||
|
||||
|
||||
async def download_encrypted_file(
|
||||
download_url: str, aes_key: str, logger: EventLogger
|
||||
) -> Tuple[Optional[bytes], Optional[str]]:
|
||||
"""Download an AES-encrypted file from WeChat Work and decrypt it.
|
||||
|
||||
Args:
|
||||
download_url: The encrypted file download URL.
|
||||
aes_key: The AES key for decryption (base64-encoded, per-message aeskey
|
||||
or platform EncodingAESKey).
|
||||
logger: Logger instance.
|
||||
|
||||
Returns:
|
||||
A tuple of (decrypted_bytes, filename) or (None, None) on failure.
|
||||
"""
|
||||
if not download_url:
|
||||
return None, None
|
||||
if not aes_key:
|
||||
await logger.error('download_encrypted_file: aes_key is empty, cannot decrypt')
|
||||
return None, None
|
||||
|
||||
filename: Optional[str] = None
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(download_url)
|
||||
if response.status_code != 200:
|
||||
await logger.error(f'Failed to download file (HTTP {response.status_code}): {response.text[:200]}')
|
||||
return None, None
|
||||
encrypted_bytes = response.content
|
||||
filename = _extract_filename(response.headers.get('content-disposition', ''))
|
||||
except Exception:
|
||||
await logger.error(f'Failed to download file: {traceback.format_exc()}')
|
||||
return None, None
|
||||
|
||||
try:
|
||||
decrypted = _decrypt_file(encrypted_bytes, aes_key)
|
||||
return decrypted, filename
|
||||
except Exception:
|
||||
await logger.error(f'Failed to decrypt file: {traceback.format_exc()}')
|
||||
return None, None
|
||||
|
||||
|
||||
async def parse_wecom_bot_message(
|
||||
msg_json: dict[str, Any], encoding_aes_key: str, logger: EventLogger
|
||||
) -> dict[str, Any]:
|
||||
"""Parse a decrypted WeChat Work AI Bot message JSON into a unified message dict.
|
||||
|
||||
This is the shared message parsing logic used by both webhook and WebSocket modes.
|
||||
|
||||
Args:
|
||||
msg_json: The decrypted message JSON from WeChat Work.
|
||||
encoding_aes_key: AES key for file decryption.
|
||||
logger: Logger instance.
|
||||
|
||||
Returns:
|
||||
A dict suitable for constructing a WecomBotEvent.
|
||||
"""
|
||||
message_data: dict[str, Any] = {}
|
||||
|
||||
msg_type = msg_json.get('msgtype', '')
|
||||
if msg_type:
|
||||
message_data['msgtype'] = msg_type
|
||||
|
||||
if msg_json.get('chattype', '') == 'single':
|
||||
message_data['type'] = 'single'
|
||||
elif msg_json.get('chattype', '') == 'group':
|
||||
message_data['type'] = 'group'
|
||||
|
||||
max_inline_file_size = 5 * 1024 * 1024
|
||||
|
||||
async def _safe_download(url: str, per_msg_aeskey: str = '') -> Tuple[Optional[bytes], Optional[str]]:
|
||||
"""Download and decrypt a file, preferring per-message aeskey over platform key."""
|
||||
if not url:
|
||||
return None, None
|
||||
key = per_msg_aeskey or encoding_aes_key
|
||||
if not key:
|
||||
await logger.warning('No AES key available for file decryption, skipping download')
|
||||
return None, None
|
||||
return await download_encrypted_file(url, key, logger)
|
||||
|
||||
async def _safe_download_as_data_uri(url: str, per_msg_aeskey: str = '') -> Optional[str]:
|
||||
"""Download, decrypt, and convert to data URI for backward compatibility."""
|
||||
data, _filename = await _safe_download(url, per_msg_aeskey)
|
||||
if data:
|
||||
return _bytes_to_data_uri(data)
|
||||
return None
|
||||
|
||||
if msg_type == 'text':
|
||||
message_data['content'] = msg_json.get('text', {}).get('content')
|
||||
elif msg_type == 'markdown':
|
||||
message_data['content'] = msg_json.get('markdown', {}).get('content') or msg_json.get('text', {}).get(
|
||||
'content', ''
|
||||
)
|
||||
elif msg_type == 'image':
|
||||
image_info = msg_json.get('image', {})
|
||||
picurl = image_info.get('url', '')
|
||||
per_msg_aeskey = image_info.get('aeskey', '')
|
||||
base64_data = await _safe_download_as_data_uri(picurl, per_msg_aeskey)
|
||||
if base64_data:
|
||||
message_data['picurl'] = base64_data
|
||||
message_data['images'] = [base64_data]
|
||||
elif msg_type == 'voice':
|
||||
voice_info = msg_json.get('voice', {}) or {}
|
||||
download_url = voice_info.get('url')
|
||||
per_msg_aeskey = voice_info.get('aeskey', '')
|
||||
message_data['voice'] = {
|
||||
'url': download_url,
|
||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||
}
|
||||
if voice_info.get('content'):
|
||||
message_data['content'] = voice_info.get('content')
|
||||
# if (message_data['voice'].get('filesize') or 0) <= max_inline_file_size:
|
||||
# voice_base64 = await _safe_download_as_data_uri(download_url, per_msg_aeskey)
|
||||
# if voice_base64:
|
||||
# message_data['voice']['base64'] = voice_base64
|
||||
elif msg_type == 'video':
|
||||
video_info = msg_json.get('video', {}) or {}
|
||||
download_url = video_info.get('url')
|
||||
per_msg_aeskey = video_info.get('aeskey', '')
|
||||
video_data = {
|
||||
'url': download_url,
|
||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||
'filename': video_info.get('filename') or video_info.get('name'),
|
||||
}
|
||||
# if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
# video_base64 = await _safe_download_as_data_uri(download_url, per_msg_aeskey)
|
||||
# if video_base64:
|
||||
# video_data['base64'] = video_base64
|
||||
# 应为需要解密,但是目前暂时不能下载到内部进行解密,所以先将下载链接拼接aeskey返回给用户,由插件去处理该链接的下载和解密逻辑
|
||||
video_data['download_url'] = download_url + f'?aeskey={per_msg_aeskey}'
|
||||
message_data['video'] = video_data
|
||||
elif msg_type == 'file':
|
||||
file_info = msg_json.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
per_msg_aeskey = file_info.get('aeskey', '')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
# if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
# file_bytes, dl_filename = await _safe_download(download_url, per_msg_aeskey)
|
||||
# if file_bytes:
|
||||
# file_data['base64'] = _bytes_to_data_uri(file_bytes)
|
||||
# if dl_filename and not file_data.get('filename'):
|
||||
# file_data['filename'] = dl_filename
|
||||
|
||||
# 应为需要解密,但是目前暂时不能下载到内部进行解密,所以先将下载链接拼接aeskey返回给用户,由插件去处理该链接的下载和解密逻辑
|
||||
file_data['download_url'] = download_url + f'?aeskey={per_msg_aeskey}'
|
||||
message_data['file'] = file_data
|
||||
elif msg_type == 'link':
|
||||
message_data['link'] = msg_json.get('link', {})
|
||||
if not message_data.get('content'):
|
||||
title = message_data['link'].get('title', '')
|
||||
desc = message_data['link'].get('description') or message_data['link'].get('digest', '')
|
||||
message_data['content'] = '\n'.join(filter(None, [title, desc]))
|
||||
elif msg_type == 'mixed':
|
||||
items = msg_json.get('mixed', {}).get('msg_item', [])
|
||||
texts = []
|
||||
images = []
|
||||
files = []
|
||||
voices = []
|
||||
videos = []
|
||||
links = []
|
||||
for item in items:
|
||||
item_type = item.get('msgtype')
|
||||
if item_type == 'text':
|
||||
texts.append(item.get('text', {}).get('content', ''))
|
||||
elif item_type == 'image':
|
||||
img_info = item.get('image', {})
|
||||
img_url = img_info.get('url')
|
||||
img_aeskey = img_info.get('aeskey', '')
|
||||
base64_data = await _safe_download_as_data_uri(img_url, img_aeskey)
|
||||
if base64_data:
|
||||
images.append(base64_data)
|
||||
elif item_type == 'file':
|
||||
file_info = item.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
item_aeskey = file_info.get('aeskey', '')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
file_bytes, dl_filename = await _safe_download(download_url, item_aeskey)
|
||||
if file_bytes:
|
||||
file_data['base64'] = _bytes_to_data_uri(file_bytes)
|
||||
if dl_filename and not file_data.get('filename'):
|
||||
file_data['filename'] = dl_filename
|
||||
files.append(file_data)
|
||||
elif item_type == 'voice':
|
||||
voice_info = item.get('voice', {}) or {}
|
||||
download_url = voice_info.get('url')
|
||||
item_aeskey = voice_info.get('aeskey', '')
|
||||
voice_data = {
|
||||
'url': download_url,
|
||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||
}
|
||||
if voice_info.get('content'):
|
||||
texts.append(voice_info.get('content'))
|
||||
if (voice_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
voice_base64 = await _safe_download_as_data_uri(download_url, item_aeskey)
|
||||
if voice_base64:
|
||||
voice_data['base64'] = voice_base64
|
||||
voices.append(voice_data)
|
||||
elif item_type == 'video':
|
||||
video_info = item.get('video', {}) or {}
|
||||
download_url = video_info.get('url')
|
||||
item_aeskey = video_info.get('aeskey', '')
|
||||
video_data = {
|
||||
'url': download_url,
|
||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||
'filename': video_info.get('filename') or video_info.get('name'),
|
||||
}
|
||||
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
video_base64 = await _safe_download_as_data_uri(download_url, item_aeskey)
|
||||
if video_base64:
|
||||
video_data['base64'] = video_base64
|
||||
videos.append(video_data)
|
||||
elif item_type == 'link':
|
||||
links.append(item.get('link', {}))
|
||||
|
||||
if texts:
|
||||
message_data['content'] = ' '.join(texts)
|
||||
if images:
|
||||
message_data['images'] = images
|
||||
message_data['picurl'] = images[0]
|
||||
if files:
|
||||
message_data['files'] = files
|
||||
message_data['file'] = files[0]
|
||||
if voices:
|
||||
message_data['voices'] = voices
|
||||
message_data['voice'] = voices[0]
|
||||
if videos:
|
||||
message_data['videos'] = videos
|
||||
message_data['video'] = videos[0]
|
||||
if links:
|
||||
message_data['link'] = links[0]
|
||||
if items:
|
||||
message_data['attachments'] = items
|
||||
else:
|
||||
message_data['raw_msg'] = msg_json
|
||||
|
||||
from_info = msg_json.get('from', {})
|
||||
message_data['userid'] = from_info.get('userid', '')
|
||||
message_data['username'] = from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
|
||||
|
||||
if msg_json.get('chattype', '') == 'group':
|
||||
message_data['chatid'] = msg_json.get('chatid', '')
|
||||
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
|
||||
|
||||
message_data['msgid'] = msg_json.get('msgid', '')
|
||||
|
||||
if msg_json.get('aibotid'):
|
||||
message_data['aibotid'] = msg_json.get('aibotid', '')
|
||||
|
||||
# Handle quote (referenced message) - important for group chat file references
|
||||
quote_info = msg_json.get('quote')
|
||||
if quote_info:
|
||||
quote_data: dict[str, Any] = {}
|
||||
quote_type = quote_info.get('msgtype', '')
|
||||
quote_data['msgtype'] = quote_type
|
||||
|
||||
if quote_type == 'text':
|
||||
quote_data['content'] = quote_info.get('text', {}).get('content', '')
|
||||
elif quote_type == 'image':
|
||||
img_info = quote_info.get('image', {})
|
||||
img_url = img_info.get('url', '')
|
||||
img_aeskey = img_info.get('aeskey', '')
|
||||
base64_data = await _safe_download_as_data_uri(img_url, img_aeskey)
|
||||
if base64_data:
|
||||
quote_data['picurl'] = base64_data
|
||||
quote_data['images'] = [base64_data]
|
||||
elif quote_type == 'file':
|
||||
file_info = quote_info.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
item_aeskey = file_info.get('aeskey', '')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
# Same as private chat: append aeskey to download_url for plugin processing
|
||||
if download_url and item_aeskey:
|
||||
file_data['download_url'] = download_url + f'?aeskey={item_aeskey}'
|
||||
quote_data['file'] = file_data
|
||||
elif quote_type == 'voice':
|
||||
voice_info = quote_info.get('voice', {}) or {}
|
||||
download_url = voice_info.get('url')
|
||||
item_aeskey = voice_info.get('aeskey', '')
|
||||
voice_data = {
|
||||
'url': download_url,
|
||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||
}
|
||||
if voice_info.get('content'):
|
||||
quote_data['content'] = voice_info.get('content')
|
||||
# Same as private chat: append aeskey to url for plugin processing
|
||||
if download_url and item_aeskey:
|
||||
voice_data['url'] = download_url + f'?aeskey={item_aeskey}'
|
||||
quote_data['voice'] = voice_data
|
||||
elif quote_type == 'video':
|
||||
video_info = quote_info.get('video', {}) or {}
|
||||
download_url = video_info.get('url')
|
||||
item_aeskey = video_info.get('aeskey', '')
|
||||
video_data = {
|
||||
'url': download_url,
|
||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||
'filename': video_info.get('filename') or video_info.get('name'),
|
||||
}
|
||||
# Same as private chat: append aeskey to download_url for plugin processing
|
||||
if download_url and item_aeskey:
|
||||
video_data['download_url'] = download_url + f'?aeskey={item_aeskey}'
|
||||
quote_data['video'] = video_data
|
||||
elif quote_type == 'link':
|
||||
quote_data['link'] = quote_info.get('link', {})
|
||||
link = quote_data['link']
|
||||
title = link.get('title', '')
|
||||
desc = link.get('description') or link.get('digest', '')
|
||||
quote_data['content'] = '\n'.join(filter(None, [title, desc]))
|
||||
elif quote_type == 'mixed':
|
||||
# Handle mixed type in quote (text + images + files etc.)
|
||||
items = quote_info.get('mixed', {}).get('msg_item', [])
|
||||
texts = []
|
||||
images = []
|
||||
files = []
|
||||
for item in items:
|
||||
item_type = item.get('msgtype')
|
||||
if item_type == 'text':
|
||||
texts.append(item.get('text', {}).get('content', ''))
|
||||
elif item_type == 'image':
|
||||
img_info = item.get('image', {})
|
||||
img_url = img_info.get('url')
|
||||
img_aeskey = img_info.get('aeskey', '')
|
||||
base64_data = await _safe_download_as_data_uri(img_url, img_aeskey)
|
||||
if base64_data:
|
||||
images.append(base64_data)
|
||||
elif item_type == 'file':
|
||||
file_info = item.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
item_aeskey = file_info.get('aeskey', '')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
# Same as private chat: append aeskey to download_url for plugin processing
|
||||
if download_url and item_aeskey:
|
||||
file_data['download_url'] = download_url + f'?aeskey={item_aeskey}'
|
||||
files.append(file_data)
|
||||
if texts:
|
||||
quote_data['content'] = ' '.join(texts)
|
||||
if images:
|
||||
quote_data['images'] = images
|
||||
quote_data['picurl'] = images[0]
|
||||
if files:
|
||||
quote_data['files'] = files
|
||||
quote_data['file'] = files[0]
|
||||
|
||||
message_data['quote'] = quote_data
|
||||
|
||||
return message_data
|
||||
|
||||
|
||||
class WecomBotClient:
|
||||
@@ -236,14 +760,27 @@ class WecomBotClient:
|
||||
self.stream_sessions = StreamSessionManager(logger=logger)
|
||||
self.stream_poll_timeout = 0.5
|
||||
|
||||
self._feedback_callback: Optional[Callable] = None
|
||||
|
||||
def set_feedback_callback(self, callback: Callable) -> None:
|
||||
"""设置反馈回调函数。
|
||||
|
||||
Args:
|
||||
callback: 反馈回调函数,签名: async def callback(feedback_id, feedback_type, feedback_content, inaccurate_reasons, session)
|
||||
"""
|
||||
self._feedback_callback = callback
|
||||
|
||||
@staticmethod
|
||||
def _build_stream_payload(stream_id: str, content: str, finish: bool) -> dict[str, Any]:
|
||||
def _build_stream_payload(
|
||||
stream_id: str, content: str, finish: bool, feedback_id: Optional[str] = None
|
||||
) -> dict[str, Any]:
|
||||
"""按照企业微信协议拼装返回报文。
|
||||
|
||||
Args:
|
||||
stream_id: 企业微信会话 ID。
|
||||
content: 推送的文本内容。
|
||||
finish: 是否为最终片段。
|
||||
feedback_id: 反馈 ID,用于接收用户点赞/点踩反馈。
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: 可直接加密返回的 payload。
|
||||
@@ -251,13 +788,16 @@ class WecomBotClient:
|
||||
Example:
|
||||
组装 `{'msgtype': 'stream', 'stream': {'id': 'sid', ...}}` 结构。
|
||||
"""
|
||||
stream_payload = {
|
||||
'id': stream_id,
|
||||
'finish': finish,
|
||||
'content': content,
|
||||
}
|
||||
if feedback_id:
|
||||
stream_payload['feedback'] = {'id': feedback_id}
|
||||
return {
|
||||
'msgtype': 'stream',
|
||||
'stream': {
|
||||
'id': stream_id,
|
||||
'finish': finish,
|
||||
'content': content,
|
||||
},
|
||||
'stream': stream_payload,
|
||||
}
|
||||
|
||||
async def _encrypt_and_reply(self, payload: dict[str, Any], nonce: str) -> tuple[Response, int]:
|
||||
@@ -313,9 +853,14 @@ class WecomBotClient:
|
||||
"""
|
||||
session, is_new = self.stream_sessions.create_or_get(msg_json)
|
||||
|
||||
feedback_id = str(uuid.uuid4())
|
||||
session.feedback_id = feedback_id
|
||||
self.stream_sessions.register_feedback_id(session.stream_id, feedback_id)
|
||||
|
||||
message_data = await self.get_message(msg_json)
|
||||
if message_data:
|
||||
message_data['stream_id'] = session.stream_id
|
||||
message_data['feedback_id'] = feedback_id
|
||||
try:
|
||||
event = wecombotevent.WecomBotEvent(message_data)
|
||||
except Exception:
|
||||
@@ -324,7 +869,7 @@ class WecomBotClient:
|
||||
if is_new:
|
||||
asyncio.create_task(self._dispatch_event(event))
|
||||
|
||||
payload = self._build_stream_payload(session.stream_id, '', False)
|
||||
payload = self._build_stream_payload(session.stream_id, '', False, feedback_id)
|
||||
return await self._encrypt_and_reply(payload, nonce)
|
||||
|
||||
async def _handle_post_followup_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
|
||||
@@ -449,202 +994,83 @@ class WecomBotClient:
|
||||
|
||||
msg_json = json.loads(decrypted_xml)
|
||||
|
||||
event = msg_json.get('event', {})
|
||||
event_type = event.get('eventtype', '')
|
||||
|
||||
if event_type == 'feedback_event':
|
||||
return await self._handle_feedback_event(msg_json, nonce)
|
||||
|
||||
if msg_json.get('msgtype') == 'stream':
|
||||
return await self._handle_post_followup_response(msg_json, nonce)
|
||||
|
||||
return await self._handle_post_initial_response(msg_json, nonce)
|
||||
|
||||
async def get_message(self, msg_json):
|
||||
message_data = {}
|
||||
async def _handle_feedback_event(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
|
||||
"""处理企业微信用户反馈事件(点赞/点踩)。
|
||||
|
||||
msg_type = msg_json.get('msgtype', '')
|
||||
if msg_type:
|
||||
message_data['msgtype'] = msg_type
|
||||
Args:
|
||||
msg_json: 解密后的企业微信反馈事件 JSON。
|
||||
nonce: 企业微信回调参数 nonce。
|
||||
|
||||
if msg_json.get('chattype', '') == 'single':
|
||||
message_data['type'] = 'single'
|
||||
elif msg_json.get('chattype', '') == 'group':
|
||||
message_data['type'] = 'group'
|
||||
Returns:
|
||||
Tuple[Response, int]: Quart Response 及状态码。
|
||||
|
||||
max_inline_file_size = 5 * 1024 * 1024 # avoid decoding very large payloads by default
|
||||
Note:
|
||||
企业微信协议要求:反馈事件目前仅支持回复空包。
|
||||
"""
|
||||
try:
|
||||
feedback_event = msg_json.get('event', {}).get('feedback_event', {})
|
||||
feedback_id = feedback_event.get('id', '')
|
||||
feedback_type = feedback_event.get('type', 0)
|
||||
feedback_content = feedback_event.get('content', '')
|
||||
inaccurate_reasons = feedback_event.get('inaccurate_reason_list', [])
|
||||
|
||||
async def _safe_download(url: str):
|
||||
if not url:
|
||||
return None
|
||||
return await self.download_url_to_base64(url, self.EnCodingAESKey)
|
||||
|
||||
if msg_type == 'text':
|
||||
message_data['content'] = msg_json.get('text', {}).get('content')
|
||||
elif msg_type == 'markdown':
|
||||
message_data['content'] = msg_json.get('markdown', {}).get('content') or msg_json.get('text', {}).get(
|
||||
'content', ''
|
||||
await self.logger.info(
|
||||
f'收到用户反馈事件: feedback_id={feedback_id}, type={feedback_type}, '
|
||||
f'content={feedback_content}, reasons={inaccurate_reasons}'
|
||||
)
|
||||
elif msg_type == 'image':
|
||||
picurl = msg_json.get('image', {}).get('url', '')
|
||||
base64_data = await _safe_download(picurl)
|
||||
if base64_data:
|
||||
message_data['picurl'] = base64_data
|
||||
message_data['images'] = [base64_data]
|
||||
elif msg_type == 'voice':
|
||||
voice_info = msg_json.get('voice', {}) or {}
|
||||
download_url = voice_info.get('url')
|
||||
message_data['voice'] = {
|
||||
'url': download_url,
|
||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||
}
|
||||
# 企业微信智能转写文本(如果已有)直接复用,避免重复转写
|
||||
if voice_info.get('content'):
|
||||
message_data['content'] = voice_info.get('content')
|
||||
if (message_data['voice'].get('filesize') or 0) <= max_inline_file_size:
|
||||
voice_base64 = await _safe_download(download_url)
|
||||
if voice_base64:
|
||||
message_data['voice']['base64'] = voice_base64
|
||||
elif msg_type == 'video':
|
||||
video_info = msg_json.get('video', {}) or {}
|
||||
download_url = video_info.get('url')
|
||||
video_data = {
|
||||
'url': download_url,
|
||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||
'filename': video_info.get('filename') or video_info.get('name'),
|
||||
}
|
||||
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
video_base64 = await _safe_download(download_url)
|
||||
if video_base64:
|
||||
video_data['base64'] = video_base64
|
||||
message_data['video'] = video_data
|
||||
elif msg_type == 'file':
|
||||
file_info = msg_json.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
file_base64 = await _safe_download(download_url)
|
||||
if file_base64:
|
||||
file_data['base64'] = file_base64
|
||||
message_data['file'] = file_data
|
||||
elif msg_type == 'link':
|
||||
message_data['link'] = msg_json.get('link', {})
|
||||
if not message_data.get('content'):
|
||||
title = message_data['link'].get('title', '')
|
||||
desc = message_data['link'].get('description') or message_data['link'].get('digest', '')
|
||||
message_data['content'] = '\n'.join(filter(None, [title, desc]))
|
||||
elif msg_type == 'mixed':
|
||||
items = msg_json.get('mixed', {}).get('msg_item', [])
|
||||
texts = []
|
||||
images = []
|
||||
files = []
|
||||
voices = []
|
||||
videos = []
|
||||
links = []
|
||||
for item in items:
|
||||
item_type = item.get('msgtype')
|
||||
if item_type == 'text':
|
||||
texts.append(item.get('text', {}).get('content', ''))
|
||||
elif item_type == 'image':
|
||||
img_url = item.get('image', {}).get('url')
|
||||
base64_data = await _safe_download(img_url)
|
||||
if base64_data:
|
||||
images.append(base64_data)
|
||||
elif item_type == 'file':
|
||||
file_info = item.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
file_base64 = await _safe_download(download_url)
|
||||
if file_base64:
|
||||
file_data['base64'] = file_base64
|
||||
files.append(file_data)
|
||||
elif item_type == 'voice':
|
||||
voice_info = item.get('voice', {}) or {}
|
||||
download_url = voice_info.get('url')
|
||||
voice_data = {
|
||||
'url': download_url,
|
||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||
}
|
||||
if voice_info.get('content'):
|
||||
texts.append(voice_info.get('content'))
|
||||
if (voice_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
voice_base64 = await _safe_download(download_url)
|
||||
if voice_base64:
|
||||
voice_data['base64'] = voice_base64
|
||||
voices.append(voice_data)
|
||||
elif item_type == 'video':
|
||||
video_info = item.get('video', {}) or {}
|
||||
download_url = video_info.get('url')
|
||||
video_data = {
|
||||
'url': download_url,
|
||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||
'filename': video_info.get('filename') or video_info.get('name'),
|
||||
}
|
||||
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
video_base64 = await _safe_download(download_url)
|
||||
if video_base64:
|
||||
video_data['base64'] = video_base64
|
||||
videos.append(video_data)
|
||||
elif item_type == 'link':
|
||||
links.append(item.get('link', {}))
|
||||
|
||||
if texts:
|
||||
message_data['content'] = ' '.join(texts) # 拼接所有 text
|
||||
if images:
|
||||
message_data['images'] = images
|
||||
message_data['picurl'] = images[0] # 只保留第一个 image
|
||||
if files:
|
||||
message_data['files'] = files
|
||||
message_data['file'] = files[0]
|
||||
if voices:
|
||||
message_data['voices'] = voices
|
||||
message_data['voice'] = voices[0]
|
||||
if videos:
|
||||
message_data['videos'] = videos
|
||||
message_data['video'] = videos[0]
|
||||
if links:
|
||||
message_data['link'] = links[0]
|
||||
if items:
|
||||
message_data['attachments'] = items
|
||||
else:
|
||||
message_data['raw_msg'] = msg_json
|
||||
session = self.stream_sessions.get_session_by_feedback_id(feedback_id)
|
||||
|
||||
# Extract user information
|
||||
from_info = msg_json.get('from', {})
|
||||
message_data['userid'] = from_info.get('userid', '')
|
||||
message_data['username'] = (
|
||||
from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
|
||||
)
|
||||
if session:
|
||||
await self.logger.info(
|
||||
f'反馈关联到会话: stream_id={session.stream_id}, msg_id={session.msg_id}, user_id={session.user_id}'
|
||||
)
|
||||
else:
|
||||
await self.logger.warning(f'未找到 feedback_id={feedback_id} 对应的会话,仍将记录反馈')
|
||||
|
||||
# Extract chat/group information
|
||||
if msg_json.get('chattype', '') == 'group':
|
||||
message_data['chatid'] = msg_json.get('chatid', '')
|
||||
# Try to get group name if available
|
||||
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
|
||||
# Dispatch feedback event regardless of session availability
|
||||
for handler in self._message_handlers.get('feedback', []):
|
||||
try:
|
||||
await handler(
|
||||
feedback_id=feedback_id,
|
||||
feedback_type=feedback_type,
|
||||
feedback_content=feedback_content,
|
||||
inaccurate_reasons=inaccurate_reasons,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
|
||||
message_data['msgid'] = msg_json.get('msgid', '')
|
||||
if self._feedback_callback:
|
||||
try:
|
||||
await self._feedback_callback(
|
||||
feedback_id=feedback_id,
|
||||
feedback_type=feedback_type,
|
||||
feedback_content=feedback_content,
|
||||
inaccurate_reasons=inaccurate_reasons,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
|
||||
if msg_json.get('aibotid'):
|
||||
message_data['aibotid'] = msg_json.get('aibotid', '')
|
||||
except Exception:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
|
||||
return message_data
|
||||
return await self._encrypt_and_reply({}, nonce)
|
||||
|
||||
async def get_message(self, msg_json):
|
||||
return await parse_wecom_bot_message(msg_json, self.EnCodingAESKey, self.logger)
|
||||
|
||||
async def _handle_message(self, event: wecombotevent.WecomBotEvent):
|
||||
"""
|
||||
@@ -711,40 +1137,20 @@ class WecomBotClient:
|
||||
|
||||
return decorator
|
||||
|
||||
def on_feedback(self):
|
||||
def decorator(func: Callable):
|
||||
if 'feedback' not in self._message_handlers:
|
||||
self._message_handlers['feedback'] = []
|
||||
self._message_handlers['feedback'].append(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def download_url_to_base64(self, download_url, encoding_aes_key):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(download_url)
|
||||
if response.status_code != 200:
|
||||
await self.logger.error(f'failed to get file: {response.text}')
|
||||
return None
|
||||
|
||||
encrypted_bytes = response.content
|
||||
|
||||
aes_key = base64.b64decode(encoding_aes_key + '=') # base64 补齐
|
||||
iv = aes_key[:16]
|
||||
|
||||
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
||||
decrypted = cipher.decrypt(encrypted_bytes)
|
||||
|
||||
pad_len = decrypted[-1]
|
||||
decrypted = decrypted[:-pad_len]
|
||||
|
||||
if decrypted.startswith(b'\xff\xd8'): # JPEG
|
||||
mime_type = 'image/jpeg'
|
||||
elif decrypted.startswith(b'\x89PNG'): # PNG
|
||||
mime_type = 'image/png'
|
||||
elif decrypted.startswith((b'GIF87a', b'GIF89a')): # GIF
|
||||
mime_type = 'image/gif'
|
||||
elif decrypted.startswith(b'BM'): # BMP
|
||||
mime_type = 'image/bmp'
|
||||
elif decrypted.startswith(b'II*\x00') or decrypted.startswith(b'MM\x00*'): # TIFF
|
||||
mime_type = 'image/tiff'
|
||||
else:
|
||||
mime_type = 'application/octet-stream'
|
||||
|
||||
# 转 base64
|
||||
base64_str = base64.b64encode(decrypted).decode('utf-8')
|
||||
return f'data:{mime_type};base64,{base64_str}'
|
||||
data, _filename = await download_encrypted_file(download_url, encoding_aes_key, self.logger)
|
||||
if data:
|
||||
return _bytes_to_data_uri(data)
|
||||
return None
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -133,3 +133,24 @@ class WecomBotEvent(dict):
|
||||
AI Bot ID
|
||||
"""
|
||||
return self.get('aibotid', '')
|
||||
|
||||
@property
|
||||
def feedback_id(self) -> str:
|
||||
"""
|
||||
反馈 ID,用于关联用户点赞/点踩反馈
|
||||
"""
|
||||
return self.get('feedback_id', '')
|
||||
|
||||
@property
|
||||
def stream_id(self) -> str:
|
||||
"""
|
||||
流式消息 ID
|
||||
"""
|
||||
return self.get('stream_id', '')
|
||||
|
||||
@property
|
||||
def quote(self):
|
||||
"""
|
||||
引用消息信息(群聊中用户引用其他消息时返回)
|
||||
"""
|
||||
return self.get('quote', {})
|
||||
|
||||
683
src/langbot/libs/wecom_ai_bot_api/ws_client.py
Normal file
683
src/langbot/libs/wecom_ai_bot_api/ws_client.py
Normal file
@@ -0,0 +1,683 @@
|
||||
"""WeChat Work AI Bot WebSocket long connection client.
|
||||
|
||||
Implements the WebSocket protocol for receiving messages and sending replies
|
||||
via a persistent connection to wss://openws.work.weixin.qq.com, as an
|
||||
alternative to the HTTP callback (webhook) mode.
|
||||
|
||||
Protocol reference: https://developer.work.weixin.qq.com/document/path/101463
|
||||
Official Node.js SDK: https://github.com/WecomTeam/aibot-node-sdk
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from langbot.libs.wecom_ai_bot_api import wecombotevent
|
||||
from langbot.libs.wecom_ai_bot_api.api import parse_wecom_bot_message, StreamSession
|
||||
from langbot.pkg.platform.logger import EventLogger
|
||||
|
||||
DEFAULT_WS_URL = 'wss://openws.work.weixin.qq.com'
|
||||
|
||||
# WebSocket frame command constants
|
||||
CMD_SUBSCRIBE = 'aibot_subscribe'
|
||||
CMD_HEARTBEAT = 'ping'
|
||||
CMD_MSG_CALLBACK = 'aibot_msg_callback'
|
||||
CMD_EVENT_CALLBACK = 'aibot_event_callback'
|
||||
CMD_RESPOND_MSG = 'aibot_respond_msg'
|
||||
CMD_RESPOND_WELCOME = 'aibot_respond_welcome_msg'
|
||||
CMD_RESPOND_UPDATE = 'aibot_respond_update_msg'
|
||||
CMD_SEND_MSG = 'aibot_send_msg'
|
||||
|
||||
|
||||
def _generate_req_id(prefix: str) -> str:
|
||||
"""Generate a unique request ID in the format: {prefix}_{timestamp}_{random}."""
|
||||
ts = int(time.time() * 1000)
|
||||
rand = secrets.token_hex(4)
|
||||
return f'{prefix}_{ts}_{rand}'
|
||||
|
||||
|
||||
class WecomBotWsClient:
|
||||
"""WeChat Work AI Bot WebSocket long connection client.
|
||||
|
||||
Provides message receiving, streaming reply, proactive message sending,
|
||||
and event callback handling over a persistent WebSocket connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot_id: str,
|
||||
secret: str,
|
||||
logger: EventLogger,
|
||||
encoding_aes_key: str = '',
|
||||
ws_url: str = DEFAULT_WS_URL,
|
||||
heartbeat_interval: float = 30.0,
|
||||
max_reconnect_attempts: int = -1,
|
||||
reconnect_base_delay: float = 1.0,
|
||||
reconnect_max_delay: float = 30.0,
|
||||
):
|
||||
self.bot_id = bot_id
|
||||
self.secret = secret
|
||||
self.logger = logger
|
||||
self.encoding_aes_key = encoding_aes_key
|
||||
self.ws_url = ws_url
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
self.max_reconnect_attempts = max_reconnect_attempts
|
||||
self.reconnect_base_delay = reconnect_base_delay
|
||||
self.reconnect_max_delay = reconnect_max_delay
|
||||
|
||||
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._running = False
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._missed_pong_count = 0
|
||||
self._max_missed_pong = 2
|
||||
self._reconnect_attempts = 0
|
||||
|
||||
# Message handler registry (same pattern as WecomBotClient)
|
||||
self._message_handlers: dict[str, list[Callable]] = {}
|
||||
# Message deduplication
|
||||
self._msg_id_map: dict[str, int] = {}
|
||||
|
||||
# Pending ACK futures: req_id -> Future[dict]
|
||||
self._pending_acks: dict[str, asyncio.Future] = {}
|
||||
# Per-req_id serial reply queues
|
||||
self._reply_queues: dict[str, asyncio.Queue] = {}
|
||||
self._reply_workers: dict[str, asyncio.Task] = {}
|
||||
self._reply_ack_timeout = 5.0
|
||||
|
||||
# Stream ID tracking for WebSocket mode
|
||||
self._stream_ids: dict[str, str] = {} # msg_id -> req_id|stream_id
|
||||
# Dedup: skip sending when content hasn't changed
|
||||
self._stream_last_content: dict[str, str] = {} # msg_id -> last content sent
|
||||
# Stream session info for feedback tracking
|
||||
self._stream_sessions: dict[str, dict] = {} # msg_id -> session info
|
||||
# Feedback tracking: feedback_id -> session info
|
||||
self._feedback_sessions: dict[str, dict] = {} # feedback_id -> {msg_id, user_id, chat_id, stream_id, req_id}
|
||||
# msg_id -> feedback_id (for associating feedback with message)
|
||||
self._msg_feedback_ids: dict[str, str] = {} # msg_id -> feedback_id
|
||||
|
||||
# ── Public API ──────────────────────────────────────────────────
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to WebSocket server with automatic reconnection.
|
||||
|
||||
This method blocks until disconnect() is called or max reconnect
|
||||
attempts are exhausted.
|
||||
"""
|
||||
self._running = True
|
||||
self._reconnect_attempts = 0
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await self._connect_once()
|
||||
except Exception:
|
||||
if not self._running:
|
||||
break
|
||||
await self.logger.error(f'WebSocket connection error: {traceback.format_exc()}')
|
||||
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
# Reconnect with exponential backoff
|
||||
if self.max_reconnect_attempts != -1 and self._reconnect_attempts >= self.max_reconnect_attempts:
|
||||
await self.logger.error(f'Max reconnect attempts reached ({self.max_reconnect_attempts}), giving up')
|
||||
break
|
||||
|
||||
self._reconnect_attempts += 1
|
||||
delay = min(
|
||||
self.reconnect_base_delay * (2 ** (self._reconnect_attempts - 1)),
|
||||
self.reconnect_max_delay,
|
||||
)
|
||||
await self.logger.info(f'Reconnecting in {delay:.1f}s (attempt {self._reconnect_attempts})...')
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async def disconnect(self):
|
||||
"""Gracefully disconnect from the WebSocket server."""
|
||||
self._running = False
|
||||
if self._heartbeat_task and not self._heartbeat_task.done():
|
||||
self._heartbeat_task.cancel()
|
||||
for task in self._reply_workers.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
if self._ws and not self._ws.closed:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
def on_message(self, msg_type: str) -> Callable:
|
||||
"""Decorator to register a message handler.
|
||||
|
||||
Same interface as WecomBotClient.on_message for compatibility.
|
||||
|
||||
Args:
|
||||
msg_type: 'single', 'group', or specific message type.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[[wecombotevent.WecomBotEvent], Any]):
|
||||
if msg_type not in self._message_handlers:
|
||||
self._message_handlers[msg_type] = []
|
||||
self._message_handlers[msg_type].append(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def on_feedback(self) -> Callable:
|
||||
"""Decorator to register a feedback event handler.
|
||||
|
||||
Same interface as WecomBotClient.on_feedback for compatibility.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
if 'feedback' not in self._message_handlers:
|
||||
self._message_handlers['feedback'] = []
|
||||
self._message_handlers['feedback'].append(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def reply_stream(
|
||||
self,
|
||||
req_id: str,
|
||||
stream_id: str,
|
||||
content: str,
|
||||
finish: bool = False,
|
||||
feedback_id: str = '',
|
||||
) -> Optional[dict]:
|
||||
"""Send a streaming reply frame.
|
||||
|
||||
Args:
|
||||
req_id: The req_id from the original message frame (must be passed through).
|
||||
stream_id: The stream ID for this streaming session.
|
||||
content: The content to send (supports Markdown).
|
||||
finish: Whether this is the final chunk.
|
||||
feedback_id: Optional feedback ID for receiving user feedback (like/dislike).
|
||||
|
||||
Returns:
|
||||
The ACK frame dict, or None on failure.
|
||||
"""
|
||||
stream_payload = {
|
||||
'id': stream_id,
|
||||
'finish': finish,
|
||||
'content': content,
|
||||
}
|
||||
if feedback_id:
|
||||
stream_payload['feedback'] = {'id': feedback_id}
|
||||
|
||||
body = {
|
||||
'msgtype': 'stream',
|
||||
'stream': stream_payload,
|
||||
}
|
||||
return await self._send_reply(req_id, body)
|
||||
|
||||
async def reply_text(self, req_id: str, content: str) -> Optional[dict]:
|
||||
"""Send a non-streaming text reply.
|
||||
|
||||
Args:
|
||||
req_id: The req_id from the original message frame.
|
||||
content: The text content to reply.
|
||||
|
||||
Returns:
|
||||
The ACK frame dict, or None on failure.
|
||||
"""
|
||||
body = {
|
||||
'msgtype': 'markdown',
|
||||
'markdown': {
|
||||
'content': content,
|
||||
},
|
||||
}
|
||||
return await self._send_reply(req_id, body)
|
||||
|
||||
async def send_message(self, chat_id: str, content: str, msgtype: str = 'markdown') -> Optional[dict]:
|
||||
"""Proactively send a message to a specified chat.
|
||||
|
||||
Args:
|
||||
chat_id: The chat ID (userid for single chat, chatid for group chat).
|
||||
content: The message content.
|
||||
msgtype: Message type, 'markdown' by default.
|
||||
|
||||
Returns:
|
||||
The ACK frame dict, or None on failure.
|
||||
"""
|
||||
req_id = _generate_req_id(CMD_SEND_MSG)
|
||||
body: dict[str, Any] = {
|
||||
'chatid': chat_id,
|
||||
'msgtype': msgtype,
|
||||
}
|
||||
if msgtype == 'markdown':
|
||||
body['markdown'] = {'content': content}
|
||||
elif msgtype == 'text':
|
||||
body['text'] = {'content': content}
|
||||
return await self._send_reply(req_id, body, cmd=CMD_SEND_MSG)
|
||||
|
||||
async def push_stream_chunk(self, msg_id: str, content: str, is_final: bool = False) -> bool:
|
||||
"""Push a streaming chunk for a given message ID.
|
||||
|
||||
Compatible interface with WecomBotClient.push_stream_chunk.
|
||||
|
||||
Args:
|
||||
msg_id: The original message ID.
|
||||
content: The cumulative content from the pipeline.
|
||||
is_final: Whether this is the final chunk.
|
||||
|
||||
Returns:
|
||||
True if the stream session exists and chunk was sent.
|
||||
"""
|
||||
key = self._stream_ids.get(msg_id)
|
||||
if not key:
|
||||
return False
|
||||
req_id, stream_id = key.split('|', 1)
|
||||
try:
|
||||
# Skip sending if content hasn't changed (e.g. during tool call argument streaming)
|
||||
if not is_final and content == self._stream_last_content.get(msg_id):
|
||||
return True
|
||||
|
||||
# Generate feedback_id for final chunk
|
||||
feedback_id = ''
|
||||
if is_final:
|
||||
feedback_id = _generate_req_id('feedback')
|
||||
self._msg_feedback_ids[msg_id] = feedback_id
|
||||
# Store session info for feedback tracking
|
||||
session_info = self._stream_sessions.get(msg_id)
|
||||
if session_info:
|
||||
self._feedback_sessions[feedback_id] = session_info
|
||||
|
||||
await self.reply_stream(req_id, stream_id, content, finish=is_final, feedback_id=feedback_id)
|
||||
self._stream_last_content[msg_id] = content
|
||||
if is_final:
|
||||
self._stream_ids.pop(msg_id, None)
|
||||
self._stream_last_content.pop(msg_id, None)
|
||||
self._stream_sessions.pop(msg_id, None)
|
||||
return True
|
||||
except Exception:
|
||||
await self.logger.error(f'Failed to push stream chunk: {traceback.format_exc()}')
|
||||
return False
|
||||
|
||||
async def set_message(self, msg_id: str, content: str):
|
||||
"""Fallback: send content as a final stream chunk or direct reply.
|
||||
|
||||
Compatible interface with WecomBotClient.set_message.
|
||||
"""
|
||||
handled = await self.push_stream_chunk(msg_id, content, is_final=True)
|
||||
if not handled:
|
||||
await self.logger.warning(f'No active stream for msg_id={msg_id}, message dropped')
|
||||
|
||||
# ── Connection lifecycle ────────────────────────────────────────
|
||||
|
||||
async def _connect_once(self):
|
||||
"""Establish a single WebSocket connection, authenticate, and listen."""
|
||||
await self.logger.info(f'Connecting to {self.ws_url}...')
|
||||
|
||||
self._session = aiohttp.ClientSession()
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(self.ws_url)
|
||||
self._missed_pong_count = 0
|
||||
self._reconnect_attempts = 0
|
||||
await self.logger.info('WebSocket connected, sending auth...')
|
||||
|
||||
await self._send_auth()
|
||||
|
||||
# Wait for auth response
|
||||
auth_ok = await self._wait_for_auth()
|
||||
if not auth_ok:
|
||||
await self.logger.error('Authentication failed')
|
||||
return
|
||||
|
||||
await self.logger.info('Authenticated successfully')
|
||||
|
||||
# Start heartbeat
|
||||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
|
||||
try:
|
||||
await self._listen_loop()
|
||||
finally:
|
||||
if self._heartbeat_task and not self._heartbeat_task.done():
|
||||
self._heartbeat_task.cancel()
|
||||
self._clear_pending_acks('Connection closed')
|
||||
finally:
|
||||
if self._ws and not self._ws.closed:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def _send_auth(self):
|
||||
"""Send the authentication frame."""
|
||||
frame = {
|
||||
'cmd': CMD_SUBSCRIBE,
|
||||
'headers': {'req_id': _generate_req_id(CMD_SUBSCRIBE)},
|
||||
'body': {
|
||||
'bot_id': self.bot_id,
|
||||
'secret': self.secret,
|
||||
},
|
||||
}
|
||||
await self._send_frame(frame)
|
||||
|
||||
async def _wait_for_auth(self) -> bool:
|
||||
"""Wait for and validate the authentication response."""
|
||||
try:
|
||||
msg = await asyncio.wait_for(self._ws.receive(), timeout=10.0)
|
||||
if msg.type in (aiohttp.WSMsgType.TEXT,):
|
||||
frame = json.loads(msg.data)
|
||||
req_id = frame.get('headers', {}).get('req_id', '')
|
||||
if req_id.startswith(CMD_SUBSCRIBE) and frame.get('errcode') == 0:
|
||||
return True
|
||||
await self.logger.error(f'Auth response: errcode={frame.get("errcode")}, errmsg={frame.get("errmsg")}')
|
||||
return False
|
||||
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||
await self.logger.error(f'WebSocket closed during auth: {msg.type}')
|
||||
return False
|
||||
await self.logger.error(f'Unexpected message type during auth: {msg.type}')
|
||||
return False
|
||||
except asyncio.TimeoutError:
|
||||
await self.logger.error('Auth response timeout')
|
||||
return False
|
||||
|
||||
async def _heartbeat_loop(self):
|
||||
"""Periodically send heartbeat pings."""
|
||||
try:
|
||||
while self._running and self._ws and not self._ws.closed:
|
||||
await asyncio.sleep(self.heartbeat_interval)
|
||||
if not self._running or not self._ws or self._ws.closed:
|
||||
break
|
||||
|
||||
if self._missed_pong_count >= self._max_missed_pong:
|
||||
await self.logger.warning(
|
||||
f'No heartbeat ack for {self._missed_pong_count} consecutive pings, connection considered dead'
|
||||
)
|
||||
await self._ws.close()
|
||||
break
|
||||
|
||||
self._missed_pong_count += 1
|
||||
frame = {
|
||||
'cmd': CMD_HEARTBEAT,
|
||||
'headers': {'req_id': _generate_req_id(CMD_HEARTBEAT)},
|
||||
}
|
||||
try:
|
||||
await self._send_frame(frame)
|
||||
except Exception:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _listen_loop(self):
|
||||
"""Listen for incoming WebSocket frames and dispatch them."""
|
||||
async for msg in self._ws:
|
||||
if not self._running:
|
||||
break
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
frame = json.loads(msg.data)
|
||||
await self._handle_frame(frame)
|
||||
except json.JSONDecodeError:
|
||||
await self.logger.error(f'Failed to parse WebSocket message: {str(msg.data)[:200]}')
|
||||
except Exception:
|
||||
await self.logger.error(f'Error handling frame: {traceback.format_exc()}')
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
try:
|
||||
frame = json.loads(msg.data)
|
||||
await self._handle_frame(frame)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error handling binary frame: {traceback.format_exc()}')
|
||||
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||
await self.logger.warning(f'WebSocket connection closed: {msg.type}')
|
||||
break
|
||||
|
||||
# ── Frame handling ──────────────────────────────────────────────
|
||||
|
||||
async def _handle_frame(self, frame: dict):
|
||||
"""Route an incoming frame to the appropriate handler."""
|
||||
cmd = frame.get('cmd', '')
|
||||
|
||||
# Message push
|
||||
if cmd == CMD_MSG_CALLBACK:
|
||||
asyncio.create_task(self._handle_message_callback(frame))
|
||||
return
|
||||
|
||||
# Event push
|
||||
if cmd == CMD_EVENT_CALLBACK:
|
||||
asyncio.create_task(self._handle_event_callback(frame))
|
||||
return
|
||||
|
||||
# No cmd → response/ACK frame, dispatch by req_id prefix
|
||||
req_id = frame.get('headers', {}).get('req_id', '')
|
||||
|
||||
# Check pending ACKs first
|
||||
if req_id in self._pending_acks:
|
||||
future = self._pending_acks.pop(req_id)
|
||||
if not future.done():
|
||||
future.set_result(frame)
|
||||
return
|
||||
|
||||
# Heartbeat response
|
||||
if req_id.startswith(CMD_HEARTBEAT):
|
||||
if frame.get('errcode') == 0:
|
||||
self._missed_pong_count = 0
|
||||
return
|
||||
|
||||
# Unknown frame
|
||||
await self.logger.warning(f'Unknown frame: {json.dumps(frame, ensure_ascii=False)[:200]}')
|
||||
|
||||
async def _handle_message_callback(self, frame: dict):
|
||||
"""Handle an incoming message callback frame."""
|
||||
try:
|
||||
body = frame.get('body', {})
|
||||
req_id = frame.get('headers', {}).get('req_id', '')
|
||||
|
||||
# Parse message using shared logic
|
||||
message_data = await parse_wecom_bot_message(body, self.encoding_aes_key, self.logger)
|
||||
if not message_data:
|
||||
return
|
||||
|
||||
# Generate stream_id for this message and store the mapping
|
||||
stream_id = _generate_req_id('stream')
|
||||
msg_id = message_data.get('msgid', '')
|
||||
if msg_id:
|
||||
self._stream_ids[msg_id] = f'{req_id}|{stream_id}'
|
||||
# Store session info for feedback tracking
|
||||
self._stream_sessions[msg_id] = {
|
||||
'req_id': req_id,
|
||||
'stream_id': stream_id,
|
||||
'msg_id': msg_id,
|
||||
'user_id': message_data.get('userid', ''),
|
||||
'chat_id': message_data.get('chatid', ''),
|
||||
'chat_type': message_data.get('type', 'single'),
|
||||
}
|
||||
message_data['stream_id'] = stream_id
|
||||
message_data['req_id'] = req_id
|
||||
|
||||
event = wecombotevent.WecomBotEvent(message_data)
|
||||
await self._dispatch_event(event)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in message callback: {traceback.format_exc()}')
|
||||
|
||||
async def _handle_event_callback(self, frame: dict):
|
||||
"""Handle an incoming event callback frame (enter_chat, template_card_event, feedback_event, disconnected_event)."""
|
||||
try:
|
||||
body = frame.get('body', {})
|
||||
req_id = frame.get('headers', {}).get('req_id', '')
|
||||
|
||||
event_info = body.get('event', {})
|
||||
event_type = event_info.get('eventtype', '')
|
||||
|
||||
message_data = {
|
||||
'msgtype': 'event',
|
||||
'type': body.get('chattype', 'single'),
|
||||
'event': event_info,
|
||||
'eventtype': event_type,
|
||||
'msgid': body.get('msgid', ''),
|
||||
'aibotid': body.get('aibotid', ''),
|
||||
'req_id': req_id,
|
||||
}
|
||||
|
||||
from_info = body.get('from', {})
|
||||
message_data['userid'] = from_info.get('userid', '')
|
||||
message_data['username'] = from_info.get('alias', '') or from_info.get('userid', '')
|
||||
|
||||
if body.get('chatid'):
|
||||
message_data['chatid'] = body.get('chatid', '')
|
||||
|
||||
if event_type == 'feedback_event':
|
||||
feedback_event = event_info.get('feedback_event', {})
|
||||
feedback_id = feedback_event.get('id', '')
|
||||
feedback_type = feedback_event.get('type', 0)
|
||||
feedback_content = feedback_event.get('content', '')
|
||||
inaccurate_reasons = feedback_event.get('inaccurate_reason_list', [])
|
||||
|
||||
await self.logger.info(
|
||||
f'收到用户反馈事件: feedback_id={feedback_id}, type={feedback_type}, '
|
||||
f'content={feedback_content}, reasons={inaccurate_reasons}'
|
||||
)
|
||||
|
||||
# Look up session by feedback_id
|
||||
session_info = self._feedback_sessions.get(feedback_id)
|
||||
session = None
|
||||
if session_info:
|
||||
session = StreamSession(
|
||||
stream_id=session_info.get('stream_id', ''),
|
||||
msg_id=session_info.get('msg_id', ''),
|
||||
chat_id=session_info.get('chat_id') or None,
|
||||
user_id=session_info.get('user_id') or None,
|
||||
feedback_id=feedback_id,
|
||||
)
|
||||
await self.logger.info(
|
||||
f'反馈关联到会话: stream_id={session.stream_id}, msg_id={session.msg_id}, user_id={session.user_id}'
|
||||
)
|
||||
else:
|
||||
await self.logger.warning(f'未找到 feedback_id={feedback_id} 对应的会话')
|
||||
|
||||
for handler in self._message_handlers.get('feedback', []):
|
||||
try:
|
||||
await handler(
|
||||
feedback_id=feedback_id,
|
||||
feedback_type=feedback_type,
|
||||
feedback_content=feedback_content,
|
||||
inaccurate_reasons=inaccurate_reasons,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in feedback handler: {traceback.format_exc()}')
|
||||
return
|
||||
|
||||
event = wecombotevent.WecomBotEvent(message_data)
|
||||
|
||||
if event_type in self._message_handlers:
|
||||
for handler in self._message_handlers[event_type]:
|
||||
await handler(event)
|
||||
|
||||
if 'event' in self._message_handlers:
|
||||
for handler in self._message_handlers['event']:
|
||||
await handler(event)
|
||||
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in event callback: {traceback.format_exc()}')
|
||||
|
||||
async def _dispatch_event(self, event: wecombotevent.WecomBotEvent):
|
||||
"""Dispatch a message event to registered handlers with deduplication."""
|
||||
try:
|
||||
message_id = event.message_id
|
||||
if message_id in self._msg_id_map:
|
||||
self._msg_id_map[message_id] += 1
|
||||
return
|
||||
self._msg_id_map[message_id] = 1
|
||||
|
||||
msg_type = event.type
|
||||
if msg_type in self._message_handlers:
|
||||
for handler in self._message_handlers[msg_type]:
|
||||
await handler(event)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error dispatching event: {traceback.format_exc()}')
|
||||
|
||||
# ── Reply sending with serial queue ─────────────────────────────
|
||||
|
||||
async def _send_reply(
|
||||
self,
|
||||
req_id: str,
|
||||
body: dict,
|
||||
cmd: str = CMD_RESPOND_MSG,
|
||||
) -> Optional[dict]:
|
||||
"""Send a reply frame and wait for ACK.
|
||||
|
||||
Replies with the same req_id are serialized to maintain ordering.
|
||||
"""
|
||||
if not self._ws or self._ws.closed:
|
||||
return None
|
||||
|
||||
frame = {
|
||||
'cmd': cmd,
|
||||
'headers': {'req_id': req_id},
|
||||
'body': body,
|
||||
}
|
||||
|
||||
# Ensure serial delivery per req_id
|
||||
if req_id not in self._reply_queues:
|
||||
self._reply_queues[req_id] = asyncio.Queue()
|
||||
self._reply_workers[req_id] = asyncio.create_task(self._reply_queue_worker(req_id))
|
||||
|
||||
future: asyncio.Future = asyncio.get_event_loop().create_future()
|
||||
await self._reply_queues[req_id].put((frame, future))
|
||||
return await future
|
||||
|
||||
async def _reply_queue_worker(self, req_id: str):
|
||||
"""Process reply queue items serially for a given req_id."""
|
||||
queue = self._reply_queues[req_id]
|
||||
try:
|
||||
while self._running:
|
||||
try:
|
||||
frame, future = await asyncio.wait_for(queue.get(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
# Queue idle, clean up worker
|
||||
break
|
||||
|
||||
try:
|
||||
ack = await self._send_and_wait_ack(frame)
|
||||
if not future.done():
|
||||
future.set_result(ack)
|
||||
except Exception as e:
|
||||
if not future.done():
|
||||
future.set_exception(e)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._reply_queues.pop(req_id, None)
|
||||
self._reply_workers.pop(req_id, None)
|
||||
|
||||
async def _send_and_wait_ack(self, frame: dict) -> Optional[dict]:
|
||||
"""Send a frame and wait for the corresponding ACK."""
|
||||
req_id = frame['headers']['req_id']
|
||||
ack_future: asyncio.Future = asyncio.get_event_loop().create_future()
|
||||
self._pending_acks[req_id] = ack_future
|
||||
|
||||
try:
|
||||
await self._send_frame(frame)
|
||||
result = await asyncio.wait_for(ack_future, timeout=self._reply_ack_timeout)
|
||||
if result.get('errcode', 0) != 0:
|
||||
await self.logger.warning(
|
||||
f'Reply ACK error: errcode={result.get("errcode")}, errmsg={result.get("errmsg")}'
|
||||
)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_acks.pop(req_id, None)
|
||||
await self.logger.warning(f'Reply ACK timeout ({self._reply_ack_timeout}s) for req_id={req_id}')
|
||||
return None
|
||||
|
||||
async def _send_frame(self, frame: dict):
|
||||
"""Send a JSON frame over the WebSocket connection."""
|
||||
if self._ws and not self._ws.closed:
|
||||
await self._ws.send_str(json.dumps(frame, ensure_ascii=False))
|
||||
|
||||
def _clear_pending_acks(self, reason: str):
|
||||
"""Reject all pending ACK futures on disconnection."""
|
||||
for req_id, future in self._pending_acks.items():
|
||||
if not future.done():
|
||||
future.set_exception(ConnectionError(reason))
|
||||
self._pending_acks.clear()
|
||||
@@ -4,6 +4,7 @@ import base64
|
||||
import binascii
|
||||
import httpx
|
||||
import traceback
|
||||
from urllib.parse import quote
|
||||
from quart import Quart
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Callable, Dict, Any
|
||||
@@ -67,6 +68,31 @@ class WecomClient:
|
||||
await self.logger.error(f'获取accesstoken失败:{response.json()}')
|
||||
raise Exception(f'未获取access token: {data}')
|
||||
|
||||
async def get_user_info(self, userid: str) -> dict:
|
||||
"""
|
||||
Get user information by user ID using the application secret.
|
||||
|
||||
Args:
|
||||
userid: The user ID to look up.
|
||||
|
||||
Returns:
|
||||
dict: User information including 'name' field.
|
||||
"""
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
|
||||
url = self.base_url + '/user/get?access_token=' + self.access_token + '&userid=' + quote(userid)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
data = response.json()
|
||||
if data.get('errcode') == 40014 or data.get('errcode') == 42001:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.get_user_info(userid)
|
||||
if data.get('errcode', 0) != 0:
|
||||
await self.logger.error(f'获取用户信息失败:{data}')
|
||||
return {}
|
||||
return data
|
||||
|
||||
async def get_users(self):
|
||||
if not self.check_access_token_for_contacts():
|
||||
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Callable
|
||||
from .wecomcsevent import WecomCSEvent
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import aiofiles
|
||||
import time
|
||||
|
||||
|
||||
class WecomCSClient:
|
||||
@@ -34,6 +35,10 @@ class WecomCSClient:
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
|
||||
# Customer info cache: {external_userid: (info_dict, timestamp)}
|
||||
self._customer_cache: dict[str, tuple[dict, float]] = {}
|
||||
self._cache_ttl = 60 # Cache TTL in seconds (1 minute)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
@@ -378,3 +383,53 @@ class WecomCSClient:
|
||||
async def get_media_id(self, image: platform_message.Image):
|
||||
media_id = await self.upload_to_work(image=image)
|
||||
return media_id
|
||||
|
||||
async def get_customer_info(self, external_userid: str) -> dict | None:
|
||||
"""
|
||||
Get customer information by external_userid with caching.
|
||||
|
||||
Uses a 1-minute cache to avoid repeated API calls for the same user.
|
||||
|
||||
Args:
|
||||
external_userid: The external user ID of the customer.
|
||||
|
||||
Returns:
|
||||
Customer info dict with 'nickname', 'avatar', etc., or None if not found.
|
||||
"""
|
||||
# Check cache first
|
||||
current_time = time.time()
|
||||
if external_userid in self._customer_cache:
|
||||
cached_info, cached_time = self._customer_cache[external_userid]
|
||||
if current_time - cached_time < self._cache_ttl:
|
||||
return cached_info
|
||||
|
||||
# Cache miss or expired, fetch from API
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
|
||||
url = f'{self.base_url}/kf/customer/batchget?access_token={self.access_token}'
|
||||
|
||||
payload = {
|
||||
'external_userid_list': [external_userid],
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, json=payload)
|
||||
data = response.json()
|
||||
|
||||
if data.get('errcode') in [40014, 42001]:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.get_customer_info(external_userid)
|
||||
|
||||
if data.get('errcode', 0) != 0:
|
||||
if self.logger:
|
||||
await self.logger.warning(f'Failed to get customer info: {data}')
|
||||
return None
|
||||
|
||||
customer_list = data.get('customer_list', [])
|
||||
if customer_list:
|
||||
customer_info = customer_list[0]
|
||||
# Store in cache
|
||||
self._customer_cache[external_userid] = (customer_info, current_time)
|
||||
return customer_info
|
||||
return None
|
||||
|
||||
@@ -13,9 +13,9 @@ from .. import group
|
||||
@group.group_class('files', '/api/v1/files')
|
||||
class FilesRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/image/<image_key>', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
@self.route('/image/<path:image_key>', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def _(image_key: str) -> quart.Response:
|
||||
if '/' in image_key or '\\' in image_key:
|
||||
if '..' in image_key or '\\' in image_key:
|
||||
return quart.Response(status=404)
|
||||
|
||||
if not await self.ap.storage_mgr.storage_provider.exists(image_key):
|
||||
|
||||
@@ -456,6 +456,31 @@ class MonitoringRouterGroup(group.RouterGroup):
|
||||
'platform',
|
||||
'user_id',
|
||||
]
|
||||
elif export_type == 'feedback':
|
||||
data = await self.ap.monitoring_service.export_feedback(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
)
|
||||
headers = [
|
||||
'id',
|
||||
'timestamp',
|
||||
'feedback_id',
|
||||
'feedback_type',
|
||||
'feedback_content',
|
||||
'inaccurate_reasons',
|
||||
'bot_id',
|
||||
'bot_name',
|
||||
'pipeline_id',
|
||||
'pipeline_name',
|
||||
'session_id',
|
||||
'message_id',
|
||||
'stream_id',
|
||||
'user_id',
|
||||
'platform',
|
||||
]
|
||||
else:
|
||||
return self.error(message=f'Invalid export type: {export_type}', code=400)
|
||||
|
||||
@@ -486,3 +511,63 @@ class MonitoringRouterGroup(group.RouterGroup):
|
||||
)
|
||||
|
||||
return response, 200
|
||||
|
||||
@self.route('/feedback/stats', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_feedback_stats() -> str:
|
||||
"""Get feedback statistics"""
|
||||
# Parse query parameters
|
||||
bot_ids = quart.request.args.getlist('botId')
|
||||
pipeline_ids = quart.request.args.getlist('pipelineId')
|
||||
start_time_str = quart.request.args.get('startTime')
|
||||
end_time_str = quart.request.args.get('endTime')
|
||||
|
||||
# Parse datetime
|
||||
start_time = parse_iso_datetime(start_time_str)
|
||||
end_time = parse_iso_datetime(end_time_str)
|
||||
|
||||
stats = await self.ap.monitoring_service.get_feedback_stats(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
return self.success(data=stats)
|
||||
|
||||
@self.route('/feedback', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_feedback() -> str:
|
||||
"""Get feedback list"""
|
||||
# Parse query parameters
|
||||
bot_ids = quart.request.args.getlist('botId')
|
||||
pipeline_ids = quart.request.args.getlist('pipelineId')
|
||||
feedback_type_str = quart.request.args.get('feedbackType')
|
||||
start_time_str = quart.request.args.get('startTime')
|
||||
end_time_str = quart.request.args.get('endTime')
|
||||
limit = int(quart.request.args.get('limit', 100))
|
||||
offset = int(quart.request.args.get('offset', 0))
|
||||
|
||||
# Parse datetime
|
||||
start_time = parse_iso_datetime(start_time_str)
|
||||
end_time = parse_iso_datetime(end_time_str)
|
||||
|
||||
# Parse feedback type
|
||||
feedback_type = int(feedback_type_str) if feedback_type_str else None
|
||||
|
||||
feedback_list, total = await self.ap.monitoring_service.get_feedback_list(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
feedback_type=feedback_type,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'feedback': feedback_list,
|
||||
'total': total,
|
||||
'limit': limit,
|
||||
'offset': offset,
|
||||
}
|
||||
)
|
||||
|
||||
384
src/langbot/pkg/api/http/controller/groups/pipelines/embed.py
Normal file
384
src/langbot/pkg/api/http/controller/groups/pipelines/embed.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""Embed widget routes - serve embeddable chat widget for external websites.
|
||||
|
||||
All user-facing URLs are keyed by **bot_uuid** (not pipeline_uuid) so that
|
||||
internal pipeline identifiers are never exposed to end-users. Each handler
|
||||
resolves the bot_uuid to the owning ``web_page_bot`` RuntimeBot and extracts
|
||||
the bound pipeline_uuid for internal routing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
import hmac
|
||||
import hashlib
|
||||
import time
|
||||
import re
|
||||
import httpx
|
||||
|
||||
import quart
|
||||
|
||||
from ... import group
|
||||
from ......utils import paths
|
||||
from ......platform.sources.websocket_manager import ws_connection_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache the widget template content
|
||||
_widget_template_cache: str | None = None
|
||||
_logo_bytes_cache: bytes | None = None
|
||||
|
||||
|
||||
def _is_valid_uuid(s: str) -> bool:
|
||||
return bool(re.match(r'^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$', s))
|
||||
|
||||
|
||||
def _get_widget_template() -> str:
|
||||
"""Load and cache the widget JS template."""
|
||||
global _widget_template_cache
|
||||
if _widget_template_cache is None:
|
||||
template_path = paths.get_resource_path('templates/embed/widget.js')
|
||||
with open(template_path, 'r', encoding='utf-8') as f:
|
||||
_widget_template_cache = f.read()
|
||||
return _widget_template_cache
|
||||
|
||||
|
||||
def _get_logo_bytes() -> bytes:
|
||||
"""Load and cache the logo image."""
|
||||
global _logo_bytes_cache
|
||||
if _logo_bytes_cache is None:
|
||||
logo_path = paths.get_resource_path('templates/embed/logo.webp')
|
||||
with open(logo_path, 'rb') as f:
|
||||
_logo_bytes_cache = f.read()
|
||||
return _logo_bytes_cache
|
||||
|
||||
|
||||
@group.group_class('embed', '/api/v1/embed')
|
||||
class EmbedRouterGroup(group.RouterGroup):
|
||||
# -- helpers -------------------------------------------------------------
|
||||
|
||||
def _resolve_bot(self, bot_uuid: str):
|
||||
"""Resolve *bot_uuid* to ``(runtime_bot, pipeline_uuid)``.
|
||||
|
||||
Returns ``(None, None)`` when the bot does not exist, is not a
|
||||
``web_page_bot``, is disabled, or has no pipeline bound.
|
||||
"""
|
||||
for bot in self.ap.platform_mgr.bots:
|
||||
if (
|
||||
bot.bot_entity.uuid == bot_uuid
|
||||
and bot.bot_entity.adapter == 'web_page_bot'
|
||||
and bot.bot_entity.enable
|
||||
and bot.bot_entity.use_pipeline_uuid
|
||||
):
|
||||
return bot, bot.bot_entity.use_pipeline_uuid
|
||||
return None, None
|
||||
|
||||
def _get_bot_config(self, bot_uuid: str) -> dict:
|
||||
for bot in self.ap.platform_mgr.bots:
|
||||
if bot.bot_entity.uuid == bot_uuid and bot.bot_entity.adapter == 'web_page_bot':
|
||||
return bot.bot_entity.adapter_config
|
||||
return {}
|
||||
|
||||
async def _verify_session_token(self, request, bot_uuid: str) -> bool:
|
||||
config = self._get_bot_config(bot_uuid)
|
||||
secret = config.get('turnstile_secret_key', '')
|
||||
if not secret:
|
||||
return True
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
if not auth_header.startswith('Bearer '):
|
||||
return False
|
||||
token = auth_header[7:]
|
||||
try:
|
||||
ts_str, mac = token.split('.', 1)
|
||||
ts = float(ts_str)
|
||||
if time.time() - ts > 86400:
|
||||
return False
|
||||
expected_mac = hmac.new(secret.encode(), f'{ts_str}'.encode(), hashlib.sha256).hexdigest()
|
||||
return hmac.compare_digest(mac, expected_mac)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# -- routes --------------------------------------------------------------
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/<bot_uuid>/turnstile/verify', methods=['POST'], auth_type=group.AuthType.NONE)
|
||||
async def verify_turnstile(bot_uuid: str) -> str:
|
||||
if not _is_valid_uuid(bot_uuid):
|
||||
return self.http_status(400, -1, 'Invalid bot_uuid format')
|
||||
runtime_bot, pipeline_uuid = self._resolve_bot(bot_uuid)
|
||||
if runtime_bot is None:
|
||||
return self.http_status(404, -1, 'Bot not found or not available')
|
||||
try:
|
||||
data = await quart.request.get_json()
|
||||
token = data.get('token')
|
||||
if not token:
|
||||
return self.http_status(400, -1, 'Token is required')
|
||||
|
||||
config = self._get_bot_config(bot_uuid)
|
||||
secret = config.get('turnstile_secret_key', '')
|
||||
if not secret:
|
||||
ts = time.time()
|
||||
return self.success(data={'token': f'{ts}.dummy'})
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
'https://challenges.cloudflare.com/turnstile/v0/siteverify',
|
||||
data={'secret': secret, 'response': token},
|
||||
)
|
||||
result = resp.json()
|
||||
|
||||
if not result.get('success'):
|
||||
return self.http_status(403, -1, 'Turnstile verification failed')
|
||||
|
||||
ts = time.time()
|
||||
mac = hmac.new(secret.encode(), f'{ts}'.encode(), hashlib.sha256).hexdigest()
|
||||
session_token = f'{ts}.{mac}'
|
||||
|
||||
return self.success(data={'token': session_token})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Turnstile verify failed: {e}', exc_info=True)
|
||||
return self.http_status(500, -1, 'Internal server error')
|
||||
|
||||
@self.route('/<bot_uuid>/widget.js', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def serve_widget(bot_uuid: str) -> quart.Response:
|
||||
"""Serve the embed widget JavaScript with injected configuration."""
|
||||
if not _is_valid_uuid(bot_uuid):
|
||||
return self.http_status(400, -1, 'Invalid bot_uuid format')
|
||||
runtime_bot, pipeline_uuid = self._resolve_bot(bot_uuid)
|
||||
if runtime_bot is None:
|
||||
return quart.Response(
|
||||
'// Bot not found or not available', status=404, content_type='application/javascript'
|
||||
)
|
||||
try:
|
||||
template = _get_widget_template()
|
||||
except FileNotFoundError:
|
||||
return quart.Response('// Widget template not found', status=404, content_type='application/javascript')
|
||||
|
||||
base_url = quart.request.host_url.rstrip('/')
|
||||
webhook_prefix = self.ap.instance_config.data.get('api', {}).get('webhook_prefix', '')
|
||||
if webhook_prefix:
|
||||
base_url = webhook_prefix.rstrip('/')
|
||||
|
||||
if not re.match(r'^https?://[a-zA-Z0-9._:/-]+$', base_url):
|
||||
base_url = quart.request.host_url.rstrip('/')
|
||||
|
||||
config = self._get_bot_config(bot_uuid)
|
||||
site_key = config.get('turnstile_site_key', '')
|
||||
locale = config.get('language', 'en_US') or 'en_US'
|
||||
bubble_icon = config.get('bubble_icon', 'logo') or 'logo'
|
||||
widget_js = template.replace('__LANGBOT_TURNSTILE_SITE_KEY__', site_key)
|
||||
widget_js = widget_js.replace('__LANGBOT_BOT_UUID__', bot_uuid)
|
||||
widget_js = widget_js.replace('__LANGBOT_BASE_URL__', base_url)
|
||||
widget_js = widget_js.replace('__LANGBOT_LOCALE__', locale)
|
||||
widget_js = widget_js.replace('__LANGBOT_BUBBLE_ICON__', bubble_icon)
|
||||
|
||||
response = quart.Response(widget_js, content_type='application/javascript; charset=utf-8')
|
||||
response.headers['Cache-Control'] = 'public, max-age=300'
|
||||
return response
|
||||
|
||||
@self.route('/logo', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def serve_logo() -> quart.Response:
|
||||
"""Serve the LangBot logo for the embed widget."""
|
||||
try:
|
||||
logo_data = _get_logo_bytes()
|
||||
except FileNotFoundError:
|
||||
return quart.Response('', status=404)
|
||||
|
||||
response = quart.Response(logo_data, content_type='image/webp')
|
||||
response.headers['Cache-Control'] = 'public, max-age=86400'
|
||||
return response
|
||||
|
||||
@self.route('/<bot_uuid>/messages/<session_type>', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def get_embed_messages(bot_uuid: str, session_type: str) -> str:
|
||||
if not _is_valid_uuid(bot_uuid):
|
||||
return self.http_status(400, -1, 'Invalid bot_uuid format')
|
||||
runtime_bot, pipeline_uuid = self._resolve_bot(bot_uuid)
|
||||
if runtime_bot is None:
|
||||
return self.http_status(404, -1, 'Bot not found or not available')
|
||||
if not await self._verify_session_token(quart.request, bot_uuid):
|
||||
return self.http_status(403, -1, 'Unauthorized or session expired')
|
||||
try:
|
||||
if session_type not in ['person', 'group']:
|
||||
return self.http_status(400, -1, 'session_type must be person or group')
|
||||
|
||||
websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter
|
||||
if not websocket_adapter:
|
||||
return self.http_status(404, -1, 'WebSocket adapter not found')
|
||||
|
||||
messages = websocket_adapter.get_websocket_messages(pipeline_uuid, session_type)
|
||||
return self.success(data={'messages': messages})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to get embed messages: {e}', exc_info=True)
|
||||
return self.http_status(500, -1, 'Internal server error')
|
||||
|
||||
@self.route('/<bot_uuid>/reset/<session_type>', methods=['POST'], auth_type=group.AuthType.NONE)
|
||||
async def reset_embed_session(bot_uuid: str, session_type: str) -> str:
|
||||
if not _is_valid_uuid(bot_uuid):
|
||||
return self.http_status(400, -1, 'Invalid bot_uuid format')
|
||||
runtime_bot, pipeline_uuid = self._resolve_bot(bot_uuid)
|
||||
if runtime_bot is None:
|
||||
return self.http_status(404, -1, 'Bot not found or not available')
|
||||
if not await self._verify_session_token(quart.request, bot_uuid):
|
||||
return self.http_status(403, -1, 'Unauthorized or session expired')
|
||||
try:
|
||||
if session_type not in ['person', 'group']:
|
||||
return self.http_status(400, -1, 'session_type must be person or group')
|
||||
|
||||
websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter
|
||||
if not websocket_adapter:
|
||||
return self.http_status(404, -1, 'WebSocket adapter not found')
|
||||
|
||||
websocket_adapter.reset_session(pipeline_uuid, session_type)
|
||||
return self.success(data={'message': 'Session reset successfully'})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to reset embed session: {e}', exc_info=True)
|
||||
return self.http_status(500, -1, 'Internal server error')
|
||||
|
||||
@self.route('/<bot_uuid>/feedback', methods=['POST'], auth_type=group.AuthType.NONE)
|
||||
async def submit_feedback(bot_uuid: str) -> str:
|
||||
if not _is_valid_uuid(bot_uuid):
|
||||
return self.http_status(400, -1, 'Invalid bot_uuid format')
|
||||
runtime_bot, pipeline_uuid = self._resolve_bot(bot_uuid)
|
||||
if runtime_bot is None:
|
||||
return self.http_status(404, -1, 'Bot not found or not available')
|
||||
if not await self._verify_session_token(quart.request, bot_uuid):
|
||||
return self.http_status(403, -1, 'Unauthorized or session expired')
|
||||
try:
|
||||
data = await quart.request.get_json()
|
||||
message_id = data.get('message_id', '')
|
||||
feedback_type = data.get('feedback_type')
|
||||
|
||||
if feedback_type not in (1, 2, 3):
|
||||
return self.http_status(400, -1, 'feedback_type must be 1 (like), 2 (dislike), or 3 (cancel)')
|
||||
|
||||
feedback_id = f'embed_{uuid.uuid4().hex[:12]}'
|
||||
|
||||
await self.ap.monitoring_service.record_feedback(
|
||||
feedback_id=feedback_id,
|
||||
feedback_type=feedback_type,
|
||||
bot_id=runtime_bot.bot_entity.uuid,
|
||||
bot_name=runtime_bot.bot_entity.name or bot_uuid,
|
||||
pipeline_id=pipeline_uuid,
|
||||
message_id=str(message_id),
|
||||
platform='web_page_bot',
|
||||
)
|
||||
|
||||
return self.success(data={'feedback_id': feedback_id})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to record feedback: {e}', exc_info=True)
|
||||
return self.http_status(500, -1, 'Internal server error')
|
||||
|
||||
# -- Embed WebSocket endpoint ----------------------------------------
|
||||
|
||||
@self.quart_app.websocket(self.path + '/<bot_uuid>/ws/connect')
|
||||
async def embed_websocket_connect(bot_uuid: str):
|
||||
"""WebSocket connection for embed widget, keyed by bot_uuid."""
|
||||
if not _is_valid_uuid(bot_uuid):
|
||||
await quart.websocket.send(json.dumps({'type': 'error', 'message': 'Invalid bot_uuid format'}))
|
||||
return
|
||||
|
||||
runtime_bot, pipeline_uuid = self._resolve_bot(bot_uuid)
|
||||
if runtime_bot is None:
|
||||
await quart.websocket.send(json.dumps({'type': 'error', 'message': 'Bot not found or not available'}))
|
||||
return
|
||||
|
||||
session_type = quart.websocket.args.get('session_type', 'person')
|
||||
if session_type not in ['person', 'group']:
|
||||
await quart.websocket.send(
|
||||
json.dumps({'type': 'error', 'message': 'session_type must be person or group'})
|
||||
)
|
||||
return
|
||||
|
||||
websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter
|
||||
if not websocket_adapter:
|
||||
await quart.websocket.send(json.dumps({'type': 'error', 'message': 'WebSocket adapter not found'}))
|
||||
return
|
||||
|
||||
try:
|
||||
connection = await ws_connection_manager.add_connection(
|
||||
websocket=quart.websocket._get_current_object(),
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
session_type=session_type,
|
||||
metadata={'user_agent': quart.websocket.headers.get('User-Agent', '')},
|
||||
)
|
||||
|
||||
await quart.websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
'type': 'connected',
|
||||
'connection_id': connection.connection_id,
|
||||
'bot_uuid': bot_uuid,
|
||||
'session_type': session_type,
|
||||
'timestamp': connection.created_at.isoformat(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f'Embed WebSocket connected: {connection.connection_id} '
|
||||
f'(bot={bot_uuid}, pipeline={pipeline_uuid}, session_type={session_type})'
|
||||
)
|
||||
|
||||
receive_task = asyncio.create_task(self._handle_receive(connection, websocket_adapter, runtime_bot))
|
||||
send_task = asyncio.create_task(self._handle_send(connection))
|
||||
|
||||
try:
|
||||
await asyncio.gather(receive_task, send_task)
|
||||
except Exception as e:
|
||||
logger.error(f'Embed WebSocket task error: {e}')
|
||||
finally:
|
||||
await ws_connection_manager.remove_connection(connection.connection_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Embed WebSocket connection error: {e}', exc_info=True)
|
||||
try:
|
||||
await quart.websocket.send(json.dumps({'type': 'error', 'message': 'Internal server error'}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -- WebSocket receive/send helpers --------------------------------------
|
||||
|
||||
async def _handle_receive(self, connection, websocket_adapter, owner_bot):
|
||||
try:
|
||||
while connection.is_active:
|
||||
message = await quart.websocket.receive()
|
||||
await ws_connection_manager.update_activity(connection.connection_id)
|
||||
|
||||
try:
|
||||
data = json.loads(message)
|
||||
message_type = data.get('type', 'message')
|
||||
|
||||
if message_type == 'ping':
|
||||
await connection.send_queue.put(
|
||||
{'type': 'pong', 'timestamp': datetime.datetime.now().isoformat()}
|
||||
)
|
||||
elif message_type == 'message':
|
||||
await websocket_adapter.handle_websocket_message(connection, data, owner_bot=owner_bot)
|
||||
elif message_type == 'disconnect':
|
||||
break
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await connection.send_queue.put({'type': 'error', 'message': 'Invalid JSON format'})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Embed receive error: {e}', exc_info=True)
|
||||
finally:
|
||||
connection.is_active = False
|
||||
|
||||
async def _handle_send(self, connection):
|
||||
try:
|
||||
while connection.is_active:
|
||||
try:
|
||||
message = await asyncio.wait_for(connection.send_queue.get(), timeout=1.0)
|
||||
await quart.websocket.send(json.dumps(message))
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f'Embed send error: {e}', exc_info=True)
|
||||
finally:
|
||||
connection.is_active = False
|
||||
@@ -43,6 +43,9 @@ class WebSocketChatRouterGroup(group.RouterGroup):
|
||||
await quart.websocket.send(json.dumps({'type': 'error', 'message': 'WebSocket adapter not found'}))
|
||||
return
|
||||
|
||||
# Find the owning bot for this pipeline (e.g. a web_page_bot)
|
||||
owner_bot = self._find_owner_bot(pipeline_uuid)
|
||||
|
||||
# 注册连接
|
||||
connection = await ws_connection_manager.add_connection(
|
||||
websocket=quart.websocket._get_current_object(),
|
||||
@@ -70,7 +73,7 @@ class WebSocketChatRouterGroup(group.RouterGroup):
|
||||
)
|
||||
|
||||
# 创建接收和发送任务
|
||||
receive_task = asyncio.create_task(self._handle_receive(connection, websocket_adapter))
|
||||
receive_task = asyncio.create_task(self._handle_receive(connection, websocket_adapter, owner_bot))
|
||||
send_task = asyncio.create_task(self._handle_send(connection))
|
||||
|
||||
# 等待任务完成
|
||||
@@ -178,7 +181,14 @@ class WebSocketChatRouterGroup(group.RouterGroup):
|
||||
except Exception as e:
|
||||
return self.http_status(500, -1, f'Internal server error: {str(e)}')
|
||||
|
||||
async def _handle_receive(self, connection, websocket_adapter):
|
||||
def _find_owner_bot(self, pipeline_uuid: str):
|
||||
"""Find a user-created bot (e.g. web_page_bot) that owns this pipeline."""
|
||||
for bot in self.ap.platform_mgr.bots:
|
||||
if bot.bot_entity.adapter == 'web_page_bot' and bot.bot_entity.use_pipeline_uuid == pipeline_uuid:
|
||||
return bot
|
||||
return None
|
||||
|
||||
async def _handle_receive(self, connection, websocket_adapter, owner_bot=None):
|
||||
"""处理接收消息的任务"""
|
||||
try:
|
||||
while connection.is_active:
|
||||
@@ -203,7 +213,7 @@ class WebSocketChatRouterGroup(group.RouterGroup):
|
||||
logger.debug(f'收到消息: {data} from {connection.connection_id}')
|
||||
|
||||
# 处理消息(不等待响应,响应会通过broadcast异步发送)
|
||||
await websocket_adapter.handle_websocket_message(connection, data)
|
||||
await websocket_adapter.handle_websocket_message(connection, data, owner_bot=owner_bot)
|
||||
|
||||
elif message_type == 'disconnect':
|
||||
# 客户端主动断开
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import quart
|
||||
import mimetypes
|
||||
import asyncio
|
||||
from ... import group
|
||||
from langbot.pkg.utils import importutil
|
||||
|
||||
@@ -35,3 +36,640 @@ class AdaptersRouterGroup(group.RouterGroup):
|
||||
return quart.Response(
|
||||
importutil.read_resource_file_bytes(icon_path), mimetype=mimetypes.guess_type(icon_path)[0]
|
||||
)
|
||||
|
||||
# In-memory session store for active registrations
|
||||
_create_app_sessions: dict = {}
|
||||
_SESSION_TTL = 900 # 15 minutes
|
||||
|
||||
def _cleanup_expired_sessions():
|
||||
"""Remove sessions that have exceeded their TTL."""
|
||||
import time
|
||||
|
||||
now = time.time()
|
||||
expired = [sid for sid, s in _create_app_sessions.items() if now - s.get('created_at', 0) > _SESSION_TTL]
|
||||
for sid in expired:
|
||||
session = _create_app_sessions.pop(sid, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
|
||||
@self.route('/lark/create-app', methods=['POST'])
|
||||
async def _() -> str:
|
||||
"""Start Feishu one-click app registration. Returns session_id + QR code URL."""
|
||||
import uuid
|
||||
import time
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.scene.registration.errors import AppAccessDeniedError, AppExpiredError
|
||||
|
||||
_cleanup_expired_sessions()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
session = {
|
||||
'status': 'pending',
|
||||
'qr_url': None,
|
||||
'expire_at': None,
|
||||
'app_id': None,
|
||||
'app_secret': None,
|
||||
'error': None,
|
||||
'created_at': time.time(),
|
||||
}
|
||||
_create_app_sessions[session_id] = session
|
||||
|
||||
def on_qr_code(info):
|
||||
# May be called from a background thread by the SDK;
|
||||
# use call_soon_threadsafe to safely update session state.
|
||||
def _update():
|
||||
session['qr_url'] = info['url']
|
||||
session['expire_at'] = time.time() + 600 # 10 minutes
|
||||
session['status'] = 'waiting'
|
||||
|
||||
loop.call_soon_threadsafe(_update)
|
||||
|
||||
async def run_registration():
|
||||
try:
|
||||
result = await lark.aregister_app(
|
||||
on_qr_code=on_qr_code,
|
||||
source='langbot',
|
||||
)
|
||||
session['status'] = 'success'
|
||||
session['app_id'] = result['client_id']
|
||||
session['app_secret'] = result['client_secret']
|
||||
except AppAccessDeniedError:
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'User denied authorization'
|
||||
except AppExpiredError:
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code expired'
|
||||
except Exception as e:
|
||||
session['status'] = 'error'
|
||||
session['error'] = str(e)
|
||||
|
||||
task = asyncio.create_task(run_registration())
|
||||
session['task'] = task
|
||||
|
||||
# Wait for QR code to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
if session['qr_url']:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not session['qr_url']:
|
||||
task.cancel()
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Timeout waiting for QR code'
|
||||
return self.http_status(504, -1, 'Timeout waiting for QR code')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'session_id': session_id,
|
||||
'qr_url': session['qr_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/lark/create-app/status/<session_id>', methods=['GET'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Poll registration status."""
|
||||
session = _create_app_sessions.get(session_id)
|
||||
if not session:
|
||||
return self.http_status(404, -1, 'Session not found')
|
||||
|
||||
data = {'status': session['status']}
|
||||
|
||||
if session['status'] == 'success':
|
||||
data['app_id'] = session['app_id']
|
||||
data['app_secret'] = session['app_secret']
|
||||
_create_app_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'error':
|
||||
data['error'] = session['error']
|
||||
_create_app_sessions.pop(session_id, None)
|
||||
|
||||
return self.success(data=data)
|
||||
|
||||
@self.route('/lark/create-app/<session_id>', methods=['DELETE'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Cancel and clean up a registration session."""
|
||||
session = _create_app_sessions.pop(session_id, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
return self.success(data={})
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# WeChat QR Code Login
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
_weixin_login_sessions: dict = {}
|
||||
_WEIXIN_SESSION_TTL = 600 # 10 minutes (3 retries × 3 min QR validity)
|
||||
|
||||
def _cleanup_expired_weixin_sessions():
|
||||
import time
|
||||
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid for sid, s in _weixin_login_sessions.items() if now - s.get('created_at', 0) > _WEIXIN_SESSION_TTL
|
||||
]
|
||||
for sid in expired:
|
||||
session = _weixin_login_sessions.pop(sid, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
|
||||
@self.route('/weixin/login', methods=['POST'])
|
||||
async def _() -> str:
|
||||
"""Start WeChat QR code login. Returns session_id + QR code data URL."""
|
||||
import uuid
|
||||
import time
|
||||
import io
|
||||
import base64
|
||||
|
||||
from langbot.libs.openclaw_weixin_api.client import OpenClawWeixinClient, DEFAULT_BASE_URL
|
||||
|
||||
_cleanup_expired_weixin_sessions()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
session = {
|
||||
'status': 'pending',
|
||||
'qr_data_url': None,
|
||||
'expire_at': None,
|
||||
'token': None,
|
||||
'base_url': None,
|
||||
'account_id': None,
|
||||
'error': None,
|
||||
'created_at': time.time(),
|
||||
}
|
||||
_weixin_login_sessions[session_id] = session
|
||||
|
||||
client = OpenClawWeixinClient(
|
||||
base_url=DEFAULT_BASE_URL,
|
||||
token='',
|
||||
)
|
||||
|
||||
async def run_login():
|
||||
try:
|
||||
import qrcode as qr_lib
|
||||
|
||||
for _attempt in range(3):
|
||||
qr_resp = await client.fetch_qrcode()
|
||||
if not qr_resp.qrcode or not qr_resp.qrcode_img_content:
|
||||
raise Exception('Failed to get QR code from server')
|
||||
|
||||
# Generate QR code image locally
|
||||
qr = qr_lib.QRCode(error_correction=qr_lib.constants.ERROR_CORRECT_L)
|
||||
qr.add_data(qr_resp.qrcode_img_content)
|
||||
qr.make(fit=True)
|
||||
img = qr.make_image(fill_color='black', back_color='white')
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format='PNG')
|
||||
b64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
||||
data_url = f'data:image/png;base64,{b64}'
|
||||
|
||||
def _update_qr():
|
||||
session['qr_data_url'] = data_url
|
||||
session['expire_at'] = time.time() + 480 # 8 minutes
|
||||
session['status'] = 'waiting'
|
||||
|
||||
loop.call_soon_threadsafe(_update_qr)
|
||||
|
||||
# Poll for scan status
|
||||
deadline = loop.time() + 180
|
||||
while loop.time() < deadline:
|
||||
try:
|
||||
status_resp = await client.poll_qrcode_status(qr_resp.qrcode)
|
||||
except Exception:
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
if status_resp.status == 'confirmed' and status_resp.bot_token:
|
||||
session['status'] = 'success'
|
||||
session['token'] = status_resp.bot_token
|
||||
session['base_url'] = status_resp.baseurl or client.base_url
|
||||
session['account_id'] = status_resp.ilink_bot_id or ''
|
||||
return
|
||||
|
||||
if status_resp.status == 'expired':
|
||||
break # retry with new QR code
|
||||
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
pass # timeout, retry
|
||||
|
||||
# All retries exhausted
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code login failed: max retries exceeded'
|
||||
|
||||
except Exception as e:
|
||||
session['status'] = 'error'
|
||||
session['error'] = str(e)
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
task = asyncio.create_task(run_login())
|
||||
session['task'] = task
|
||||
|
||||
# Wait for QR code to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
if session['qr_data_url']:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not session['qr_data_url']:
|
||||
task.cancel()
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Timeout waiting for QR code'
|
||||
return self.http_status(504, -1, 'Timeout waiting for QR code')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'session_id': session_id,
|
||||
'qr_data_url': session['qr_data_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/weixin/login/status/<session_id>', methods=['GET'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Poll WeChat login status."""
|
||||
session = _weixin_login_sessions.get(session_id)
|
||||
if not session:
|
||||
return self.http_status(404, -1, 'Session not found')
|
||||
|
||||
data = {'status': session['status']}
|
||||
|
||||
if session['status'] == 'success':
|
||||
data['token'] = session['token']
|
||||
data['base_url'] = session['base_url']
|
||||
data['account_id'] = session['account_id']
|
||||
_weixin_login_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'error':
|
||||
data['error'] = session['error']
|
||||
_weixin_login_sessions.pop(session_id, None)
|
||||
|
||||
return self.success(data=data)
|
||||
|
||||
@self.route('/weixin/login/<session_id>', methods=['DELETE'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Cancel and clean up a WeChat login session."""
|
||||
session = _weixin_login_sessions.pop(session_id, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
return self.success(data={})
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# DingTalk Device Flow QR Code Login
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
_dingtalk_sessions: dict = {}
|
||||
_DINGTALK_SESSION_TTL = 600 # 10 minutes (QR code validity window)
|
||||
|
||||
def _cleanup_expired_dingtalk_sessions():
|
||||
import time
|
||||
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid for sid, s in _dingtalk_sessions.items() if now - s.get('created_at', 0) > _DINGTALK_SESSION_TTL
|
||||
]
|
||||
for sid in expired:
|
||||
session = _dingtalk_sessions.pop(sid, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
|
||||
@self.route('/dingtalk/create-app', methods=['POST'])
|
||||
async def _() -> str:
|
||||
"""Start DingTalk one-click app creation via Device Flow. Returns session_id + QR code URL."""
|
||||
import uuid
|
||||
import time
|
||||
import aiohttp
|
||||
|
||||
DINGTALK_BASE_URL = 'https://oapi.dingtalk.com'
|
||||
|
||||
_cleanup_expired_dingtalk_sessions()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
session = {
|
||||
'status': 'pending',
|
||||
'qr_url': None,
|
||||
'expire_at': None,
|
||||
'client_id': None,
|
||||
'client_secret': None,
|
||||
'error': None,
|
||||
'created_at': time.time(),
|
||||
'device_code': None,
|
||||
'interval': 5,
|
||||
}
|
||||
_dingtalk_sessions[session_id] = session
|
||||
|
||||
async def run_device_flow():
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as http:
|
||||
# Step 1: Init — get nonce
|
||||
async with http.post(
|
||||
f'{DINGTALK_BASE_URL}/app/registration/init',
|
||||
json={'source': 'langbot'},
|
||||
) as resp:
|
||||
try:
|
||||
data = await resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Invalid response from DingTalk service'
|
||||
return
|
||||
if data.get('errcode', -1) != 0:
|
||||
session['status'] = 'error'
|
||||
session['error'] = data.get('errmsg', 'Failed to init')
|
||||
return
|
||||
nonce = data['nonce']
|
||||
|
||||
# Step 2: Begin — get device_code + QR URL
|
||||
async with http.post(
|
||||
f'{DINGTALK_BASE_URL}/app/registration/begin',
|
||||
json={'nonce': nonce},
|
||||
) as resp:
|
||||
try:
|
||||
data = await resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Invalid response from DingTalk service'
|
||||
return
|
||||
if data.get('errcode', -1) != 0:
|
||||
session['status'] = 'error'
|
||||
session['error'] = data.get('errmsg', 'Failed to begin authorization')
|
||||
return
|
||||
|
||||
device_code = data['device_code']
|
||||
verification_uri_complete = data.get('verification_uri_complete', '')
|
||||
expires_in = data.get('expires_in', 7200)
|
||||
interval = data.get('interval', 5)
|
||||
|
||||
session['device_code'] = device_code
|
||||
session['interval'] = interval
|
||||
session['qr_url'] = verification_uri_complete
|
||||
session['expire_at'] = time.time() + 600 # QR code valid for ~10 min
|
||||
session['status'] = 'waiting'
|
||||
|
||||
# Step 3: Poll for authorization result
|
||||
deadline = time.time() + expires_in
|
||||
while time.time() < deadline:
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
async with http.post(
|
||||
f'{DINGTALK_BASE_URL}/app/registration/poll',
|
||||
json={'device_code': device_code},
|
||||
) as poll_resp:
|
||||
try:
|
||||
poll_data = await poll_resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
continue
|
||||
|
||||
if poll_data.get('errcode', -1) != 0:
|
||||
session['status'] = 'error'
|
||||
session['error'] = poll_data.get('errmsg', 'Poll failed')
|
||||
return
|
||||
|
||||
status = poll_data.get('status', '')
|
||||
|
||||
if status == 'SUCCESS':
|
||||
session['status'] = 'success'
|
||||
session['client_id'] = poll_data.get('client_id', '')
|
||||
session['client_secret'] = poll_data.get('client_secret', '')
|
||||
return
|
||||
elif status == 'FAIL':
|
||||
session['status'] = 'error'
|
||||
session['error'] = poll_data.get('fail_reason', 'Authorization failed')
|
||||
return
|
||||
elif status == 'EXPIRED':
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code expired'
|
||||
return
|
||||
# status == 'WAITING': continue polling
|
||||
|
||||
# Timeout
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code expired'
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
session['status'] = 'error'
|
||||
session['error'] = str(e)
|
||||
|
||||
task = asyncio.create_task(run_device_flow())
|
||||
session['task'] = task
|
||||
|
||||
# Wait for QR code to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
if session['qr_url'] or session['error']:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if session['error']:
|
||||
task.cancel()
|
||||
return self.http_status(502, -1, session['error'])
|
||||
|
||||
if not session['qr_url']:
|
||||
task.cancel()
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Timeout waiting for QR code'
|
||||
return self.http_status(504, -1, 'Timeout waiting for QR code')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'session_id': session_id,
|
||||
'qr_url': session['qr_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/dingtalk/create-app/status/<session_id>', methods=['GET'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Poll DingTalk Device Flow status."""
|
||||
_cleanup_expired_dingtalk_sessions()
|
||||
session = _dingtalk_sessions.get(session_id)
|
||||
if not session:
|
||||
return self.http_status(404, -1, 'Session not found')
|
||||
|
||||
data = {'status': session['status']}
|
||||
|
||||
if session['status'] == 'success':
|
||||
data['client_id'] = session['client_id']
|
||||
data['client_secret'] = session['client_secret']
|
||||
_dingtalk_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'error':
|
||||
data['error'] = session['error']
|
||||
_dingtalk_sessions.pop(session_id, None)
|
||||
|
||||
return self.success(data=data)
|
||||
|
||||
@self.route('/dingtalk/create-app/<session_id>', methods=['DELETE'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Cancel and clean up a DingTalk Device Flow session."""
|
||||
session = _dingtalk_sessions.pop(session_id, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
return self.success(data={})
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# WeComBot QR Code One-Click Create
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
_wecombot_sessions: dict = {}
|
||||
_WECOMBOT_SESSION_TTL = 300 # 5 minutes (WeCom QR validity window)
|
||||
|
||||
def _cleanup_expired_wecombot_sessions():
|
||||
import time
|
||||
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid for sid, s in _wecombot_sessions.items() if now - s.get('created_at', 0) > _WECOMBOT_SESSION_TTL
|
||||
]
|
||||
for sid in expired:
|
||||
session = _wecombot_sessions.pop(sid, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
|
||||
@self.route('/wecombot/create-bot', methods=['POST'])
|
||||
async def _() -> str:
|
||||
"""Start WeComBot one-click creation via QR code. Returns session_id + QR code URL."""
|
||||
import uuid
|
||||
import time
|
||||
import aiohttp
|
||||
|
||||
WECOM_QC_GENERATE_URL = 'https://work.weixin.qq.com/ai/qc/generate'
|
||||
WECOM_QC_QUERY_URL = 'https://work.weixin.qq.com/ai/qc/query_result'
|
||||
|
||||
_cleanup_expired_wecombot_sessions()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
session = {
|
||||
'status': 'pending',
|
||||
'qr_url': None,
|
||||
'expire_at': None,
|
||||
'botid': None,
|
||||
'secret': None,
|
||||
'error': None,
|
||||
'created_at': time.time(),
|
||||
'scode': None,
|
||||
'task': None,
|
||||
}
|
||||
_wecombot_sessions[session_id] = session
|
||||
|
||||
async def run_qr_flow():
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as http:
|
||||
# Step 1: Generate QR code
|
||||
async with http.get(
|
||||
f'{WECOM_QC_GENERATE_URL}?source=langbot&plat=0',
|
||||
) as resp:
|
||||
try:
|
||||
data = await resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Invalid response from WeCom service'
|
||||
return
|
||||
if not data.get('data', {}).get('scode') or not data.get('data', {}).get('auth_url'):
|
||||
session['status'] = 'error'
|
||||
session['error'] = data.get('errmsg', 'Failed to generate QR code')
|
||||
return
|
||||
|
||||
scode = data['data']['scode']
|
||||
auth_url = data['data']['auth_url']
|
||||
|
||||
session['scode'] = scode
|
||||
session['qr_url'] = auth_url
|
||||
session['expire_at'] = time.time() + _WECOMBOT_SESSION_TTL
|
||||
session['status'] = 'waiting'
|
||||
|
||||
# Step 2: Poll for scan result
|
||||
deadline = time.time() + _WECOMBOT_SESSION_TTL
|
||||
while time.time() < deadline:
|
||||
await asyncio.sleep(3)
|
||||
|
||||
async with http.get(
|
||||
f'{WECOM_QC_QUERY_URL}?scode={scode}',
|
||||
) as poll_resp:
|
||||
try:
|
||||
poll_data = await poll_resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
continue
|
||||
|
||||
status = poll_data.get('data', {}).get('status', '')
|
||||
if status == 'success':
|
||||
bot_info = poll_data.get('data', {}).get('bot_info', {})
|
||||
if bot_info.get('botid') and bot_info.get('secret'):
|
||||
session['status'] = 'success'
|
||||
session['botid'] = bot_info['botid']
|
||||
session['secret'] = bot_info['secret']
|
||||
return
|
||||
else:
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Scan succeeded but bot info is incomplete'
|
||||
return
|
||||
|
||||
# Timeout
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code expired'
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
session['status'] = 'error'
|
||||
session['error'] = str(e)
|
||||
|
||||
task = asyncio.create_task(run_qr_flow())
|
||||
session['task'] = task
|
||||
|
||||
# Wait for QR code to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
if session['qr_url'] or session['error']:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if session['error']:
|
||||
task.cancel()
|
||||
return self.http_status(502, -1, session['error'])
|
||||
|
||||
if not session['qr_url']:
|
||||
task.cancel()
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Timeout waiting for QR code'
|
||||
return self.http_status(504, -1, 'Timeout waiting for QR code')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'session_id': session_id,
|
||||
'qr_url': session['qr_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/wecombot/create-bot/status/<session_id>', methods=['GET'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Poll WeComBot creation status."""
|
||||
_cleanup_expired_wecombot_sessions()
|
||||
session = _wecombot_sessions.get(session_id)
|
||||
if not session:
|
||||
return self.http_status(404, -1, 'Session not found')
|
||||
|
||||
data = {'status': session['status']}
|
||||
|
||||
if session['status'] == 'success':
|
||||
data['botid'] = session['botid']
|
||||
data['secret'] = session['secret']
|
||||
_wecombot_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'error':
|
||||
data['error'] = session['error']
|
||||
_wecombot_sessions.pop(session_id, None)
|
||||
|
||||
return self.success(data=data)
|
||||
|
||||
@self.route('/wecombot/create-bot/<session_id>', methods=['DELETE'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Cancel and clean up a WeComBot creation session."""
|
||||
session = _wecombot_sessions.pop(session_id, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
return self.success(data={})
|
||||
|
||||
@@ -6,11 +6,50 @@ import re
|
||||
import httpx
|
||||
import uuid
|
||||
import os
|
||||
import posixpath
|
||||
import sqlalchemy
|
||||
|
||||
from .....core import taskmgr
|
||||
from .....entity.persistence import plugin as persistence_plugin
|
||||
from .. import group
|
||||
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
|
||||
|
||||
# Resolve the built-in page SDK JS from the langbot_plugin package
|
||||
_PAGE_SDK_PATH = None
|
||||
try:
|
||||
import langbot_plugin.assets as _assets_pkg
|
||||
|
||||
_candidate = os.path.join(os.path.dirname(_assets_pkg.__file__), 'langbot-page-sdk.js')
|
||||
if os.path.exists(_candidate):
|
||||
_PAGE_SDK_PATH = _candidate
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _normalize_plugin_asset_path(filepath: str) -> str | None:
|
||||
filepath = filepath.replace('\\', '/')
|
||||
if filepath.startswith('/'):
|
||||
return None
|
||||
|
||||
normalized = posixpath.normpath(filepath)
|
||||
if normalized == '.' or normalized.startswith('../') or normalized == '..':
|
||||
return None
|
||||
|
||||
if normalized.startswith('components/pages/'):
|
||||
return normalized
|
||||
|
||||
return f'assets/{normalized}'
|
||||
|
||||
|
||||
def _get_request_origin() -> str:
|
||||
"""Return the public request origin, respecting reverse-proxy headers."""
|
||||
forwarded_proto = quart.request.headers.get('X-Forwarded-Proto', '').split(',')[0].strip()
|
||||
forwarded_host = quart.request.headers.get('X-Forwarded-Host', '').split(',')[0].strip()
|
||||
|
||||
scheme = forwarded_proto or quart.request.scheme
|
||||
host = forwarded_host or quart.request.host
|
||||
return f'{scheme}://{host}'
|
||||
|
||||
|
||||
@group.group_class('plugins', '/api/v1/plugins')
|
||||
class PluginsRouterGroup(group.RouterGroup):
|
||||
@@ -27,6 +66,15 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
return None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/_sdk/page-sdk.js', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def _() -> quart.Response:
|
||||
"""Serve the built-in LangBot page SDK JavaScript."""
|
||||
if _PAGE_SDK_PATH and os.path.exists(_PAGE_SDK_PATH):
|
||||
with open(_PAGE_SDK_PATH, 'r') as f:
|
||||
content = f.read()
|
||||
return quart.Response(content, mimetype='application/javascript')
|
||||
return quart.Response('// SDK not found', status=404, mimetype='application/javascript')
|
||||
|
||||
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
plugins = await self.ap.plugin_connector.list_plugins()
|
||||
@@ -102,7 +150,15 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
return self.http_status(404, -1, 'plugin not found')
|
||||
|
||||
if quart.request.method == 'GET':
|
||||
return self.success(data={'config': plugin['plugin_config']})
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_plugin.PluginSetting.config)
|
||||
.where(persistence_plugin.PluginSetting.plugin_author == author)
|
||||
.where(persistence_plugin.PluginSetting.plugin_name == plugin_name)
|
||||
)
|
||||
persisted_config = result.scalar_one_or_none()
|
||||
|
||||
config = persisted_config if persisted_config is not None else plugin['plugin_config']
|
||||
return self.success(data={'config': config})
|
||||
elif quart.request.method == 'PUT':
|
||||
data = await quart.request.json
|
||||
|
||||
@@ -135,15 +191,62 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
return quart.Response(icon_data, mimetype=mime_type)
|
||||
|
||||
@self.route(
|
||||
'/<author>/<plugin_name>/assets/<filepath>',
|
||||
'/<author>/<plugin_name>/assets/<path:filepath>',
|
||||
methods=['GET'],
|
||||
auth_type=group.AuthType.NONE,
|
||||
)
|
||||
async def _(author: str, plugin_name: str, filepath: str) -> quart.Response:
|
||||
asset_data = await self.ap.plugin_connector.get_plugin_assets(author, plugin_name, filepath)
|
||||
asset_path = _normalize_plugin_asset_path(filepath)
|
||||
if asset_path is None:
|
||||
return quart.Response('Asset not found', status=404)
|
||||
|
||||
asset_data = await self.ap.plugin_connector.get_plugin_assets(author, plugin_name, asset_path)
|
||||
if not asset_data.get('asset_base64'):
|
||||
return quart.Response('Asset not found', status=404)
|
||||
asset_bytes = base64.b64decode(asset_data['asset_base64'])
|
||||
mime_type = asset_data['mime_type']
|
||||
return quart.Response(asset_bytes, mimetype=mime_type)
|
||||
resp = quart.Response(asset_bytes, mimetype=mime_type)
|
||||
# CSP for HTML pages served to sandboxed iframes (opaque origin).
|
||||
# 'self' doesn't work in sandboxed iframes — use actual server origin.
|
||||
if mime_type and mime_type.startswith('text/html'):
|
||||
origin = _get_request_origin()
|
||||
resp.headers['Content-Security-Policy'] = (
|
||||
f'default-src {origin}; '
|
||||
f"script-src {origin} 'unsafe-inline'; "
|
||||
f"style-src {origin} 'unsafe-inline'; "
|
||||
f'img-src {origin} data:; '
|
||||
f'connect-src {origin}; '
|
||||
"frame-src 'none'; "
|
||||
"object-src 'none'"
|
||||
)
|
||||
return resp
|
||||
|
||||
@self.route(
|
||||
'/<author>/<plugin_name>/page-api',
|
||||
methods=['POST'],
|
||||
auth_type=group.AuthType.USER_TOKEN_OR_API_KEY,
|
||||
)
|
||||
async def _(author: str, plugin_name: str) -> str:
|
||||
"""Forward a page API request to the plugin."""
|
||||
data = await quart.request.json
|
||||
if not isinstance(data, dict):
|
||||
return self.http_status(400, -1, 'invalid request body')
|
||||
|
||||
page_id = data.get('page_id', '')
|
||||
endpoint = data.get('endpoint', '')
|
||||
method = data.get('method', 'POST')
|
||||
body = data.get('body')
|
||||
if not isinstance(page_id, str) or not isinstance(endpoint, str) or not isinstance(method, str):
|
||||
return self.http_status(400, -1, 'invalid page api request')
|
||||
if not endpoint.startswith('/') or '..' in endpoint:
|
||||
return self.http_status(400, -1, 'invalid endpoint')
|
||||
|
||||
result = await self.ap.plugin_connector.handle_page_api(
|
||||
author, plugin_name, page_id, endpoint, method.upper(), body
|
||||
)
|
||||
if result.get('error'):
|
||||
return self.http_status(400, -1, result['error'])
|
||||
return self.success(data=result.get('data'))
|
||||
|
||||
@self.route('/github/releases', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
@@ -265,6 +368,8 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
return self.http_status(400, -1, 'Missing asset_url parameter')
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
ctx.metadata['plugin_name'] = f'{owner}/{repo}'
|
||||
ctx.metadata['install_source'] = 'github'
|
||||
install_info = {
|
||||
'asset_url': asset_url,
|
||||
'owner': owner,
|
||||
@@ -295,12 +400,17 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
|
||||
data = await quart.request.json
|
||||
|
||||
plugin_author = data.get('plugin_author', '')
|
||||
plugin_name = data.get('plugin_name', '')
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
ctx.metadata['plugin_name'] = f'{plugin_author}/{plugin_name}'
|
||||
ctx.metadata['install_source'] = 'marketplace'
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self.ap.plugin_connector.install_plugin(PluginInstallSource.MARKETPLACE, data, task_context=ctx),
|
||||
kind='plugin-operation',
|
||||
name='plugin-install-marketplace',
|
||||
label=f'Installing plugin from marketplace ...{data}',
|
||||
label=f'Installing plugin from marketplace {plugin_author}/{plugin_name}',
|
||||
context=ctx,
|
||||
)
|
||||
|
||||
@@ -323,11 +433,13 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
}
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
ctx.metadata['plugin_name'] = file.filename or 'local plugin'
|
||||
ctx.metadata['install_source'] = 'local'
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self.ap.plugin_connector.install_plugin(PluginInstallSource.LOCAL, data, task_context=ctx),
|
||||
kind='plugin-operation',
|
||||
name='plugin-install-local',
|
||||
label=f'Installing plugin from local ...{file.filename}',
|
||||
label=f'Installing plugin from local {file.filename}',
|
||||
context=ctx,
|
||||
)
|
||||
|
||||
|
||||
@@ -97,3 +97,51 @@ class EmbeddingModelsRouterGroup(group.RouterGroup):
|
||||
await self.ap.embedding_models_service.test_embedding_model(model_uuid, json_data)
|
||||
|
||||
return self.success()
|
||||
|
||||
|
||||
@group.group_class('models/rerank', '/api/v1/provider/models/rerank')
|
||||
class RerankModelsRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
provider_uuid = quart.request.args.get('provider_uuid')
|
||||
if provider_uuid:
|
||||
return self.success(
|
||||
data={
|
||||
'models': await self.ap.rerank_models_service.get_rerank_models_by_provider(provider_uuid)
|
||||
}
|
||||
)
|
||||
return self.success(data={'models': await self.ap.rerank_models_service.get_rerank_models()})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
model_uuid = await self.ap.rerank_models_service.create_rerank_model(json_data)
|
||||
return self.success(data={'uuid': model_uuid})
|
||||
|
||||
@self.route('/<model_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(model_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
model = await self.ap.rerank_models_service.get_rerank_model(model_uuid)
|
||||
|
||||
if model is None:
|
||||
return self.http_status(404, -1, 'model not found')
|
||||
|
||||
return self.success(data={'model': model})
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
|
||||
await self.ap.rerank_models_service.update_rerank_model(model_uuid, json_data)
|
||||
|
||||
return self.success()
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.rerank_models_service.delete_rerank_model(model_uuid)
|
||||
|
||||
return self.success()
|
||||
|
||||
@self.route('/<model_uuid>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(model_uuid: str) -> str:
|
||||
json_data = await quart.request.json
|
||||
|
||||
await self.ap.rerank_models_service.test_rerank_model(model_uuid, json_data)
|
||||
|
||||
return self.success()
|
||||
|
||||
@@ -15,6 +15,7 @@ class ModelProvidersRouterGroup(group.RouterGroup):
|
||||
counts = await self.ap.provider_service.get_provider_model_counts(provider['uuid'])
|
||||
provider['llm_count'] = counts['llm_count']
|
||||
provider['embedding_count'] = counts['embedding_count']
|
||||
provider['rerank_count'] = counts['rerank_count']
|
||||
return self.success(data={'providers': providers})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
@@ -32,6 +33,7 @@ class ModelProvidersRouterGroup(group.RouterGroup):
|
||||
counts = await self.ap.provider_service.get_provider_model_counts(provider_uuid)
|
||||
provider['llm_count'] = counts['llm_count']
|
||||
provider['embedding_count'] = counts['embedding_count']
|
||||
provider['rerank_count'] = counts['rerank_count']
|
||||
return self.success(data={'provider': provider})
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
@@ -43,3 +45,12 @@ class ModelProvidersRouterGroup(group.RouterGroup):
|
||||
return self.success()
|
||||
except ValueError as e:
|
||||
return self.http_status(400, -1, str(e))
|
||||
|
||||
@self.route('/<provider_uuid>/scan-models', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(provider_uuid: str) -> str:
|
||||
try:
|
||||
model_type = quart.request.args.get('type')
|
||||
result = await self.ap.provider_service.scan_provider_models(provider_uuid, model_type)
|
||||
return self.success(data=result)
|
||||
except ValueError as e:
|
||||
return self.http_status(400, -1, str(e))
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ... import group
|
||||
|
||||
|
||||
@group.group_class('tools', '/api/v1/tools')
|
||||
class ToolsRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
"""获取所有可用工具列表"""
|
||||
tools = await self.ap.tool_mgr.get_all_tools()
|
||||
|
||||
tool_list = []
|
||||
for tool in tools:
|
||||
tool_list.append(
|
||||
{
|
||||
'name': tool.name,
|
||||
'description': tool.description,
|
||||
'human_desc': tool.human_desc,
|
||||
'parameters': tool.parameters,
|
||||
}
|
||||
)
|
||||
|
||||
return self.success(data={'tools': tool_list})
|
||||
|
||||
@self.route('/<tool_name>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(tool_name: str) -> str:
|
||||
"""获取特定工具详情"""
|
||||
tools = await self.ap.tool_mgr.get_all_tools()
|
||||
|
||||
for tool in tools:
|
||||
if tool.name == tool_name:
|
||||
return self.success(
|
||||
data={
|
||||
'tool': {
|
||||
'name': tool.name,
|
||||
'description': tool.description,
|
||||
'human_desc': tool.human_desc,
|
||||
'parameters': tool.parameters,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return self.http_status(404, -1, f'Tool not found: {tool_name}')
|
||||
@@ -1,7 +1,11 @@
|
||||
import json
|
||||
|
||||
import quart
|
||||
import sqlalchemy
|
||||
|
||||
from .. import group
|
||||
from .....utils import constants
|
||||
from .....entity.persistence.metadata import Metadata
|
||||
|
||||
|
||||
@group.group_class('system', '/api/v1/system')
|
||||
@@ -9,6 +13,24 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/info', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def _() -> str:
|
||||
# Read wizard_status and wizard_progress from metadata table
|
||||
wizard_status = 'none'
|
||||
wizard_progress = None
|
||||
try:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(Metadata).where(Metadata.key.in_(['wizard_status', 'wizard_progress']))
|
||||
)
|
||||
for row in result:
|
||||
if row.key == 'wizard_status':
|
||||
wizard_status = row.value
|
||||
elif row.key == 'wizard_progress':
|
||||
try:
|
||||
wizard_progress = json.loads(row.value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
wizard_progress = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'version': constants.semantic_version,
|
||||
@@ -27,17 +49,83 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
'disable_models_service', False
|
||||
),
|
||||
'limitation': self.ap.instance_config.data.get('system', {}).get('limitation', {}),
|
||||
'wizard_status': wizard_status,
|
||||
'wizard_progress': wizard_progress,
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/wizard/completed', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
"""Mark wizard status in metadata table and clear progress.
|
||||
|
||||
Accepts JSON body: { "status": "skipped" | "completed" }
|
||||
"""
|
||||
data = await quart.request.get_json(silent=True) or {}
|
||||
status = data.get('status', 'completed')
|
||||
if status not in ('skipped', 'completed'):
|
||||
return self.http_status(400, 400, f'Invalid wizard status: {status}')
|
||||
|
||||
try:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(Metadata).where(Metadata.key == 'wizard_status')
|
||||
)
|
||||
if result.first():
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(Metadata).where(Metadata.key == 'wizard_status').values(value=status)
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(Metadata).values(key='wizard_status', value=status)
|
||||
)
|
||||
|
||||
# Clear wizard progress when wizard is completed/skipped
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(Metadata).where(Metadata.key == 'wizard_progress')
|
||||
)
|
||||
except Exception as e:
|
||||
return self.http_status(500, 500, f'Failed to update wizard status: {e}')
|
||||
|
||||
return self.success(data={})
|
||||
|
||||
@self.route('/wizard/progress', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
"""Save wizard progress to metadata table.
|
||||
|
||||
Accepts JSON body with wizard state fields:
|
||||
{ "step": int, "selected_adapter": str|null, "created_bot_uuid": str|null,
|
||||
"bot_saved": bool, "selected_runner": str|null }
|
||||
"""
|
||||
data = await quart.request.get_json(silent=True) or {}
|
||||
progress_json = json.dumps(data, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(Metadata).where(Metadata.key == 'wizard_progress')
|
||||
)
|
||||
if result.first():
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(Metadata).where(Metadata.key == 'wizard_progress').values(value=progress_json)
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(Metadata).values(key='wizard_progress', value=progress_json)
|
||||
)
|
||||
except Exception as e:
|
||||
return self.http_status(500, 500, f'Failed to save wizard progress: {e}')
|
||||
|
||||
return self.success(data={})
|
||||
|
||||
@self.route('/tasks', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
task_type = quart.request.args.get('type')
|
||||
task_kind = quart.request.args.get('kind')
|
||||
|
||||
if task_type == '':
|
||||
task_type = None
|
||||
if task_kind == '':
|
||||
task_kind = None
|
||||
|
||||
return self.success(data=self.ap.task_mgr.get_tasks_dict(task_type))
|
||||
return self.success(data=self.ap.task_mgr.get_tasks_dict(task_type, task_kind))
|
||||
|
||||
@self.route('/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(task_id: str) -> str:
|
||||
@@ -48,16 +136,9 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
|
||||
return self.success(data=task.to_dict())
|
||||
|
||||
@self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
@self.route('/storage-analysis', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
return self.http_status(403, 403, 'Forbidden')
|
||||
|
||||
py_code = await quart.request.data
|
||||
|
||||
ap = self.ap
|
||||
|
||||
return self.success(data=exec(py_code, {'ap': ap}))
|
||||
return self.success(data=await self.ap.maintenance_service.get_storage_analysis())
|
||||
|
||||
@self.route(
|
||||
'/debug/plugin/action',
|
||||
|
||||
@@ -146,6 +146,7 @@ class UserRouterGroup(group.RouterGroup):
|
||||
return self.fail(3, str(e))
|
||||
except ValueError as e:
|
||||
traceback.print_exc()
|
||||
self.ap.logger.warning(f'Space OAuth callback failed: {e}')
|
||||
return self.fail(1, str(e))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
# Workflow router group
|
||||
from .workflows import WorkflowsRouterGroup, ExecutionsRouterGroup
|
||||
from .websocket_chat import WorkflowWebSocketChatRouterGroup
|
||||
|
||||
__all__ = ['WorkflowsRouterGroup', 'ExecutionsRouterGroup', 'WorkflowWebSocketChatRouterGroup']
|
||||
@@ -0,0 +1,260 @@
|
||||
"""Workflow WebSocket聊天路由 - 支持工作流调试的双向实时通信"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
|
||||
import quart
|
||||
|
||||
from ... import group
|
||||
from ......platform.sources.websocket_manager import ws_connection_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@group.group_class('workflow_websocket_chat', '/api/v1/workflows/<workflow_uuid>/ws')
|
||||
class WorkflowWebSocketChatRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.quart_app.websocket(self.path + '/connect')
|
||||
async def workflow_websocket_connect(workflow_uuid: str):
|
||||
"""
|
||||
建立工作流WebSocket连接
|
||||
|
||||
URL参数:
|
||||
- workflow_uuid: 工作流UUID
|
||||
- session_type: 会话类型 (person/group)
|
||||
"""
|
||||
try:
|
||||
session_type = quart.websocket.args.get('session_type', 'person')
|
||||
logger.info(
|
||||
'Workflow WebSocket connect request received',
|
||||
extra={
|
||||
'workflow_uuid': workflow_uuid,
|
||||
'session_type': session_type,
|
||||
'path': quart.websocket.path,
|
||||
'query_string': quart.websocket.query_string.decode('utf-8', errors='ignore'),
|
||||
'remote_addr': getattr(quart.websocket, 'remote_addr', None),
|
||||
'user_agent': quart.websocket.headers.get('User-Agent', ''),
|
||||
'host': quart.websocket.headers.get('Host', ''),
|
||||
'origin': quart.websocket.headers.get('Origin', ''),
|
||||
},
|
||||
)
|
||||
|
||||
if session_type not in ['person', 'group']:
|
||||
await quart.websocket.send(
|
||||
json.dumps({'type': 'error', 'message': 'session_type must be person or group'})
|
||||
)
|
||||
return
|
||||
|
||||
websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter
|
||||
|
||||
if not websocket_adapter:
|
||||
logger.warning(
|
||||
'Workflow WebSocket adapter missing',
|
||||
extra={
|
||||
'workflow_uuid': workflow_uuid,
|
||||
'session_type': session_type,
|
||||
},
|
||||
)
|
||||
await quart.websocket.send(json.dumps({'type': 'error', 'message': 'WebSocket adapter not found'}))
|
||||
return
|
||||
|
||||
connection = await ws_connection_manager.add_connection(
|
||||
websocket=quart.websocket._get_current_object(),
|
||||
pipeline_uuid=workflow_uuid,
|
||||
session_type=session_type,
|
||||
metadata={'user_agent': quart.websocket.headers.get('User-Agent', ''), 'is_workflow': True},
|
||||
)
|
||||
|
||||
await quart.websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
'type': 'connected',
|
||||
'connection_id': connection.connection_id,
|
||||
'workflow_uuid': workflow_uuid,
|
||||
'session_type': session_type,
|
||||
'timestamp': connection.created_at.isoformat(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f'Workflow WebSocket connection established: {connection.connection_id} '
|
||||
f'(workflow={workflow_uuid}, session_type={session_type})'
|
||||
)
|
||||
|
||||
receive_task = asyncio.create_task(self._handle_receive(connection, websocket_adapter))
|
||||
send_task = asyncio.create_task(self._handle_send(connection))
|
||||
|
||||
try:
|
||||
await asyncio.gather(receive_task, send_task)
|
||||
except Exception as e:
|
||||
logger.error(f'Workflow WebSocket task execution error: {e}')
|
||||
finally:
|
||||
await ws_connection_manager.remove_connection(connection.connection_id)
|
||||
logger.debug(f'Workflow WebSocket connection cleaned: {connection.connection_id}')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Workflow WebSocket connection error',
|
||||
exc_info=True,
|
||||
extra={
|
||||
'workflow_uuid': workflow_uuid,
|
||||
'session_type': quart.websocket.args.get('session_type', 'person'),
|
||||
'path': quart.websocket.path,
|
||||
'query_string': quart.websocket.query_string.decode('utf-8', errors='ignore'),
|
||||
'remote_addr': getattr(quart.websocket, 'remote_addr', None),
|
||||
},
|
||||
)
|
||||
try:
|
||||
await quart.websocket.send(json.dumps({'type': 'error', 'message': str(e)}))
|
||||
except Exception as send_error:
|
||||
logger.debug(
|
||||
'Failed to send error message to workflow websocket client',
|
||||
exc_info=True,
|
||||
extra={
|
||||
'workflow_uuid': workflow_uuid,
|
||||
'send_error': str(send_error),
|
||||
},
|
||||
)
|
||||
|
||||
@self.route('/messages/<session_type>', methods=['GET'])
|
||||
async def get_messages(workflow_uuid: str, session_type: str) -> str:
|
||||
"""获取工作流消息历史"""
|
||||
try:
|
||||
if session_type not in ['person', 'group']:
|
||||
return self.http_status(400, -1, 'session_type must be person or group')
|
||||
|
||||
websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter
|
||||
|
||||
if not websocket_adapter:
|
||||
return self.http_status(404, -1, 'WebSocket adapter not found')
|
||||
|
||||
messages = websocket_adapter.get_websocket_messages(workflow_uuid, session_type)
|
||||
|
||||
return self.success(data={'messages': messages})
|
||||
|
||||
except Exception as e:
|
||||
return self.http_status(500, -1, f'Internal server error: {str(e)}')
|
||||
|
||||
@self.route('/reset/<session_type>', methods=['POST'])
|
||||
async def reset_session(workflow_uuid: str, session_type: str) -> str:
|
||||
"""重置工作流会话"""
|
||||
try:
|
||||
if session_type not in ['person', 'group']:
|
||||
return self.http_status(400, -1, 'session_type must be person or group')
|
||||
|
||||
websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter
|
||||
|
||||
if not websocket_adapter:
|
||||
return self.http_status(404, -1, 'WebSocket adapter not found')
|
||||
|
||||
websocket_adapter.reset_session(workflow_uuid, session_type)
|
||||
|
||||
return self.success(data={'message': 'Session reset successfully'})
|
||||
|
||||
except Exception as e:
|
||||
return self.http_status(500, -1, f'Internal server error: {str(e)}')
|
||||
|
||||
@self.route('/connections', methods=['GET'])
|
||||
async def get_connections(workflow_uuid: str) -> str:
|
||||
"""获取当前工作流连接统计"""
|
||||
try:
|
||||
stats = ws_connection_manager.get_stats()
|
||||
connections = await ws_connection_manager.get_connections_by_pipeline(workflow_uuid)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'stats': stats,
|
||||
'connections': [
|
||||
{
|
||||
'connection_id': conn.connection_id,
|
||||
'session_type': conn.session_type,
|
||||
'created_at': conn.created_at.isoformat(),
|
||||
'last_active': conn.last_active.isoformat(),
|
||||
'is_active': conn.is_active,
|
||||
}
|
||||
for conn in connections
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return self.http_status(500, -1, f'Internal server error: {str(e)}')
|
||||
|
||||
@self.route('/broadcast', methods=['POST'])
|
||||
async def broadcast_message(workflow_uuid: str) -> str:
|
||||
"""向所有工作流连接广播消息"""
|
||||
try:
|
||||
data = await quart.request.get_json()
|
||||
message = data.get('message')
|
||||
|
||||
if not message:
|
||||
return self.http_status(400, -1, 'message is required')
|
||||
|
||||
broadcast_data = {
|
||||
'type': 'broadcast',
|
||||
'message': message,
|
||||
'timestamp': datetime.datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
await ws_connection_manager.broadcast_to_pipeline(workflow_uuid, broadcast_data)
|
||||
|
||||
return self.success(data={'message': 'Broadcast sent successfully'})
|
||||
|
||||
except Exception as e:
|
||||
return self.http_status(500, -1, f'Internal server error: {str(e)}')
|
||||
|
||||
async def _handle_receive(self, connection, websocket_adapter):
|
||||
"""处理接收消息的任务"""
|
||||
try:
|
||||
while connection.is_active:
|
||||
message = await quart.websocket.receive()
|
||||
|
||||
await ws_connection_manager.update_activity(connection.connection_id)
|
||||
|
||||
try:
|
||||
data = json.loads(message)
|
||||
message_type = data.get('type', 'message')
|
||||
|
||||
if message_type == 'ping':
|
||||
await connection.send_queue.put(
|
||||
{'type': 'pong', 'timestamp': datetime.datetime.now().isoformat()}
|
||||
)
|
||||
|
||||
elif message_type == 'message':
|
||||
logger.debug(f'收到工作流消息: {data} from {connection.connection_id}')
|
||||
await websocket_adapter.handle_websocket_message(connection, data)
|
||||
|
||||
elif message_type == 'disconnect':
|
||||
logger.debug(f'Client disconnected: {connection.connection_id}')
|
||||
break
|
||||
|
||||
else:
|
||||
logger.warning(f'Unknown message type: {message_type}')
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f'Invalid JSON message: {message}')
|
||||
await connection.send_queue.put({'type': 'error', 'message': 'Invalid JSON format'})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Receive message error: {e}', exc_info=True)
|
||||
finally:
|
||||
connection.is_active = False
|
||||
|
||||
async def _handle_send(self, connection):
|
||||
"""处理发送消息的任务"""
|
||||
try:
|
||||
while connection.is_active:
|
||||
try:
|
||||
message = await asyncio.wait_for(connection.send_queue.get(), timeout=1.0)
|
||||
await quart.websocket.send(json.dumps(message))
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Send message error: {e}', exc_info=True)
|
||||
finally:
|
||||
connection.is_active = False
|
||||
@@ -0,0 +1,482 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import quart
|
||||
|
||||
from ... import group
|
||||
from ....service.workflow import WorkflowExecutionFailedError
|
||||
|
||||
|
||||
@group.group_class('workflows', '/api/v1/workflows')
|
||||
class WorkflowsRouterGroup(group.RouterGroup):
|
||||
"""Workflow API router group"""
|
||||
|
||||
async def initialize(self) -> None:
|
||||
# Workflow CRUD
|
||||
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
sort_by = quart.request.args.get('sort_by', 'created_at')
|
||||
sort_order = quart.request.args.get('sort_order', 'DESC')
|
||||
enabled_only = quart.request.args.get('enabled_only', 'false').lower() == 'true'
|
||||
return self.success(
|
||||
data={'workflows': await self.ap.workflow_service.get_workflows(sort_by, sort_order, enabled_only)}
|
||||
)
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
workflow_uuid = await self.ap.workflow_service.create_workflow(json_data)
|
||||
return self.success(data={'uuid': workflow_uuid})
|
||||
|
||||
# Get node types (available nodes for the editor)
|
||||
@self.route('/_/node-types', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
return self.success(
|
||||
data={
|
||||
'node_types': await self.ap.workflow_service.get_node_types(),
|
||||
'categories': await self.ap.workflow_service.get_node_types_by_category_meta(),
|
||||
}
|
||||
)
|
||||
|
||||
# Get node types by category
|
||||
@self.route('/_/node-types/categories', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
return self.success(data={'categories': await self.ap.workflow_service.get_node_types_by_category()})
|
||||
|
||||
# Single workflow operations
|
||||
@self.route(
|
||||
'/<workflow_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY
|
||||
)
|
||||
async def _(workflow_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
workflow = await self.ap.workflow_service.get_workflow(workflow_uuid)
|
||||
if workflow is None:
|
||||
return self.http_status(404, -1, 'workflow not found')
|
||||
return self.success(data={'workflow': workflow})
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
try:
|
||||
await self.ap.workflow_service.update_workflow(workflow_uuid, json_data)
|
||||
return self.success()
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.workflow_service.delete_workflow(workflow_uuid)
|
||||
return self.success()
|
||||
|
||||
# Publish workflow (enable)
|
||||
@self.route('/<workflow_uuid>/publish', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(workflow_uuid: str) -> str:
|
||||
try:
|
||||
await self.ap.workflow_service.publish_workflow(workflow_uuid)
|
||||
return self.success()
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# Unpublish workflow (disable)
|
||||
@self.route('/<workflow_uuid>/unpublish', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(workflow_uuid: str) -> str:
|
||||
try:
|
||||
await self.ap.workflow_service.unpublish_workflow(workflow_uuid)
|
||||
return self.success()
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# Copy workflow
|
||||
@self.route('/<workflow_uuid>/copy', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(workflow_uuid: str) -> str:
|
||||
try:
|
||||
new_uuid = await self.ap.workflow_service.copy_workflow(workflow_uuid)
|
||||
return self.success(data={'uuid': new_uuid})
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# Execute workflow manually
|
||||
@self.route('/<workflow_uuid>/execute', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(workflow_uuid: str) -> str:
|
||||
json_data = await quart.request.json or {}
|
||||
trigger_data = json_data.get('trigger_data', {})
|
||||
session_id = json_data.get('session_id')
|
||||
user_id = json_data.get('user_id')
|
||||
bot_id = json_data.get('bot_id')
|
||||
|
||||
try:
|
||||
execution_id = await self.ap.workflow_service.execute_workflow(
|
||||
workflow_uuid,
|
||||
trigger_type='manual',
|
||||
trigger_data=trigger_data,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
bot_id=bot_id,
|
||||
)
|
||||
return self.success(data={'execution_id': execution_id})
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
except WorkflowExecutionFailedError as e:
|
||||
return self.http_status(500, -1, e.message)
|
||||
|
||||
# Get workflow executions
|
||||
@self.route('/<workflow_uuid>/executions', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(workflow_uuid: str) -> str:
|
||||
limit = int(quart.request.args.get('limit', 50))
|
||||
offset = int(quart.request.args.get('offset', 0))
|
||||
executions = await self.ap.workflow_service.get_executions(
|
||||
workflow_uuid=workflow_uuid, limit=limit, offset=offset
|
||||
)
|
||||
return self.success(data=executions)
|
||||
|
||||
@self.route(
|
||||
'/<workflow_uuid>/executions/<execution_uuid>',
|
||||
methods=['GET'],
|
||||
auth_type=group.AuthType.USER_TOKEN_OR_API_KEY,
|
||||
)
|
||||
async def _(workflow_uuid: str, execution_uuid: str) -> str:
|
||||
execution = await self.ap.workflow_service.get_execution(execution_uuid)
|
||||
if execution is None:
|
||||
return self.http_status(404, -1, 'execution not found')
|
||||
if execution.get('workflow_uuid') != workflow_uuid:
|
||||
return self.http_status(404, -1, 'execution not found in workflow')
|
||||
return self.success(data={'execution': execution})
|
||||
|
||||
# Get workflow versions
|
||||
@self.route('/<workflow_uuid>/versions', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(workflow_uuid: str) -> str:
|
||||
versions = await self.ap.workflow_service.get_versions(workflow_uuid)
|
||||
return self.success(data={'versions': versions})
|
||||
|
||||
# Rollback to a specific version
|
||||
@self.route(
|
||||
'/<workflow_uuid>/rollback/<int:version>', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY
|
||||
)
|
||||
async def _(workflow_uuid: str, version: int) -> str:
|
||||
try:
|
||||
await self.ap.workflow_service.rollback_to_version(workflow_uuid, version)
|
||||
return self.success()
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# Workflow extensions (plugins and MCP servers)
|
||||
@self.route(
|
||||
'/<workflow_uuid>/extensions', methods=['GET', 'PUT'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY
|
||||
)
|
||||
async def _(workflow_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
workflow = await self.ap.workflow_service.get_workflow(workflow_uuid)
|
||||
if workflow is None:
|
||||
return self.http_status(404, -1, 'workflow not found')
|
||||
|
||||
# Get available plugins and MCP servers
|
||||
pipeline_component_kinds = ['Command', 'EventListener', 'Tool']
|
||||
plugins = await self.ap.plugin_connector.list_plugins(component_kinds=pipeline_component_kinds)
|
||||
mcp_servers = await self.ap.mcp_service.get_mcp_servers(contain_runtime_info=True)
|
||||
|
||||
extensions_prefs = workflow.get('extensions_preferences', {})
|
||||
return self.success(
|
||||
data={
|
||||
'enable_all_plugins': extensions_prefs.get('enable_all_plugins', True),
|
||||
'enable_all_mcp_servers': extensions_prefs.get('enable_all_mcp_servers', True),
|
||||
'bound_plugins': extensions_prefs.get('plugins', []),
|
||||
'available_plugins': plugins,
|
||||
'bound_mcp_servers': extensions_prefs.get('mcp_servers', []),
|
||||
'available_mcp_servers': mcp_servers,
|
||||
}
|
||||
)
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
enable_all_plugins = json_data.get('enable_all_plugins', True)
|
||||
enable_all_mcp_servers = json_data.get('enable_all_mcp_servers', True)
|
||||
bound_plugins = json_data.get('bound_plugins', [])
|
||||
bound_mcp_servers = json_data.get('bound_mcp_servers', [])
|
||||
|
||||
try:
|
||||
await self.ap.workflow_service.update_workflow_extensions(
|
||||
workflow_uuid, bound_plugins, bound_mcp_servers, enable_all_plugins, enable_all_mcp_servers
|
||||
)
|
||||
return self.success()
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# Debug API - Start debug execution
|
||||
@self.route('/<workflow_uuid>/debug/start', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(workflow_uuid: str) -> str:
|
||||
json_data = await quart.request.json or {}
|
||||
context = json_data.get('context', {})
|
||||
variables = json_data.get('variables', {})
|
||||
breakpoints = json_data.get('breakpoints', [])
|
||||
|
||||
try:
|
||||
execution_id = await self.ap.workflow_service.start_debug_execution(
|
||||
workflow_uuid, context=context, variables=variables, breakpoints=breakpoints
|
||||
)
|
||||
return self.success(data={'execution_id': execution_id})
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# Debug API - Pause execution
|
||||
@self.route(
|
||||
'/<workflow_uuid>/debug/<execution_uuid>/pause',
|
||||
methods=['POST'],
|
||||
auth_type=group.AuthType.USER_TOKEN_OR_API_KEY,
|
||||
)
|
||||
async def _(workflow_uuid: str, execution_uuid: str) -> str:
|
||||
try:
|
||||
await self.ap.workflow_service.pause_debug_execution(workflow_uuid, execution_uuid)
|
||||
return self.success()
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# Debug API - Resume execution
|
||||
@self.route(
|
||||
'/<workflow_uuid>/debug/<execution_uuid>/resume',
|
||||
methods=['POST'],
|
||||
auth_type=group.AuthType.USER_TOKEN_OR_API_KEY,
|
||||
)
|
||||
async def _(workflow_uuid: str, execution_uuid: str) -> str:
|
||||
try:
|
||||
await self.ap.workflow_service.resume_debug_execution(workflow_uuid, execution_uuid)
|
||||
return self.success()
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# Debug API - Step execution
|
||||
@self.route(
|
||||
'/<workflow_uuid>/debug/<execution_uuid>/step',
|
||||
methods=['POST'],
|
||||
auth_type=group.AuthType.USER_TOKEN_OR_API_KEY,
|
||||
)
|
||||
async def _(workflow_uuid: str, execution_uuid: str) -> str:
|
||||
try:
|
||||
result = await self.ap.workflow_service.step_debug_execution(workflow_uuid, execution_uuid)
|
||||
return self.success(data=result)
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# Debug API - Stop execution
|
||||
@self.route(
|
||||
'/<workflow_uuid>/debug/<execution_uuid>/stop',
|
||||
methods=['POST'],
|
||||
auth_type=group.AuthType.USER_TOKEN_OR_API_KEY,
|
||||
)
|
||||
async def _(workflow_uuid: str, execution_uuid: str) -> str:
|
||||
try:
|
||||
await self.ap.workflow_service.stop_debug_execution(workflow_uuid, execution_uuid)
|
||||
return self.success()
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# Debug API - Get debug state
|
||||
@self.route(
|
||||
'/<workflow_uuid>/debug/<execution_uuid>/state',
|
||||
methods=['GET'],
|
||||
auth_type=group.AuthType.USER_TOKEN_OR_API_KEY,
|
||||
)
|
||||
async def _(workflow_uuid: str, execution_uuid: str) -> str:
|
||||
try:
|
||||
state = await self.ap.workflow_service.get_debug_state(workflow_uuid, execution_uuid)
|
||||
return self.success(data=state)
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# Get execution logs
|
||||
@self.route(
|
||||
'/<workflow_uuid>/executions/<execution_uuid>/logs',
|
||||
methods=['GET'],
|
||||
auth_type=group.AuthType.USER_TOKEN_OR_API_KEY,
|
||||
)
|
||||
async def _(workflow_uuid: str, execution_uuid: str) -> str:
|
||||
limit = int(quart.request.args.get('limit', 100))
|
||||
offset = int(quart.request.args.get('offset', 0))
|
||||
try:
|
||||
result = await self.ap.workflow_service.get_execution_logs(workflow_uuid, execution_uuid, limit, offset)
|
||||
return self.success(data=result)
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# Rerun execution
|
||||
@self.route(
|
||||
'/<workflow_uuid>/executions/<execution_uuid>/rerun',
|
||||
methods=['POST'],
|
||||
auth_type=group.AuthType.USER_TOKEN_OR_API_KEY,
|
||||
)
|
||||
async def _(workflow_uuid: str, execution_uuid: str) -> str:
|
||||
try:
|
||||
new_execution_id = await self.ap.workflow_service.rerun_execution(workflow_uuid, execution_uuid)
|
||||
return self.success(data={'execution_uuid': new_execution_id})
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# Get workflow statistics
|
||||
@self.route('/<workflow_uuid>/stats', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(workflow_uuid: str) -> str:
|
||||
try:
|
||||
stats = await self.ap.workflow_service.get_workflow_stats(workflow_uuid)
|
||||
return self.success(data=stats)
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
|
||||
# LLM Node Performance Test Endpoint
|
||||
# Tests each step of LLM node execution with detailed timing
|
||||
@self.route('/_/test/llm-node', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
"""Test LLM node performance with detailed step-by-step timing.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"model_uuid": "uuid-of-model",
|
||||
"system_prompt": "optional system prompt",
|
||||
"user_prompt": "test message",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}
|
||||
|
||||
Response includes timing for each step:
|
||||
- model_fetch: Time to get model from model_mgr
|
||||
- prompt_build: Time to build messages
|
||||
- llm_call: Time for actual LLM invocation
|
||||
- total: Total time
|
||||
- usage: Token usage information
|
||||
"""
|
||||
import time
|
||||
|
||||
json_data = await quart.request.json
|
||||
if not json_data:
|
||||
return self.http_status(400, -1, 'Request body is required')
|
||||
|
||||
model_uuid = json_data.get('model_uuid', '')
|
||||
if not model_uuid:
|
||||
return self.http_status(400, -1, 'model_uuid is required')
|
||||
|
||||
user_prompt = json_data.get('user_prompt', 'test')
|
||||
system_prompt = json_data.get('system_prompt', '')
|
||||
temperature = json_data.get('temperature')
|
||||
max_tokens = json_data.get('max_tokens', 0)
|
||||
|
||||
timings = {}
|
||||
errors = []
|
||||
|
||||
# Step 1: Model fetch
|
||||
t_start = time.perf_counter()
|
||||
try:
|
||||
runtime_model = await self.ap.model_mgr.get_model_by_uuid(model_uuid)
|
||||
timings['model_fetch_ms'] = round((time.perf_counter() - t_start) * 1000, 2)
|
||||
timings['model_found'] = True
|
||||
timings['model_name'] = runtime_model.model_entity.name if runtime_model else None
|
||||
except Exception as e:
|
||||
timings['model_fetch_ms'] = round((time.perf_counter() - t_start) * 1000, 2)
|
||||
timings['model_found'] = False
|
||||
errors.append(f'Model fetch failed: {str(e)}')
|
||||
return self.http_status(400, -1, {
|
||||
'error': errors[0],
|
||||
'timings': timings,
|
||||
})
|
||||
|
||||
# Step 2: Build messages
|
||||
t_start = time.perf_counter()
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append(provider_message.Message(role='system', content=system_prompt))
|
||||
messages.append(provider_message.Message(role='user', content=user_prompt))
|
||||
timings['prompt_build_ms'] = round((time.perf_counter() - t_start) * 1000, 2)
|
||||
|
||||
# Step 3: Build extra args
|
||||
extra_args = {}
|
||||
if temperature is not None:
|
||||
extra_args['temperature'] = float(temperature)
|
||||
if max_tokens and int(max_tokens) > 0:
|
||||
extra_args['max_tokens'] = int(max_tokens)
|
||||
|
||||
# Step 4: LLM call
|
||||
t_start = time.perf_counter()
|
||||
try:
|
||||
result_message = await runtime_model.provider.invoke_llm(
|
||||
query=None,
|
||||
model=runtime_model,
|
||||
messages=messages,
|
||||
funcs=None,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
timings['llm_call_ms'] = round((time.perf_counter() - t_start) * 1000, 2)
|
||||
timings['llm_call_success'] = True
|
||||
|
||||
# Extract response text
|
||||
response_text = ''
|
||||
if isinstance(result_message.content, str):
|
||||
response_text = result_message.content
|
||||
elif isinstance(result_message.content, list):
|
||||
for elem in result_message.content:
|
||||
if hasattr(elem, 'text') and elem.text:
|
||||
response_text += elem.text
|
||||
elif isinstance(elem, str):
|
||||
response_text += elem
|
||||
|
||||
timings['response_length'] = len(response_text)
|
||||
timings['response_preview'] = response_text[:200]
|
||||
|
||||
# Extract usage
|
||||
usage = {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0}
|
||||
if hasattr(result_message, 'usage') and result_message.usage:
|
||||
u = result_message.usage
|
||||
usage = {
|
||||
'prompt_tokens': getattr(u, 'prompt_tokens', 0) or 0,
|
||||
'completion_tokens': getattr(u, 'completion_tokens', 0) or 0,
|
||||
'total_tokens': getattr(u, 'total_tokens', 0) or 0,
|
||||
}
|
||||
timings['usage'] = usage
|
||||
|
||||
except Exception as e:
|
||||
timings['llm_call_ms'] = round((time.perf_counter() - t_start) * 1000, 2)
|
||||
timings['llm_call_success'] = False
|
||||
errors.append(f'LLM call failed: {str(e)}')
|
||||
|
||||
# Calculate total
|
||||
timings['total_ms'] = round(sum([
|
||||
timings.get('model_fetch_ms', 0),
|
||||
timings.get('prompt_build_ms', 0),
|
||||
timings.get('llm_call_ms', 0),
|
||||
]), 2)
|
||||
|
||||
# Add breakdown percentage
|
||||
if timings['total_ms'] > 0:
|
||||
timings['breakdown'] = {
|
||||
'model_fetch_pct': round(timings.get('model_fetch_ms', 0) / timings['total_ms'] * 100, 1),
|
||||
'prompt_build_pct': round(timings.get('prompt_build_ms', 0) / timings['total_ms'] * 100, 1),
|
||||
'llm_call_pct': round(timings.get('llm_call_ms', 0) / timings['total_ms'] * 100, 1),
|
||||
}
|
||||
|
||||
if errors:
|
||||
timings['errors'] = errors
|
||||
|
||||
return self.success(data={'test_result': timings})
|
||||
|
||||
|
||||
@group.group_class('executions', '/api/v1/executions')
|
||||
class ExecutionsRouterGroup(group.RouterGroup):
|
||||
"""Workflow execution API router group"""
|
||||
|
||||
async def initialize(self) -> None:
|
||||
# Get all executions (across all workflows)
|
||||
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
limit = int(quart.request.args.get('limit', 50))
|
||||
offset = int(quart.request.args.get('offset', 0))
|
||||
status = quart.request.args.get('status')
|
||||
executions = await self.ap.workflow_service.get_executions(limit=limit, offset=offset, status=status)
|
||||
return self.success(data=executions)
|
||||
|
||||
# Get single execution
|
||||
@self.route('/<execution_uuid>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(execution_uuid: str) -> str:
|
||||
execution = await self.ap.workflow_service.get_execution(execution_uuid)
|
||||
if execution is None:
|
||||
return self.http_status(404, -1, 'execution not found')
|
||||
return self.success(data={'execution': execution})
|
||||
|
||||
# Cancel execution
|
||||
@self.route('/<execution_uuid>/cancel', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(execution_uuid: str) -> str:
|
||||
try:
|
||||
await self.ap.workflow_service.cancel_execution(execution_uuid)
|
||||
return self.success()
|
||||
except ValueError as e:
|
||||
return self.http_status(404, -1, str(e))
|
||||
except RuntimeError as e:
|
||||
return self.http_status(400, -1, str(e))
|
||||
@@ -17,6 +17,7 @@ from .groups import platform as groups_platform
|
||||
from .groups import pipelines as groups_pipelines
|
||||
from .groups import knowledge as groups_knowledge
|
||||
from .groups import resources as groups_resources
|
||||
from .groups import workflows as groups_workflows
|
||||
|
||||
importutil.import_modules_in_pkg(groups)
|
||||
importutil.import_modules_in_pkg(groups_provider)
|
||||
@@ -24,6 +25,7 @@ importutil.import_modules_in_pkg(groups_platform)
|
||||
importutil.import_modules_in_pkg(groups_pipelines)
|
||||
importutil.import_modules_in_pkg(groups_knowledge)
|
||||
importutil.import_modules_in_pkg(groups_resources)
|
||||
importutil.import_modules_in_pkg(groups_workflows)
|
||||
|
||||
|
||||
class HTTPController:
|
||||
@@ -105,6 +107,29 @@ class HTTPController:
|
||||
):
|
||||
if os.path.exists(os.path.join(frontend_path, path + '.html')):
|
||||
path += '.html'
|
||||
elif not path.startswith('api/'):
|
||||
# SPA fallback: serve index.html for all non-API, non-static routes
|
||||
# so that React Router can handle client-side routing (Vite SPA).
|
||||
# For /home/* sub-routes, first try parent .html files (pre-rendered pages).
|
||||
if path.startswith('home/'):
|
||||
segments = path.rstrip('/').split('/')
|
||||
for i in range(len(segments) - 1, 0, -1):
|
||||
parent_path = '/'.join(segments[:i]) + '.html'
|
||||
if os.path.exists(os.path.join(frontend_path, parent_path)):
|
||||
response = await quart.send_from_directory(
|
||||
frontend_path, parent_path, mimetype='text/html'
|
||||
)
|
||||
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate'
|
||||
response.headers['Pragma'] = 'no-cache'
|
||||
response.headers['Expires'] = '0'
|
||||
return response
|
||||
|
||||
# Fallback to index.html for SPA client-side routing
|
||||
response = await quart.send_from_directory(frontend_path, 'index.html', mimetype='text/html')
|
||||
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate'
|
||||
response.headers['Pragma'] = 'no-cache'
|
||||
response.headers['Expires'] = '0'
|
||||
return response
|
||||
else:
|
||||
return await quart.send_from_directory(frontend_path, '404.html')
|
||||
|
||||
|
||||
@@ -52,6 +52,9 @@ class ApiKeyService:
|
||||
|
||||
async def verify_api_key(self, key: str) -> bool:
|
||||
"""Verify if an API key is valid"""
|
||||
if not isinstance(key, str) or not key.startswith('lbk_'):
|
||||
return False
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(apikey.ApiKey).where(apikey.ApiKey.key == key)
|
||||
)
|
||||
|
||||
@@ -70,12 +70,17 @@ class BotService:
|
||||
'lark',
|
||||
]:
|
||||
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
||||
extra_webhook_prefix = self.ap.instance_config.data['api'].get('extra_webhook_prefix', '')
|
||||
webhook_url = f'/bots/{bot_uuid}'
|
||||
adapter_runtime_values['webhook_url'] = webhook_url
|
||||
adapter_runtime_values['webhook_full_url'] = f'{webhook_prefix}{webhook_url}'
|
||||
adapter_runtime_values['extra_webhook_full_url'] = (
|
||||
f'{extra_webhook_prefix}{webhook_url}' if extra_webhook_prefix else ''
|
||||
)
|
||||
else:
|
||||
adapter_runtime_values['webhook_url'] = None
|
||||
adapter_runtime_values['webhook_full_url'] = None
|
||||
adapter_runtime_values['extra_webhook_full_url'] = None
|
||||
|
||||
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
|
||||
|
||||
@@ -94,7 +99,11 @@ class BotService:
|
||||
# TODO: 检查配置信息格式
|
||||
bot_data['uuid'] = str(uuid.uuid4())
|
||||
|
||||
# checkout the default pipeline
|
||||
# Set default binding_type if not provided
|
||||
if 'binding_type' not in bot_data:
|
||||
bot_data['binding_type'] = 'pipeline'
|
||||
|
||||
# checkout the default pipeline (for backward compatibility)
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
||||
persistence_pipeline.LegacyPipeline.is_default == True
|
||||
@@ -104,6 +113,9 @@ class BotService:
|
||||
if pipeline is not None:
|
||||
bot_data['use_pipeline_uuid'] = pipeline.uuid
|
||||
bot_data['use_pipeline_name'] = pipeline.name
|
||||
# Also set binding_uuid for new unified binding model
|
||||
if 'binding_uuid' not in bot_data:
|
||||
bot_data['binding_uuid'] = pipeline.uuid
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_bot.Bot).values(bot_data))
|
||||
|
||||
@@ -118,7 +130,11 @@ class BotService:
|
||||
if 'uuid' in bot_data:
|
||||
del bot_data['uuid']
|
||||
|
||||
# set use_pipeline_name
|
||||
# Handle binding_type and binding_uuid for the new unified binding model
|
||||
# If binding_type is explicitly set to 'workflow', skip pipeline validation
|
||||
binding_type = bot_data.get('binding_type')
|
||||
|
||||
# set use_pipeline_name (for backward compatibility with 'pipeline' binding_type)
|
||||
if 'use_pipeline_uuid' in bot_data:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
||||
@@ -128,9 +144,19 @@ class BotService:
|
||||
pipeline = result.first()
|
||||
if pipeline is not None:
|
||||
bot_data['use_pipeline_name'] = pipeline.name
|
||||
# Also sync to binding_uuid if binding_type is 'pipeline' or not set
|
||||
if binding_type is None or binding_type == 'pipeline':
|
||||
bot_data['binding_uuid'] = bot_data['use_pipeline_uuid']
|
||||
bot_data['binding_type'] = 'pipeline'
|
||||
else:
|
||||
raise Exception('Pipeline not found')
|
||||
|
||||
# If binding_uuid is set directly (for workflow), sync use_pipeline_uuid for backward compatibility
|
||||
if 'binding_uuid' in bot_data and binding_type == 'workflow':
|
||||
# For workflow binding, we don't sync to use_pipeline_uuid
|
||||
# but we ensure binding_type is correctly set
|
||||
bot_data['binding_type'] = 'workflow'
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid)
|
||||
)
|
||||
|
||||
@@ -31,15 +31,126 @@ class KnowledgeService:
|
||||
if not knowledge_engine_plugin_id:
|
||||
raise ValueError('knowledge_engine_plugin_id is required')
|
||||
|
||||
creation_settings = kb_data.get('creation_settings', {})
|
||||
retrieval_settings = kb_data.get('retrieval_settings', {})
|
||||
|
||||
# Validate required fields based on plugin's creation_schema and retrieval_schema
|
||||
await self._validate_schema_required_fields(
|
||||
knowledge_engine_plugin_id,
|
||||
creation_settings,
|
||||
retrieval_settings,
|
||||
)
|
||||
|
||||
kb = await self.ap.rag_mgr.create_knowledge_base(
|
||||
name=kb_data.get('name', 'Untitled'),
|
||||
knowledge_engine_plugin_id=knowledge_engine_plugin_id,
|
||||
creation_settings=kb_data.get('creation_settings', {}),
|
||||
retrieval_settings=kb_data.get('retrieval_settings', {}),
|
||||
creation_settings=creation_settings,
|
||||
retrieval_settings=retrieval_settings,
|
||||
description=kb_data.get('description', ''),
|
||||
)
|
||||
return kb.uuid
|
||||
|
||||
async def _validate_schema_required_fields(
|
||||
self,
|
||||
plugin_id: str,
|
||||
creation_settings: dict,
|
||||
retrieval_settings: dict,
|
||||
) -> None:
|
||||
"""Validate required fields based on plugin's creation_schema and retrieval_schema.
|
||||
|
||||
This is a business-agnostic validation that checks all fields marked as
|
||||
required in the plugin's schema, regardless of field type.
|
||||
|
||||
Args:
|
||||
plugin_id: Knowledge Engine plugin ID.
|
||||
creation_settings: User-provided creation settings.
|
||||
retrieval_settings: User-provided retrieval settings.
|
||||
|
||||
Raises:
|
||||
ValueError: If any required field is missing or empty.
|
||||
"""
|
||||
# Validate creation_schema
|
||||
try:
|
||||
creation_schema = await self.ap.plugin_connector.get_rag_creation_schema(plugin_id)
|
||||
self._check_required_fields(creation_schema, creation_settings, 'creation_settings')
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to get creation_schema for validation: {e}')
|
||||
|
||||
# Validate retrieval_schema
|
||||
try:
|
||||
retrieval_schema = await self.ap.plugin_connector.get_rag_retrieval_schema(plugin_id)
|
||||
self._check_required_fields(retrieval_schema, retrieval_settings, 'retrieval_settings')
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to get retrieval_schema for validation: {e}')
|
||||
|
||||
def _check_required_fields(
|
||||
self,
|
||||
schema: dict | list,
|
||||
settings: dict,
|
||||
context: str,
|
||||
) -> None:
|
||||
"""Check required fields in schema against provided settings.
|
||||
|
||||
Args:
|
||||
schema: Plugin-defined schema (can be list or dict with 'schema' key).
|
||||
settings: User-provided settings values.
|
||||
context: Context name for error messages (e.g., 'creation_settings').
|
||||
|
||||
Raises:
|
||||
ValueError: If a required field is missing or empty.
|
||||
"""
|
||||
if not schema:
|
||||
return
|
||||
|
||||
# schema can be a list directly, or a dict with 'schema' key
|
||||
items = schema if isinstance(schema, list) else schema.get('schema', [])
|
||||
if not items:
|
||||
return
|
||||
|
||||
for item in items:
|
||||
field_name = item.get('name')
|
||||
if not field_name:
|
||||
continue
|
||||
|
||||
is_required = item.get('required', False)
|
||||
if not is_required:
|
||||
continue
|
||||
|
||||
# Check show_if condition - if field is conditionally shown, only validate when condition is met
|
||||
show_if = item.get('show_if')
|
||||
if show_if:
|
||||
depend_field = show_if.get('field')
|
||||
operator = show_if.get('operator')
|
||||
expected_value = show_if.get('value')
|
||||
|
||||
if depend_field and operator:
|
||||
depend_value = settings.get(depend_field)
|
||||
# If show_if condition is not met, skip validation for this field
|
||||
if operator == 'eq' and depend_value != expected_value:
|
||||
continue
|
||||
if operator == 'neq' and depend_value == expected_value:
|
||||
continue
|
||||
if operator == 'in' and isinstance(expected_value, list) and depend_value not in expected_value:
|
||||
continue
|
||||
|
||||
value = settings.get(field_name)
|
||||
|
||||
# Validate required field has a non-empty value
|
||||
if value is None or (isinstance(value, str) and value.strip() == ''):
|
||||
# Get field label for friendly error message
|
||||
label = item.get('label', {})
|
||||
field_label = (
|
||||
label.get('en_US', field_name)
|
||||
or label.get('zh_Hans', field_name)
|
||||
or label.get('zh_Hant', field_name)
|
||||
or field_name
|
||||
)
|
||||
raise ValueError(f'{field_label} is required ({context}.{field_name})')
|
||||
|
||||
async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
|
||||
"""更新知识库"""
|
||||
# Filter to only mutable fields
|
||||
|
||||
309
src/langbot/pkg/api/http/service/maintenance.py
Normal file
309
src/langbot/pkg/api/http/service/maintenance.py
Normal file
@@ -0,0 +1,309 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ....core import app
|
||||
from ....entity.persistence import bstorage as persistence_bstorage
|
||||
from ....entity.persistence import monitoring as persistence_monitoring
|
||||
|
||||
|
||||
LOG_FILE_PATTERN = re.compile(r'^langbot-(\d{4}-\d{2}-\d{2})\.log(?:\.\d+)?$')
|
||||
DEFAULT_UPLOAD_FILE_RETENTION_DAYS = 7
|
||||
DEFAULT_LOG_RETENTION_DAYS = 3
|
||||
|
||||
|
||||
class MaintenanceService:
|
||||
"""Storage maintenance and diagnostics."""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def cleanup_expired_files(self) -> dict[str, int]:
|
||||
cleanup_cfg = self.ap.instance_config.data.get('storage', {}).get('cleanup', {})
|
||||
upload_retention_days = self._positive_int(
|
||||
cleanup_cfg.get('uploaded_file_retention_days'),
|
||||
DEFAULT_UPLOAD_FILE_RETENTION_DAYS,
|
||||
'storage.cleanup.uploaded_file_retention_days',
|
||||
)
|
||||
log_retention_days = self._positive_int(
|
||||
cleanup_cfg.get('log_retention_days'),
|
||||
DEFAULT_LOG_RETENTION_DAYS,
|
||||
'storage.cleanup.log_retention_days',
|
||||
)
|
||||
|
||||
return {
|
||||
'uploaded_files': await self._cleanup_expired_uploaded_files(upload_retention_days),
|
||||
'log_files': self._cleanup_expired_log_files(log_retention_days),
|
||||
}
|
||||
|
||||
async def get_storage_analysis(self) -> dict[str, Any]:
|
||||
cleanup_cfg = self.ap.instance_config.data.get('storage', {}).get('cleanup', {})
|
||||
upload_retention_days = self._positive_int(
|
||||
cleanup_cfg.get('uploaded_file_retention_days'),
|
||||
DEFAULT_UPLOAD_FILE_RETENTION_DAYS,
|
||||
'storage.cleanup.uploaded_file_retention_days',
|
||||
)
|
||||
log_retention_days = self._positive_int(
|
||||
cleanup_cfg.get('log_retention_days'),
|
||||
DEFAULT_LOG_RETENTION_DAYS,
|
||||
'storage.cleanup.log_retention_days',
|
||||
)
|
||||
|
||||
database_cfg = self.ap.instance_config.data.get('database', {})
|
||||
database_type = database_cfg.get('use', 'sqlite')
|
||||
database_path = (
|
||||
Path(database_cfg.get('sqlite', {}).get('path', 'data/langbot.db')) if database_type == 'sqlite' else None
|
||||
)
|
||||
roots: list[tuple[str, Path | None]] = [
|
||||
('database', database_path),
|
||||
('logs', Path('data/logs')),
|
||||
('storage', Path('data/storage')),
|
||||
('vector_store', Path('data/chroma')),
|
||||
('plugins', Path('data/plugins')),
|
||||
('mcp', Path('data/mcp')),
|
||||
('temp', Path('data/temp')),
|
||||
]
|
||||
|
||||
sections = []
|
||||
for key, path in roots:
|
||||
sections.append(
|
||||
{
|
||||
'key': key,
|
||||
'path': str(path) if path else '',
|
||||
'exists': path.exists() if path else False,
|
||||
'size_bytes': self._path_size(path) if path else 0,
|
||||
'file_count': self._file_count(path) if path else 0,
|
||||
}
|
||||
)
|
||||
|
||||
monitoring_counts = await self._monitoring_counts()
|
||||
binary_storage = await self._binary_storage_stats()
|
||||
upload_candidates = await self._expired_uploaded_candidates(upload_retention_days)
|
||||
log_candidates = self._expired_log_candidates(log_retention_days)
|
||||
|
||||
return {
|
||||
'generated_at': datetime.datetime.now(datetime.timezone.utc).isoformat(),
|
||||
'cleanup_policy': {
|
||||
'uploaded_file_retention_days': upload_retention_days,
|
||||
'log_retention_days': log_retention_days,
|
||||
},
|
||||
'sections': sections,
|
||||
'database': {
|
||||
'type': database_type,
|
||||
'monitoring_counts': monitoring_counts,
|
||||
'binary_storage': binary_storage,
|
||||
},
|
||||
'cleanup_candidates': {
|
||||
'uploaded_files': upload_candidates,
|
||||
'log_files': log_candidates,
|
||||
},
|
||||
'tasks': self.ap.task_mgr.get_stats() if self.ap.task_mgr else {},
|
||||
}
|
||||
|
||||
async def _cleanup_expired_uploaded_files(self, retention_days: int) -> int:
|
||||
provider = self.ap.storage_mgr.storage_provider
|
||||
provider_name = provider.__class__.__name__
|
||||
if provider_name == 'LocalStorageProvider':
|
||||
candidates = self._expired_local_upload_candidates(retention_days, include_paths=True)
|
||||
deleted = 0
|
||||
for item in candidates:
|
||||
try:
|
||||
os.remove(item['path'])
|
||||
deleted += 1
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to delete expired uploaded file {item["key"]}: {e}')
|
||||
return deleted
|
||||
|
||||
if provider_name == 'S3StorageProvider':
|
||||
return await self._cleanup_expired_s3_uploaded_files(retention_days)
|
||||
|
||||
return 0
|
||||
|
||||
async def _expired_uploaded_candidates(self, retention_days: int) -> list[dict[str, Any]]:
|
||||
provider_name = self.ap.storage_mgr.storage_provider.__class__.__name__
|
||||
if provider_name == 'LocalStorageProvider':
|
||||
return self._expired_local_upload_candidates(retention_days)
|
||||
if provider_name == 'S3StorageProvider':
|
||||
return await self._expired_s3_upload_candidates(retention_days)
|
||||
return []
|
||||
|
||||
async def _cleanup_expired_s3_uploaded_files(self, retention_days: int) -> int:
|
||||
provider = self.ap.storage_mgr.storage_provider
|
||||
candidates = await self._expired_s3_upload_candidates(retention_days)
|
||||
deleted = 0
|
||||
for item in candidates:
|
||||
await provider.delete(item['key'])
|
||||
deleted += 1
|
||||
return deleted
|
||||
|
||||
async def _expired_s3_upload_candidates(self, retention_days: int) -> list[dict[str, Any]]:
|
||||
provider = self.ap.storage_mgr.storage_provider
|
||||
cutoff = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=retention_days)
|
||||
candidates = []
|
||||
paginator = provider.s3_client.get_paginator('list_objects_v2')
|
||||
|
||||
for page in paginator.paginate(Bucket=provider.bucket_name):
|
||||
for obj in page.get('Contents', []):
|
||||
key = obj.get('Key', '')
|
||||
last_modified = obj.get('LastModified')
|
||||
if not self._is_uploaded_file_key(key):
|
||||
continue
|
||||
if last_modified and last_modified < cutoff:
|
||||
candidates.append(
|
||||
{
|
||||
'key': key,
|
||||
'size_bytes': obj.get('Size', 0),
|
||||
'modified_at': last_modified.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
return candidates
|
||||
|
||||
def _cleanup_expired_log_files(self, retention_days: int) -> int:
|
||||
deleted = 0
|
||||
for item in self._expired_log_candidates(retention_days, include_paths=True):
|
||||
try:
|
||||
os.remove(item['path'])
|
||||
deleted += 1
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to delete expired log file {item["name"]}: {e}')
|
||||
return deleted
|
||||
|
||||
def _expired_local_upload_candidates(
|
||||
self, retention_days: int, include_paths: bool = False
|
||||
) -> list[dict[str, Any]]:
|
||||
storage_root = Path('data/storage')
|
||||
if not storage_root.exists():
|
||||
return []
|
||||
|
||||
cutoff = datetime.datetime.now().timestamp() - retention_days * 86400
|
||||
candidates = []
|
||||
for entry in storage_root.iterdir():
|
||||
if not entry.is_file() or not self._is_uploaded_file_key(entry.name):
|
||||
continue
|
||||
stat = entry.stat()
|
||||
if stat.st_mtime >= cutoff:
|
||||
continue
|
||||
item = {
|
||||
'key': entry.name,
|
||||
'size_bytes': stat.st_size,
|
||||
'modified_at': datetime.datetime.fromtimestamp(stat.st_mtime, datetime.timezone.utc).isoformat(),
|
||||
}
|
||||
if include_paths:
|
||||
item['path'] = str(entry)
|
||||
candidates.append(item)
|
||||
return candidates
|
||||
|
||||
def _expired_log_candidates(self, retention_days: int, include_paths: bool = False) -> list[dict[str, Any]]:
|
||||
log_root = Path('data/logs')
|
||||
if not log_root.exists():
|
||||
return []
|
||||
|
||||
cutoff_date = datetime.date.today() - datetime.timedelta(days=retention_days - 1)
|
||||
candidates = []
|
||||
for entry in log_root.iterdir():
|
||||
if not entry.is_file():
|
||||
continue
|
||||
match = LOG_FILE_PATTERN.match(entry.name)
|
||||
if not match:
|
||||
continue
|
||||
try:
|
||||
file_date = datetime.date.fromisoformat(match.group(1))
|
||||
except ValueError:
|
||||
continue
|
||||
if file_date >= cutoff_date:
|
||||
continue
|
||||
stat = entry.stat()
|
||||
item = {
|
||||
'name': entry.name,
|
||||
'date': file_date.isoformat(),
|
||||
'size_bytes': stat.st_size,
|
||||
}
|
||||
if include_paths:
|
||||
item['path'] = str(entry)
|
||||
candidates.append(item)
|
||||
return candidates
|
||||
|
||||
def _is_uploaded_file_key(self, key: str) -> bool:
|
||||
return '/' not in key and not key.startswith('plugin_config_')
|
||||
|
||||
async def _monitoring_counts(self) -> dict[str, int]:
|
||||
tables = {
|
||||
'messages': persistence_monitoring.MonitoringMessage.id,
|
||||
'llm_calls': persistence_monitoring.MonitoringLLMCall.id,
|
||||
'embedding_calls': persistence_monitoring.MonitoringEmbeddingCall.id,
|
||||
'errors': persistence_monitoring.MonitoringError.id,
|
||||
'sessions': persistence_monitoring.MonitoringSession.session_id,
|
||||
'feedback': persistence_monitoring.MonitoringFeedback.id,
|
||||
}
|
||||
counts: dict[str, int] = {}
|
||||
for key, column in tables.items():
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(sqlalchemy.func.count(column)))
|
||||
counts[key] = result.scalar() or 0
|
||||
return counts
|
||||
|
||||
async def _binary_storage_stats(self) -> dict[str, Any]:
|
||||
count_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(sqlalchemy.func.count(persistence_bstorage.BinaryStorage.unique_key))
|
||||
)
|
||||
size_bytes = None
|
||||
try:
|
||||
size_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(sqlalchemy.func.sum(sqlalchemy.func.length(persistence_bstorage.BinaryStorage.value)))
|
||||
)
|
||||
size_bytes = size_result.scalar() or 0
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to estimate binary storage size: {e}')
|
||||
|
||||
return {
|
||||
'count': count_result.scalar() or 0,
|
||||
'size_bytes': size_bytes,
|
||||
}
|
||||
|
||||
def _path_size(self, path: Path) -> int:
|
||||
if not path.exists():
|
||||
return 0
|
||||
if path.is_file():
|
||||
return path.stat().st_size
|
||||
total = 0
|
||||
for root, _, files in os.walk(path):
|
||||
for file_name in files:
|
||||
file_path = Path(root) / file_name
|
||||
try:
|
||||
total += file_path.stat().st_size
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return total
|
||||
|
||||
def _file_count(self, path: Path) -> int:
|
||||
if not path.exists():
|
||||
return 0
|
||||
if path.is_file():
|
||||
return 1
|
||||
count = 0
|
||||
for _, _, files in os.walk(path):
|
||||
count += len(files)
|
||||
return count
|
||||
|
||||
def _positive_int(self, value: Any, default: int, name: str) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
self.ap.logger.warning(f'Invalid {name}: {value!r}, using {default}')
|
||||
return default
|
||||
if parsed < 1:
|
||||
self.ap.logger.warning(f'Invalid {name}: {value!r}, using {default}')
|
||||
return default
|
||||
return parsed
|
||||
@@ -23,6 +23,17 @@ def _parse_provider_api_keys(provider_dict: dict) -> dict:
|
||||
return provider_dict
|
||||
|
||||
|
||||
def _runtime_model_data(model_uuid: str, model_data: dict) -> dict:
|
||||
"""Return model data for rebuilding runtime models after an update.
|
||||
|
||||
Update payloads intentionally omit uuid before writing to the database.
|
||||
Runtime model entities still need the stable uuid so pipeline configs can
|
||||
resolve the in-memory model immediately after an edit, without requiring a
|
||||
process restart.
|
||||
"""
|
||||
return {**model_data, 'uuid': model_uuid}
|
||||
|
||||
|
||||
class LLMModelsService:
|
||||
ap: app.Application
|
||||
|
||||
@@ -105,11 +116,16 @@ class LLMModelsService:
|
||||
)
|
||||
)
|
||||
pipeline = result.first()
|
||||
if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '':
|
||||
pipeline_config = pipeline.config
|
||||
pipeline_config['ai']['local-agent']['model'] = model_data['uuid']
|
||||
pipeline_data = {'config': pipeline_config}
|
||||
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
|
||||
if pipeline is not None:
|
||||
model_config = pipeline.config.get('ai', {}).get('local-agent', {}).get('model', {})
|
||||
if not model_config.get('primary', ''):
|
||||
pipeline_config = pipeline.config
|
||||
pipeline_config['ai']['local-agent']['model'] = {
|
||||
'primary': model_data['uuid'],
|
||||
'fallbacks': [],
|
||||
}
|
||||
pipeline_data = {'config': pipeline_config}
|
||||
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
|
||||
|
||||
return model_data['uuid']
|
||||
|
||||
@@ -168,7 +184,7 @@ class LLMModelsService:
|
||||
raise Exception('provider not found')
|
||||
|
||||
runtime_llm_model = await self.ap.model_mgr.load_llm_model_with_provider(
|
||||
persistence_model.LLMModel(**model_data),
|
||||
persistence_model.LLMModel(**_runtime_model_data(model_uuid, model_data)),
|
||||
runtime_provider,
|
||||
)
|
||||
self.ap.model_mgr.llm_models.append(runtime_llm_model)
|
||||
@@ -329,7 +345,7 @@ class EmbeddingModelsService:
|
||||
raise Exception('provider not found')
|
||||
|
||||
runtime_embedding_model = await self.ap.model_mgr.load_embedding_model_with_provider(
|
||||
persistence_model.EmbeddingModel(**model_data),
|
||||
persistence_model.EmbeddingModel(**_runtime_model_data(model_uuid, model_data)),
|
||||
runtime_provider,
|
||||
)
|
||||
self.ap.model_mgr.embedding_models.append(runtime_embedding_model)
|
||||
@@ -362,3 +378,162 @@ class EmbeddingModelsService:
|
||||
input_text=['Hello, world!'],
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
|
||||
class RerankModelsService:
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def get_rerank_models(self) -> list[dict]:
|
||||
"""Get all rerank models with provider info"""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.RerankModel))
|
||||
models = result.all()
|
||||
|
||||
providers_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider)
|
||||
)
|
||||
providers = {p.uuid: p for p in providers_result.all()}
|
||||
|
||||
models_list = []
|
||||
for model in models:
|
||||
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, model)
|
||||
provider = providers.get(model.provider_uuid)
|
||||
if provider:
|
||||
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
||||
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
||||
models_list.append(model_dict)
|
||||
|
||||
return models_list
|
||||
|
||||
async def get_rerank_models_by_provider(self, provider_uuid: str) -> list[dict]:
|
||||
"""Get rerank models by provider UUID"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.RerankModel).where(
|
||||
persistence_model.RerankModel.provider_uuid == provider_uuid
|
||||
)
|
||||
)
|
||||
models = result.all()
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, m) for m in models]
|
||||
|
||||
async def create_rerank_model(self, model_data: dict, preserve_uuid: bool = False) -> str:
|
||||
"""Create a new rerank model"""
|
||||
if not preserve_uuid:
|
||||
model_data['uuid'] = str(uuid.uuid4())
|
||||
|
||||
if 'provider' in model_data:
|
||||
provider_data = model_data.pop('provider')
|
||||
if provider_data.get('uuid'):
|
||||
model_data['provider_uuid'] = provider_data['uuid']
|
||||
else:
|
||||
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
||||
requester=provider_data.get('requester', ''),
|
||||
base_url=provider_data.get('base_url', ''),
|
||||
api_keys=provider_data.get('api_keys', []),
|
||||
)
|
||||
model_data['provider_uuid'] = provider_uuid
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_model.RerankModel).values(**model_data)
|
||||
)
|
||||
|
||||
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
||||
if runtime_provider is None:
|
||||
raise Exception('provider not found')
|
||||
|
||||
runtime_rerank_model = await self.ap.model_mgr.load_rerank_model_with_provider(
|
||||
persistence_model.RerankModel(**model_data),
|
||||
runtime_provider,
|
||||
)
|
||||
self.ap.model_mgr.rerank_models.append(runtime_rerank_model)
|
||||
|
||||
return model_data['uuid']
|
||||
|
||||
async def get_rerank_model(self, model_uuid: str) -> dict | None:
|
||||
"""Get a single rerank model with provider info"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.RerankModel).where(persistence_model.RerankModel.uuid == model_uuid)
|
||||
)
|
||||
model = result.first()
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, model)
|
||||
|
||||
provider_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||
persistence_model.ModelProvider.uuid == model.provider_uuid
|
||||
)
|
||||
)
|
||||
provider = provider_result.first()
|
||||
if provider:
|
||||
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
||||
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
||||
|
||||
return model_dict
|
||||
|
||||
async def update_rerank_model(self, model_uuid: str, model_data: dict) -> None:
|
||||
"""Update an existing rerank model"""
|
||||
if 'uuid' in model_data:
|
||||
del model_data['uuid']
|
||||
|
||||
if 'provider' in model_data:
|
||||
provider_data = model_data.pop('provider')
|
||||
if provider_data.get('uuid'):
|
||||
model_data['provider_uuid'] = provider_data['uuid']
|
||||
else:
|
||||
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
||||
requester=provider_data.get('requester', ''),
|
||||
base_url=provider_data.get('base_url', ''),
|
||||
api_keys=provider_data.get('api_keys', []),
|
||||
)
|
||||
model_data['provider_uuid'] = provider_uuid
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.RerankModel)
|
||||
.where(persistence_model.RerankModel.uuid == model_uuid)
|
||||
.values(**model_data)
|
||||
)
|
||||
|
||||
await self.ap.model_mgr.remove_rerank_model(model_uuid)
|
||||
|
||||
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
||||
if runtime_provider is None:
|
||||
raise Exception('provider not found')
|
||||
|
||||
runtime_rerank_model = await self.ap.model_mgr.load_rerank_model_with_provider(
|
||||
persistence_model.RerankModel(**_runtime_model_data(model_uuid, model_data)),
|
||||
runtime_provider,
|
||||
)
|
||||
self.ap.model_mgr.rerank_models.append(runtime_rerank_model)
|
||||
|
||||
async def delete_rerank_model(self, model_uuid: str) -> None:
|
||||
"""Delete a rerank model"""
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_model.RerankModel).where(persistence_model.RerankModel.uuid == model_uuid)
|
||||
)
|
||||
await self.ap.model_mgr.remove_rerank_model(model_uuid)
|
||||
|
||||
async def test_rerank_model(self, model_uuid: str, model_data: dict) -> None:
|
||||
"""Test a rerank model"""
|
||||
runtime_rerank_model: model_requester.RuntimeRerankModel | None = None
|
||||
|
||||
if model_uuid != '_':
|
||||
for model in self.ap.model_mgr.rerank_models:
|
||||
if model.model_entity.uuid == model_uuid:
|
||||
runtime_rerank_model = model
|
||||
break
|
||||
if runtime_rerank_model is None:
|
||||
raise Exception('model not found')
|
||||
else:
|
||||
runtime_rerank_model = await self.ap.model_mgr.init_temporary_runtime_rerank_model(model_data)
|
||||
|
||||
await runtime_rerank_model.provider.invoke_rerank(
|
||||
model=runtime_rerank_model,
|
||||
query='What is artificial intelligence?',
|
||||
documents=[
|
||||
'Artificial intelligence is a branch of computer science.',
|
||||
'The weather is nice today.',
|
||||
],
|
||||
)
|
||||
|
||||
@@ -16,6 +16,121 @@ class MonitoringService:
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
# ========== Cleanup Methods ==========
|
||||
|
||||
async def cleanup_expired_records(self, retention_days: int, batch_size: int = 1000) -> dict[str, int]:
|
||||
"""Delete monitoring records older than the specified retention period.
|
||||
|
||||
Args:
|
||||
retention_days: Number of days to retain records.
|
||||
batch_size: Maximum rows to delete per table batch.
|
||||
|
||||
Returns:
|
||||
A dict mapping table name to the number of deleted rows.
|
||||
"""
|
||||
if retention_days < 1:
|
||||
raise ValueError('retention_days must be >= 1')
|
||||
if batch_size < 1:
|
||||
raise ValueError('batch_size must be >= 1')
|
||||
|
||||
cutoff = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - datetime.timedelta(
|
||||
days=retention_days
|
||||
)
|
||||
|
||||
tables_and_columns: list[tuple[str, type, sqlalchemy.Column, sqlalchemy.Column]] = [
|
||||
(
|
||||
'monitoring_messages',
|
||||
persistence_monitoring.MonitoringMessage,
|
||||
persistence_monitoring.MonitoringMessage.timestamp,
|
||||
persistence_monitoring.MonitoringMessage.id,
|
||||
),
|
||||
(
|
||||
'monitoring_llm_calls',
|
||||
persistence_monitoring.MonitoringLLMCall,
|
||||
persistence_monitoring.MonitoringLLMCall.timestamp,
|
||||
persistence_monitoring.MonitoringLLMCall.id,
|
||||
),
|
||||
(
|
||||
'monitoring_embedding_calls',
|
||||
persistence_monitoring.MonitoringEmbeddingCall,
|
||||
persistence_monitoring.MonitoringEmbeddingCall.timestamp,
|
||||
persistence_monitoring.MonitoringEmbeddingCall.id,
|
||||
),
|
||||
(
|
||||
'monitoring_errors',
|
||||
persistence_monitoring.MonitoringError,
|
||||
persistence_monitoring.MonitoringError.timestamp,
|
||||
persistence_monitoring.MonitoringError.id,
|
||||
),
|
||||
(
|
||||
'monitoring_sessions',
|
||||
persistence_monitoring.MonitoringSession,
|
||||
persistence_monitoring.MonitoringSession.last_activity,
|
||||
persistence_monitoring.MonitoringSession.session_id,
|
||||
),
|
||||
(
|
||||
'monitoring_feedback',
|
||||
persistence_monitoring.MonitoringFeedback,
|
||||
persistence_monitoring.MonitoringFeedback.timestamp,
|
||||
persistence_monitoring.MonitoringFeedback.id,
|
||||
),
|
||||
]
|
||||
|
||||
deleted_counts: dict[str, int] = {}
|
||||
|
||||
for table_name, model_cls, ts_column, pk_column in tables_and_columns:
|
||||
deleted_counts[table_name] = await self._delete_expired_in_batches(
|
||||
model_cls=model_cls,
|
||||
ts_column=ts_column,
|
||||
pk_column=pk_column,
|
||||
cutoff=cutoff,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
if sum(deleted_counts.values()) > 0:
|
||||
await self._release_sqlite_space()
|
||||
|
||||
return deleted_counts
|
||||
|
||||
async def _delete_expired_in_batches(
|
||||
self,
|
||||
model_cls: type,
|
||||
ts_column: sqlalchemy.Column,
|
||||
pk_column: sqlalchemy.Column,
|
||||
cutoff: datetime.datetime,
|
||||
batch_size: int,
|
||||
) -> int:
|
||||
deleted_total = 0
|
||||
|
||||
while True:
|
||||
select_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(pk_column).where(ts_column < cutoff).limit(batch_size)
|
||||
)
|
||||
pk_values = list(select_result.scalars().all())
|
||||
if not pk_values:
|
||||
break
|
||||
|
||||
delete_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(model_cls).where(pk_column.in_(pk_values))
|
||||
)
|
||||
deleted = delete_result.rowcount or 0
|
||||
deleted_total += deleted
|
||||
|
||||
if len(pk_values) < batch_size:
|
||||
break
|
||||
|
||||
return deleted_total
|
||||
|
||||
async def _release_sqlite_space(self) -> None:
|
||||
database_type = self.ap.instance_config.data.get('database', {}).get('use', 'sqlite')
|
||||
if database_type != 'sqlite':
|
||||
return
|
||||
|
||||
async with self.ap.persistence_mgr.get_db_engine().connect() as conn:
|
||||
autocommit_conn = await conn.execution_options(isolation_level='AUTOCOMMIT')
|
||||
await autocommit_conn.execute(sqlalchemy.text('PRAGMA wal_checkpoint(TRUNCATE)'))
|
||||
await autocommit_conn.execute(sqlalchemy.text('VACUUM'))
|
||||
|
||||
# ========== Recording Methods ==========
|
||||
|
||||
async def record_message(
|
||||
@@ -30,6 +145,7 @@ class MonitoringService:
|
||||
level: str = 'info',
|
||||
platform: str | None = None,
|
||||
user_id: str | None = None,
|
||||
user_name: str | None = None,
|
||||
runner_name: str | None = None,
|
||||
variables: str | None = None,
|
||||
role: str = 'user',
|
||||
@@ -49,6 +165,7 @@ class MonitoringService:
|
||||
'level': level,
|
||||
'platform': platform,
|
||||
'user_id': user_id,
|
||||
'user_name': user_name,
|
||||
'runner_name': runner_name,
|
||||
'variables': variables,
|
||||
'role': role,
|
||||
@@ -152,6 +269,7 @@ class MonitoringService:
|
||||
pipeline_name: str,
|
||||
platform: str | None = None,
|
||||
user_id: str | None = None,
|
||||
user_name: str | None = None,
|
||||
) -> None:
|
||||
"""Record a new session"""
|
||||
session_data = {
|
||||
@@ -166,6 +284,7 @@ class MonitoringService:
|
||||
'is_active': True,
|
||||
'platform': platform,
|
||||
'user_id': user_id,
|
||||
'user_name': user_name,
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
@@ -1128,3 +1247,314 @@ class MonitoringService:
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
# ========== Feedback Methods ==========
|
||||
|
||||
async def record_feedback(
|
||||
self,
|
||||
feedback_id: str,
|
||||
feedback_type: int,
|
||||
feedback_content: str | None = None,
|
||||
inaccurate_reasons: list[str] | None = None,
|
||||
bot_id: str | None = None,
|
||||
bot_name: str | None = None,
|
||||
pipeline_id: str | None = None,
|
||||
pipeline_name: str | None = None,
|
||||
session_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
stream_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
platform: str | None = None,
|
||||
) -> str:
|
||||
"""Record user feedback (like/dislike) from AI Bot conversation.
|
||||
|
||||
Args:
|
||||
feedback_id: Unique feedback identifier from platform (e.g., WeChat Work)
|
||||
feedback_type: 1 = like (thumbs up), 2 = dislike (thumbs down)
|
||||
feedback_content: Optional user feedback text
|
||||
inaccurate_reasons: List of reasons for inaccurate response (for dislike)
|
||||
bot_id: Bot ID
|
||||
bot_name: Bot name
|
||||
pipeline_id: Pipeline ID
|
||||
pipeline_name: Pipeline name
|
||||
session_id: Session ID
|
||||
message_id: Message ID
|
||||
stream_id: Stream ID (for WeChat Work streaming messages)
|
||||
user_id: User ID
|
||||
platform: Platform name (e.g., 'wecom')
|
||||
|
||||
Returns:
|
||||
The record ID
|
||||
"""
|
||||
import json
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
reasons_json = json.dumps(inaccurate_reasons, ensure_ascii=False) if inaccurate_reasons else None
|
||||
|
||||
MonitoringFeedback = persistence_monitoring.MonitoringFeedback
|
||||
|
||||
# Handle cancel feedback (type=3): delete existing record
|
||||
if feedback_type == 3:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(MonitoringFeedback).where(MonitoringFeedback.feedback_id == feedback_id)
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if record with this feedback_id already exists
|
||||
existing_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(MonitoringFeedback).where(MonitoringFeedback.feedback_id == feedback_id)
|
||||
)
|
||||
existing_row = existing_result.first()
|
||||
|
||||
if existing_row:
|
||||
# UPDATE existing record
|
||||
existing = existing_row[0] if isinstance(existing_row, tuple) else existing_row
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(MonitoringFeedback)
|
||||
.where(MonitoringFeedback.feedback_id == feedback_id)
|
||||
.values(
|
||||
timestamp=now,
|
||||
feedback_type=feedback_type,
|
||||
feedback_content=feedback_content,
|
||||
inaccurate_reasons=reasons_json,
|
||||
bot_id=bot_id or existing.bot_id,
|
||||
bot_name=bot_name or existing.bot_name,
|
||||
pipeline_id=pipeline_id or existing.pipeline_id,
|
||||
pipeline_name=pipeline_name or existing.pipeline_name,
|
||||
session_id=session_id or existing.session_id,
|
||||
message_id=message_id or existing.message_id,
|
||||
stream_id=stream_id or existing.stream_id,
|
||||
user_id=user_id or existing.user_id,
|
||||
platform=platform or existing.platform,
|
||||
)
|
||||
)
|
||||
return existing.id
|
||||
else:
|
||||
# INSERT new record with IntegrityError defense
|
||||
record_id = str(uuid.uuid4())
|
||||
record_data = {
|
||||
'id': record_id,
|
||||
'timestamp': now,
|
||||
'feedback_id': feedback_id,
|
||||
'feedback_type': feedback_type,
|
||||
'feedback_content': feedback_content,
|
||||
'inaccurate_reasons': reasons_json,
|
||||
'bot_id': bot_id,
|
||||
'bot_name': bot_name,
|
||||
'pipeline_id': pipeline_id,
|
||||
'pipeline_name': pipeline_name,
|
||||
'session_id': session_id,
|
||||
'message_id': message_id,
|
||||
'stream_id': stream_id,
|
||||
'user_id': user_id,
|
||||
'platform': platform,
|
||||
}
|
||||
try:
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(MonitoringFeedback).values(record_data))
|
||||
return record_id
|
||||
except Exception:
|
||||
# UNIQUE constraint conflict (concurrent feedback for same feedback_id)
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(MonitoringFeedback)
|
||||
.where(MonitoringFeedback.feedback_id == feedback_id)
|
||||
.values(
|
||||
timestamp=now,
|
||||
feedback_type=feedback_type,
|
||||
feedback_content=feedback_content,
|
||||
inaccurate_reasons=reasons_json,
|
||||
)
|
||||
)
|
||||
return feedback_id
|
||||
|
||||
async def get_feedback_stats(
|
||||
self,
|
||||
bot_ids: list[str] | None = None,
|
||||
pipeline_ids: list[str] | None = None,
|
||||
start_time: datetime.datetime | None = None,
|
||||
end_time: datetime.datetime | None = None,
|
||||
) -> dict:
|
||||
"""Get feedback statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with total likes, dislikes, and breakdown by bot/pipeline
|
||||
"""
|
||||
conditions = []
|
||||
|
||||
if bot_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringFeedback.bot_id.in_(bot_ids))
|
||||
if pipeline_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringFeedback.pipeline_id.in_(pipeline_ids))
|
||||
if start_time:
|
||||
conditions.append(persistence_monitoring.MonitoringFeedback.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append(persistence_monitoring.MonitoringFeedback.timestamp <= end_time)
|
||||
|
||||
# Get total likes (feedback_type = 1)
|
||||
likes_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringFeedback.id)).where(
|
||||
persistence_monitoring.MonitoringFeedback.feedback_type == 1
|
||||
)
|
||||
if conditions:
|
||||
likes_query = likes_query.where(sqlalchemy.and_(*conditions))
|
||||
likes_result = await self.ap.persistence_mgr.execute_async(likes_query)
|
||||
total_likes = likes_result.scalar() or 0
|
||||
|
||||
# Get total dislikes (feedback_type = 2)
|
||||
dislikes_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringFeedback.id)).where(
|
||||
persistence_monitoring.MonitoringFeedback.feedback_type == 2
|
||||
)
|
||||
if conditions:
|
||||
dislikes_query = dislikes_query.where(sqlalchemy.and_(*conditions))
|
||||
dislikes_result = await self.ap.persistence_mgr.execute_async(dislikes_query)
|
||||
total_dislikes = dislikes_result.scalar() or 0
|
||||
|
||||
# Get total feedback count
|
||||
total_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringFeedback.id))
|
||||
if conditions:
|
||||
total_query = total_query.where(sqlalchemy.and_(*conditions))
|
||||
total_result = await self.ap.persistence_mgr.execute_async(total_query)
|
||||
total_feedback = total_result.scalar() or 0
|
||||
|
||||
# Calculate satisfaction rate
|
||||
satisfaction_rate = (total_likes / total_feedback * 100) if total_feedback > 0 else 0
|
||||
|
||||
# Get feedback by bot
|
||||
bot_stats_query = sqlalchemy.select(
|
||||
persistence_monitoring.MonitoringFeedback.bot_id,
|
||||
persistence_monitoring.MonitoringFeedback.bot_name,
|
||||
sqlalchemy.func.count(persistence_monitoring.MonitoringFeedback.id).label('total'),
|
||||
sqlalchemy.func.sum(
|
||||
sqlalchemy.case((persistence_monitoring.MonitoringFeedback.feedback_type == 1, 1), else_=0)
|
||||
).label('likes'),
|
||||
sqlalchemy.func.sum(
|
||||
sqlalchemy.case((persistence_monitoring.MonitoringFeedback.feedback_type == 2, 1), else_=0)
|
||||
).label('dislikes'),
|
||||
).group_by(
|
||||
persistence_monitoring.MonitoringFeedback.bot_id,
|
||||
persistence_monitoring.MonitoringFeedback.bot_name,
|
||||
)
|
||||
if conditions:
|
||||
bot_stats_query = bot_stats_query.where(sqlalchemy.and_(*conditions))
|
||||
bot_stats_result = await self.ap.persistence_mgr.execute_async(bot_stats_query)
|
||||
bot_stats = [
|
||||
{
|
||||
'bot_id': row.bot_id,
|
||||
'bot_name': row.bot_name,
|
||||
'total': row.total,
|
||||
'likes': row.likes or 0,
|
||||
'dislikes': row.dislikes or 0,
|
||||
}
|
||||
for row in bot_stats_result.all()
|
||||
]
|
||||
|
||||
return {
|
||||
'total_feedback': total_feedback,
|
||||
'total_likes': total_likes,
|
||||
'total_dislikes': total_dislikes,
|
||||
'satisfaction_rate': round(satisfaction_rate, 2),
|
||||
'by_bot': bot_stats,
|
||||
}
|
||||
|
||||
async def get_feedback_list(
|
||||
self,
|
||||
bot_ids: list[str] | None = None,
|
||||
pipeline_ids: list[str] | None = None,
|
||||
feedback_type: int | None = None,
|
||||
start_time: datetime.datetime | None = None,
|
||||
end_time: datetime.datetime | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get feedback list with filters."""
|
||||
conditions = []
|
||||
|
||||
if bot_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringFeedback.bot_id.in_(bot_ids))
|
||||
if pipeline_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringFeedback.pipeline_id.in_(pipeline_ids))
|
||||
if feedback_type is not None:
|
||||
conditions.append(persistence_monitoring.MonitoringFeedback.feedback_type == feedback_type)
|
||||
if start_time:
|
||||
conditions.append(persistence_monitoring.MonitoringFeedback.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append(persistence_monitoring.MonitoringFeedback.timestamp <= end_time)
|
||||
|
||||
# Get total count
|
||||
count_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringFeedback.id))
|
||||
if conditions:
|
||||
count_query = count_query.where(sqlalchemy.and_(*conditions))
|
||||
count_result = await self.ap.persistence_mgr.execute_async(count_query)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
# Get feedback list
|
||||
query = sqlalchemy.select(persistence_monitoring.MonitoringFeedback).order_by(
|
||||
persistence_monitoring.MonitoringFeedback.timestamp.desc()
|
||||
)
|
||||
if conditions:
|
||||
query = query.where(sqlalchemy.and_(*conditions))
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(query)
|
||||
rows = result.all()
|
||||
|
||||
return (
|
||||
[
|
||||
self.ap.persistence_mgr.serialize_model(
|
||||
persistence_monitoring.MonitoringFeedback, row[0] if isinstance(row, tuple) else row
|
||||
)
|
||||
for row in rows
|
||||
],
|
||||
total,
|
||||
)
|
||||
|
||||
async def export_feedback(
|
||||
self,
|
||||
bot_ids: list[str] | None = None,
|
||||
pipeline_ids: list[str] | None = None,
|
||||
start_time: datetime.datetime | None = None,
|
||||
end_time: datetime.datetime | None = None,
|
||||
limit: int = 100000,
|
||||
) -> list[dict]:
|
||||
"""Export feedback as list of dictionaries for CSV conversion."""
|
||||
conditions = []
|
||||
|
||||
if bot_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringFeedback.bot_id.in_(bot_ids))
|
||||
if pipeline_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringFeedback.pipeline_id.in_(pipeline_ids))
|
||||
if start_time:
|
||||
conditions.append(persistence_monitoring.MonitoringFeedback.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append(persistence_monitoring.MonitoringFeedback.timestamp <= end_time)
|
||||
|
||||
query = sqlalchemy.select(persistence_monitoring.MonitoringFeedback).order_by(
|
||||
persistence_monitoring.MonitoringFeedback.timestamp.desc()
|
||||
)
|
||||
if conditions:
|
||||
query = query.where(sqlalchemy.and_(*conditions))
|
||||
query = query.limit(limit)
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(query)
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{
|
||||
'id': row[0].id if isinstance(row, tuple) else row.id,
|
||||
'timestamp': self._format_timestamp(row[0].timestamp if isinstance(row, tuple) else row.timestamp),
|
||||
'feedback_id': row[0].feedback_id if isinstance(row, tuple) else row.feedback_id,
|
||||
'feedback_type': 'like'
|
||||
if (row[0].feedback_type if isinstance(row, tuple) else row.feedback_type) == 1
|
||||
else 'dislike',
|
||||
'feedback_content': row[0].feedback_content if isinstance(row, tuple) else row.feedback_content,
|
||||
'inaccurate_reasons': row[0].inaccurate_reasons if isinstance(row, tuple) else row.inaccurate_reasons,
|
||||
'bot_id': row[0].bot_id if isinstance(row, tuple) else row.bot_id,
|
||||
'bot_name': row[0].bot_name if isinstance(row, tuple) else row.bot_name,
|
||||
'pipeline_id': row[0].pipeline_id if isinstance(row, tuple) else row.pipeline_id,
|
||||
'pipeline_name': row[0].pipeline_name if isinstance(row, tuple) else row.pipeline_name,
|
||||
'session_id': row[0].session_id if isinstance(row, tuple) else row.session_id,
|
||||
'message_id': row[0].message_id if isinstance(row, tuple) else row.message_id,
|
||||
'stream_id': row[0].stream_id if isinstance(row, tuple) else row.stream_id,
|
||||
'user_id': row[0].user_id if isinstance(row, tuple) else row.user_id,
|
||||
'platform': row[0].platform if isinstance(row, tuple) else row.platform,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
@@ -73,6 +73,20 @@ class PipelineService:
|
||||
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
|
||||
async def get_pipeline_by_name(self, pipeline_name: str) -> dict | None:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
||||
persistence_pipeline.LegacyPipeline.name == pipeline_name
|
||||
)
|
||||
)
|
||||
|
||||
pipeline = result.first()
|
||||
|
||||
if pipeline is None:
|
||||
return None
|
||||
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
|
||||
async def create_pipeline(self, pipeline_data: dict, default: bool = False) -> str:
|
||||
from ....utils import paths as path_utils
|
||||
|
||||
@@ -113,14 +127,9 @@ class PipelineService:
|
||||
return pipeline_data['uuid']
|
||||
|
||||
async def update_pipeline(self, pipeline_uuid: str, pipeline_data: dict) -> None:
|
||||
if 'uuid' in pipeline_data:
|
||||
del pipeline_data['uuid']
|
||||
if 'for_version' in pipeline_data:
|
||||
del pipeline_data['for_version']
|
||||
if 'stages' in pipeline_data:
|
||||
del pipeline_data['stages']
|
||||
if 'is_default' in pipeline_data:
|
||||
del pipeline_data['is_default']
|
||||
pipeline_data = pipeline_data.copy()
|
||||
for protected_field in ('uuid', 'for_version', 'stages', 'is_default'):
|
||||
pipeline_data.pop(protected_field, None)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import traceback
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
@@ -16,6 +17,24 @@ class ModelProviderService:
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
@staticmethod
|
||||
def _normalize_api_keys(api_keys: str | list[str] | tuple[str, ...] | None) -> list[str]:
|
||||
if api_keys is None:
|
||||
return []
|
||||
|
||||
raw_keys = [api_keys] if isinstance(api_keys, str) else list(api_keys)
|
||||
normalized_keys = []
|
||||
seen_keys = set()
|
||||
|
||||
for raw_key in raw_keys:
|
||||
normalized_key = raw_key.strip() if isinstance(raw_key, str) else ''
|
||||
if not normalized_key or normalized_key in seen_keys:
|
||||
continue
|
||||
normalized_keys.append(normalized_key)
|
||||
seen_keys.add(normalized_key)
|
||||
|
||||
return normalized_keys
|
||||
|
||||
async def get_providers(self) -> list[dict]:
|
||||
"""Get all providers"""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.ModelProvider))
|
||||
@@ -58,6 +77,7 @@ class ModelProviderService:
|
||||
async def create_provider(self, provider_data: dict) -> str:
|
||||
"""Create a new provider"""
|
||||
provider_data['uuid'] = str(uuid.uuid4())
|
||||
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
|
||||
)
|
||||
@@ -71,6 +91,8 @@ class ModelProviderService:
|
||||
"""Update an existing provider"""
|
||||
if 'uuid' in provider_data:
|
||||
del provider_data['uuid']
|
||||
if 'api_keys' in provider_data:
|
||||
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.ModelProvider)
|
||||
.where(persistence_model.ModelProvider.uuid == provider_uuid)
|
||||
@@ -97,6 +119,14 @@ class ModelProviderService:
|
||||
if embedding_result.first() is not None:
|
||||
raise ValueError('Cannot delete provider: Embedding models still reference it')
|
||||
|
||||
rerank_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.RerankModel).where(
|
||||
persistence_model.RerankModel.provider_uuid == provider_uuid
|
||||
)
|
||||
)
|
||||
if rerank_result.first() is not None:
|
||||
raise ValueError('Cannot delete provider: Rerank models still reference it')
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_model.ModelProvider).where(
|
||||
persistence_model.ModelProvider.uuid == provider_uuid
|
||||
@@ -121,10 +151,19 @@ class ModelProviderService:
|
||||
)
|
||||
embedding_count = embedding_result.scalar() or 0
|
||||
|
||||
return {'llm_count': llm_count, 'embedding_count': embedding_count}
|
||||
rerank_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(sqlalchemy.func.count())
|
||||
.select_from(persistence_model.RerankModel)
|
||||
.where(persistence_model.RerankModel.provider_uuid == provider_uuid)
|
||||
)
|
||||
rerank_count = rerank_result.scalar() or 0
|
||||
|
||||
return {'llm_count': llm_count, 'embedding_count': embedding_count, 'rerank_count': rerank_count}
|
||||
|
||||
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
|
||||
"""Find existing provider or create new one"""
|
||||
api_keys = self._normalize_api_keys(api_keys)
|
||||
|
||||
# Try to find existing provider with same config
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||
@@ -152,7 +191,7 @@ class ModelProviderService:
|
||||
'name': provider_name,
|
||||
'requester': requester,
|
||||
'base_url': base_url,
|
||||
'api_keys': api_keys or [],
|
||||
'api_keys': api_keys,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -161,6 +200,69 @@ class ModelProviderService:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.ModelProvider)
|
||||
.where(persistence_model.ModelProvider.uuid == '00000000-0000-0000-0000-000000000000')
|
||||
.values(api_keys=[api_key])
|
||||
.values(api_keys=self._normalize_api_keys(api_key))
|
||||
)
|
||||
await self.ap.model_mgr.reload_provider('00000000-0000-0000-0000-000000000000')
|
||||
|
||||
async def scan_provider_models(self, provider_uuid: str, model_type: str | None = None) -> dict:
|
||||
provider = await self.get_provider(provider_uuid)
|
||||
if provider is None:
|
||||
raise ValueError('provider not found')
|
||||
|
||||
runtime_provider = await self.ap.model_mgr.load_provider(provider)
|
||||
|
||||
try:
|
||||
scan_result = await runtime_provider.requester.scan_models(
|
||||
runtime_provider.token_mgr.get_token() if runtime_provider.token_mgr.tokens else None
|
||||
)
|
||||
except NotImplementedError:
|
||||
raise ValueError('current provider does not support model scanning')
|
||||
except Exception as exc:
|
||||
self.ap.logger.warning(
|
||||
f'Failed to scan models for provider {provider_uuid}: {exc}\n{traceback.format_exc()}'
|
||||
)
|
||||
raise ValueError(str(exc)) from exc
|
||||
|
||||
if isinstance(scan_result, dict):
|
||||
scanned_models = scan_result.get('models', [])
|
||||
debug_info = scan_result.get('debug')
|
||||
else:
|
||||
scanned_models = scan_result
|
||||
debug_info = None
|
||||
|
||||
llm_models = await self.ap.llm_model_service.get_llm_models_by_provider(provider_uuid)
|
||||
embedding_models = await self.ap.embedding_models_service.get_embedding_models_by_provider(provider_uuid)
|
||||
existing_llm_names = {model['name'] for model in llm_models}
|
||||
existing_embedding_names = {model['name'] for model in embedding_models}
|
||||
|
||||
filtered_models = []
|
||||
for model in scanned_models:
|
||||
scanned_type = model.get('type', 'llm')
|
||||
if model_type and scanned_type != model_type:
|
||||
continue
|
||||
|
||||
model_name = model.get('name') or model.get('id')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
filtered_models.append(
|
||||
{
|
||||
'id': model.get('id', model_name),
|
||||
'name': model_name,
|
||||
'type': scanned_type,
|
||||
'abilities': model.get('abilities', []),
|
||||
'display_name': model.get('display_name'),
|
||||
'description': model.get('description'),
|
||||
'context_length': model.get('context_length'),
|
||||
'owned_by': model.get('owned_by'),
|
||||
'input_modalities': model.get('input_modalities', []),
|
||||
'output_modalities': model.get('output_modalities', []),
|
||||
'already_added': (
|
||||
model_name in existing_embedding_names
|
||||
if scanned_type == 'embedding'
|
||||
else model_name in existing_llm_names
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return {'models': filtered_models, 'debug': debug_info}
|
||||
|
||||
@@ -179,7 +179,7 @@ class SpaceService:
|
||||
space_url = space_config['url']
|
||||
|
||||
session = httpclient.get_session()
|
||||
async with session.get(f'{space_url}/api/v1/models') as response:
|
||||
async with session.get(f'{space_url}/api/v1/models', params={'page_size': 100}) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to get models: {await response.text()}')
|
||||
data = await response.json()
|
||||
|
||||
@@ -65,8 +65,8 @@ class UserService:
|
||||
|
||||
user_obj = result_list[0]
|
||||
|
||||
# Check if this is a Space account
|
||||
if user_obj.account_type == 'space':
|
||||
# Check if this user has a local password set
|
||||
if not user_obj.password:
|
||||
raise ValueError('请使用 Space 账户登录')
|
||||
|
||||
ph = argon2.PasswordHasher()
|
||||
@@ -108,9 +108,8 @@ class UserService:
|
||||
if user_obj is None:
|
||||
raise ValueError('User not found')
|
||||
|
||||
# Space accounts cannot change password locally
|
||||
if user_obj.account_type == 'space':
|
||||
raise ValueError('Space account cannot change password locally')
|
||||
if not user_obj.password:
|
||||
raise ValueError('No local password set, please set a password first')
|
||||
|
||||
ph.verify(user_obj.password, current_password)
|
||||
|
||||
|
||||
1175
src/langbot/pkg/api/http/service/workflow.py
Normal file
1175
src/langbot/pkg/api/http/service/workflow.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,7 @@ from ..platform import botmgr as im_mgr
|
||||
from ..platform.webhook_pusher import WebhookPusher
|
||||
from ..provider.session import sessionmgr as llm_session_mgr
|
||||
from ..provider.modelmgr import modelmgr as llm_model_mgr
|
||||
|
||||
from langbot.pkg.provider.tools import toolmgr as llm_tool_mgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..command import cmdmgr
|
||||
@@ -30,6 +31,9 @@ from ..api.http.service import mcp as mcp_service
|
||||
from ..api.http.service import apikey as apikey_service
|
||||
from ..api.http.service import webhook as webhook_service
|
||||
from ..api.http.service import monitoring as monitoring_service
|
||||
from ..api.http.service import workflow as workflow_service
|
||||
from ..api.http.service import maintenance as maintenance_service
|
||||
|
||||
from ..discover import engine as discover_engine
|
||||
from ..storage import mgr as storagemgr
|
||||
from ..utils import logcache
|
||||
@@ -131,6 +135,8 @@ class Application:
|
||||
|
||||
embedding_models_service: model_service.EmbeddingModelsService = None
|
||||
|
||||
rerank_models_service: model_service.RerankModelsService = None
|
||||
|
||||
provider_service: provider_service.ModelProviderService = None
|
||||
|
||||
pipeline_service: pipeline_service.PipelineService = None
|
||||
@@ -145,12 +151,16 @@ class Application:
|
||||
|
||||
webhook_service: webhook_service.WebhookService = None
|
||||
|
||||
workflow_service: workflow_service.WorkflowService = None
|
||||
|
||||
telemetry: telemetry_module.TelemetryManager = None
|
||||
|
||||
survey: survey_module.SurveyManager = None
|
||||
|
||||
monitoring_service: monitoring_service.MonitoringService = None
|
||||
|
||||
maintenance_service: maintenance_service.MaintenanceService = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -186,6 +196,93 @@ class Application:
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||
)
|
||||
|
||||
# Start monitoring data cleanup task if enabled
|
||||
monitoring_cfg = self.instance_config.data.get('monitoring', {})
|
||||
auto_cleanup_cfg = monitoring_cfg.get('auto_cleanup', {})
|
||||
if auto_cleanup_cfg.get('enabled', True):
|
||||
retention_days = self._get_positive_int_config(
|
||||
auto_cleanup_cfg.get('retention_days', 30),
|
||||
default=30,
|
||||
name='monitoring.auto_cleanup.retention_days',
|
||||
)
|
||||
delete_batch_size = self._get_positive_int_config(
|
||||
auto_cleanup_cfg.get('delete_batch_size', 1000),
|
||||
default=1000,
|
||||
name='monitoring.auto_cleanup.delete_batch_size',
|
||||
)
|
||||
check_interval_hours = self._get_positive_float_config(
|
||||
auto_cleanup_cfg.get('check_interval_hours', 1),
|
||||
default=1,
|
||||
name='monitoring.auto_cleanup.check_interval_hours',
|
||||
)
|
||||
|
||||
async def monitoring_cleanup_loop():
|
||||
check_interval_seconds = check_interval_hours * 3600
|
||||
while True:
|
||||
try:
|
||||
deleted = await self.monitoring_service.cleanup_expired_records(
|
||||
retention_days,
|
||||
batch_size=delete_batch_size,
|
||||
)
|
||||
total_deleted = sum(deleted.values())
|
||||
if total_deleted > 0:
|
||||
self.logger.info(
|
||||
f'Monitoring auto-cleanup: deleted {total_deleted} expired records '
|
||||
f'(retention={retention_days}d): {deleted}'
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(f'Monitoring auto-cleanup error: {e}')
|
||||
await asyncio.sleep(check_interval_seconds)
|
||||
|
||||
self.task_mgr.create_task(
|
||||
monitoring_cleanup_loop(),
|
||||
name='monitoring-cleanup',
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||
)
|
||||
|
||||
async def workflow_execution_cleanup_loop():
|
||||
check_interval_seconds = 60
|
||||
while True:
|
||||
try:
|
||||
cancelled = await self.workflow_service.cleanup_stale_executions()
|
||||
if cancelled > 0:
|
||||
self.logger.info(f'Workflow execution auto-cleanup: cancelled {cancelled} stale executions')
|
||||
except Exception as e:
|
||||
self.logger.warning(f'Workflow execution auto-cleanup error: {e}')
|
||||
await asyncio.sleep(check_interval_seconds)
|
||||
|
||||
self.task_mgr.create_task(
|
||||
workflow_execution_cleanup_loop(),
|
||||
name='workflow-execution-cleanup',
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||
)
|
||||
# Start storage/log maintenance task if enabled
|
||||
storage_cleanup_cfg = self.instance_config.data.get('storage', {}).get('cleanup', {})
|
||||
if storage_cleanup_cfg.get('enabled', True) and self.maintenance_service is not None:
|
||||
check_interval_hours = self._get_positive_float_config(
|
||||
storage_cleanup_cfg.get('check_interval_hours', 1),
|
||||
default=1,
|
||||
name='storage.cleanup.check_interval_hours',
|
||||
)
|
||||
|
||||
async def storage_cleanup_loop():
|
||||
check_interval_seconds = check_interval_hours * 3600
|
||||
while True:
|
||||
try:
|
||||
deleted = await self.maintenance_service.cleanup_expired_files()
|
||||
total_deleted = sum(deleted.values())
|
||||
if total_deleted > 0:
|
||||
self.logger.info(f'Storage maintenance: deleted expired files: {deleted}')
|
||||
except Exception as e:
|
||||
self.logger.warning(f'Storage maintenance error: {e}')
|
||||
await asyncio.sleep(check_interval_seconds)
|
||||
|
||||
self.task_mgr.create_task(
|
||||
storage_cleanup_loop(),
|
||||
name='storage-maintenance',
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||
)
|
||||
|
||||
self.task_mgr.create_task(
|
||||
never_ending(),
|
||||
name='never-ending-task',
|
||||
@@ -200,6 +297,28 @@ class Application:
|
||||
self.logger.error(f'Application runtime fatal exception: {e}')
|
||||
self.logger.debug(f'Traceback: {traceback.format_exc()}')
|
||||
|
||||
def _get_positive_int_config(self, value, default: int, name: str) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
self.logger.warning(f'Invalid {name}: {value!r}, using {default}')
|
||||
return default
|
||||
if parsed < 1:
|
||||
self.logger.warning(f'Invalid {name}: {value!r}, using {default}')
|
||||
return default
|
||||
return parsed
|
||||
|
||||
def _get_positive_float_config(self, value, default: float, name: str) -> float:
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
self.logger.warning(f'Invalid {name}: {value!r}, using {default}')
|
||||
return default
|
||||
if parsed <= 0:
|
||||
self.logger.warning(f'Invalid {name}: {value!r}, using {default}')
|
||||
return default
|
||||
return parsed
|
||||
|
||||
def dispose(self):
|
||||
self.plugin_connector.dispose()
|
||||
|
||||
|
||||
@@ -46,12 +46,14 @@ async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
|
||||
|
||||
|
||||
async def main(loop: asyncio.AbstractEventLoop):
|
||||
app_inst: app.Application | None = None
|
||||
try:
|
||||
# Hang system signal processing
|
||||
import signal
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
app_inst.dispose()
|
||||
if app_inst is not None:
|
||||
app_inst.dispose()
|
||||
print('[Signal] Program exit.')
|
||||
os._exit(0)
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ from ...api.http.service import mcp as mcp_service
|
||||
from ...api.http.service import apikey as apikey_service
|
||||
from ...api.http.service import webhook as webhook_service
|
||||
from ...api.http.service import monitoring as monitoring_service
|
||||
from ...api.http.service import workflow as workflow_service
|
||||
from ...api.http.service import maintenance as maintenance_service
|
||||
from ...discover import engine as discover_engine
|
||||
from ...storage import mgr as storagemgr
|
||||
from ...utils import logcache
|
||||
@@ -61,6 +63,9 @@ class BuildAppStage(stage.BootingStage):
|
||||
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
|
||||
ap.embedding_models_service = embedding_models_service_inst
|
||||
|
||||
rerank_models_service_inst = model_service.RerankModelsService(ap)
|
||||
ap.rerank_models_service = rerank_models_service_inst
|
||||
|
||||
provider_service_inst = provider_service.ModelProviderService(ap)
|
||||
ap.provider_service = provider_service_inst
|
||||
|
||||
@@ -82,6 +87,9 @@ class BuildAppStage(stage.BootingStage):
|
||||
webhook_service_inst = webhook_service.WebhookService(ap)
|
||||
ap.webhook_service = webhook_service_inst
|
||||
|
||||
workflow_service_inst = workflow_service.WorkflowService(ap)
|
||||
ap.workflow_service = workflow_service_inst
|
||||
|
||||
proxy_mgr = proxy.ProxyManager(ap)
|
||||
await proxy_mgr.initialize()
|
||||
ap.proxy_mgr = proxy_mgr
|
||||
@@ -164,6 +172,9 @@ class BuildAppStage(stage.BootingStage):
|
||||
monitoring_service_inst = monitoring_service.MonitoringService(ap)
|
||||
ap.monitoring_service = monitoring_service_inst
|
||||
|
||||
maintenance_service_inst = maintenance_service.MaintenanceService(ap)
|
||||
ap.maintenance_service = maintenance_service_inst
|
||||
|
||||
async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None:
|
||||
await asyncio.sleep(3)
|
||||
await plugin_connector_inst.initialize()
|
||||
|
||||
@@ -74,20 +74,30 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict:
|
||||
current = cfg
|
||||
|
||||
for i, key in enumerate(keys):
|
||||
if not isinstance(current, dict) or key not in current:
|
||||
if not isinstance(current, dict):
|
||||
break
|
||||
|
||||
if i == len(keys) - 1:
|
||||
# At the final key - check if it's a scalar value
|
||||
if isinstance(current[key], (dict, list)):
|
||||
# Skip dict and list types
|
||||
pass
|
||||
# At the final key
|
||||
if key in current:
|
||||
if isinstance(current[key], list):
|
||||
# Convert comma-separated string to list
|
||||
# e.g., SYSTEM__DISABLED_ADAPTERS="aiocqhttp,dingtalk"
|
||||
current[key] = [item.strip() for item in env_value.split(',') if item.strip()]
|
||||
elif isinstance(current[key], dict):
|
||||
# Skip dict types
|
||||
pass
|
||||
else:
|
||||
# Valid scalar value - convert and set it
|
||||
converted_value = convert_value(env_value, current[key])
|
||||
current[key] = converted_value
|
||||
else:
|
||||
# Valid scalar value - convert and set it
|
||||
converted_value = convert_value(env_value, current[key])
|
||||
current[key] = converted_value
|
||||
# Key doesn't exist yet - create it as string
|
||||
current[key] = env_value
|
||||
else:
|
||||
# Navigate deeper
|
||||
# Navigate deeper - create intermediate dict if needed
|
||||
if key not in current:
|
||||
current[key] = {}
|
||||
current = current[key]
|
||||
|
||||
return cfg
|
||||
@@ -146,16 +156,50 @@ class LoadConfigStage(stage.BootingStage):
|
||||
await ap.instance_config.dump_config()
|
||||
|
||||
# load or generate instance id
|
||||
ap.instance_id = await config.load_json_config(
|
||||
'data/labels/instance_id.json',
|
||||
template_data={
|
||||
'instance_id': f'instance_{str(uuid.uuid4())}',
|
||||
'instance_create_ts': int(time.time()),
|
||||
},
|
||||
completion=False,
|
||||
)
|
||||
# Priority:
|
||||
# 1. system.instance_id from config.yaml (can be set via SYSTEM__INSTANCE_ID env var)
|
||||
# 2. data/labels/instance_id.json (if file exists)
|
||||
# 3. Generate new and save to file
|
||||
config_instance_id = ap.instance_config.data.get('system', {}).get('instance_id', '')
|
||||
|
||||
constants.instance_id = ap.instance_id.data['instance_id']
|
||||
if config_instance_id:
|
||||
# Use the instance_id from config.yaml
|
||||
constants.instance_id = config_instance_id
|
||||
# Still load/create the file for backward compat, but don't use its value
|
||||
ap.instance_id = await config.load_json_config(
|
||||
'data/labels/instance_id.json',
|
||||
template_data={
|
||||
'instance_id': f'instance_{str(uuid.uuid4())}',
|
||||
'instance_create_ts': int(time.time()),
|
||||
},
|
||||
completion=False,
|
||||
)
|
||||
else:
|
||||
# Try loading file-based instance id
|
||||
instance_id_path = os.path.join('data', 'labels', 'instance_id.json')
|
||||
if os.path.exists(instance_id_path):
|
||||
# File exists, read it
|
||||
ap.instance_id = await config.load_json_config(
|
||||
'data/labels/instance_id.json',
|
||||
template_data={
|
||||
'instance_id': '',
|
||||
'instance_create_ts': 0,
|
||||
},
|
||||
completion=False,
|
||||
)
|
||||
constants.instance_id = ap.instance_id.data['instance_id']
|
||||
else:
|
||||
# Neither config nor file, generate new and save to file
|
||||
new_id = f'instance_{str(uuid.uuid4())}'
|
||||
ap.instance_id = await config.load_json_config(
|
||||
'data/labels/instance_id.json',
|
||||
template_data={
|
||||
'instance_id': new_id,
|
||||
'instance_create_ts': int(time.time()),
|
||||
},
|
||||
completion=False,
|
||||
)
|
||||
constants.instance_id = new_id
|
||||
constants.edition = ap.instance_config.data.get('system', {}).get('edition', 'community')
|
||||
|
||||
print(f'LangBot instance id: {constants.instance_id}')
|
||||
@@ -177,3 +221,34 @@ class LoadConfigStage(stage.BootingStage):
|
||||
ap.pipeline_config_meta_safety = await load_resource_yaml_template_data('metadata/pipeline/safety.yaml')
|
||||
ap.pipeline_config_meta_ai = await load_resource_yaml_template_data('metadata/pipeline/ai.yaml')
|
||||
ap.pipeline_config_meta_output = await load_resource_yaml_template_data('metadata/pipeline/output.yaml')
|
||||
|
||||
# Load workflow node metadata from YAML files. YAML is the source of
|
||||
# truth for workflow editor metadata; Python classes provide execution
|
||||
# logic and are bound through the registry.
|
||||
from langbot.pkg.workflow.metadata import NodeMetadataLoader
|
||||
from langbot.pkg.workflow.registry import NodeTypeRegistry
|
||||
|
||||
workflow_metadata_loader = NodeMetadataLoader()
|
||||
workflow_node_count = await workflow_metadata_loader.load_core_metadata()
|
||||
ap.workflow_node_configs = workflow_metadata_loader.get_all_metadata()
|
||||
ap.workflow_node_metadata_loader = workflow_metadata_loader
|
||||
|
||||
workflow_registry = NodeTypeRegistry.instance()
|
||||
for node_config in ap.workflow_node_configs.values():
|
||||
workflow_registry.register_metadata(node_config, source=node_config.get('_source', 'core'))
|
||||
|
||||
# Auto-discover and register workflow nodes using discovery engine
|
||||
if hasattr(ap, 'discover') and ap.discover is not None:
|
||||
workflow_registry.discover_nodes(ap.discover)
|
||||
|
||||
workflow_load_errors = workflow_metadata_loader.get_load_errors()
|
||||
if workflow_load_errors:
|
||||
print(f'Workflow node metadata load errors: {len(workflow_load_errors)}')
|
||||
for error in workflow_load_errors:
|
||||
print(f" - {error.get('file')}: {error.get('error')}")
|
||||
|
||||
print(
|
||||
f'Loaded {workflow_node_count} workflow node metadata files; '
|
||||
f'registered {workflow_registry.metadata_count()} metadata definitions, '
|
||||
f'{workflow_registry.count()} node types'
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import typing
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from . import app
|
||||
from . import entities as core_entities
|
||||
@@ -17,9 +18,13 @@ class TaskContext:
|
||||
log: str
|
||||
"""Log"""
|
||||
|
||||
metadata: dict
|
||||
"""Structured metadata for progress reporting"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_action = 'default'
|
||||
self.log = ''
|
||||
self.metadata = {}
|
||||
|
||||
def _log(self, msg: str):
|
||||
self.log += msg + '\n'
|
||||
@@ -38,7 +43,7 @@ class TaskContext:
|
||||
self._log(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | {self.current_action} | {msg}')
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {'current_action': self.current_action, 'log': self.log}
|
||||
return {'current_action': self.current_action, 'log': self.log, 'metadata': self.metadata}
|
||||
|
||||
@staticmethod
|
||||
def new() -> TaskContext:
|
||||
@@ -115,6 +120,7 @@ class TaskWrapper:
|
||||
self.label = label if label != '' else name
|
||||
self.task.set_name(name)
|
||||
self.scopes = scopes
|
||||
self.created_at = time.time()
|
||||
|
||||
def assume_exception(self):
|
||||
try:
|
||||
@@ -150,6 +156,7 @@ class TaskWrapper:
|
||||
'name': self.name,
|
||||
'label': self.label,
|
||||
'scopes': [scope.value for scope in self.scopes],
|
||||
'created_at': self.created_at,
|
||||
'task_context': self.task_context.to_dict(),
|
||||
'runtime': {
|
||||
'done': self.task.done(),
|
||||
@@ -189,6 +196,8 @@ class AsyncTaskManager:
|
||||
) -> TaskWrapper:
|
||||
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes)
|
||||
self.tasks.append(wrapper)
|
||||
wrapper.task.add_done_callback(lambda _: self._prune_completed_tasks())
|
||||
self._prune_completed_tasks()
|
||||
return wrapper
|
||||
|
||||
def create_user_task(
|
||||
@@ -211,9 +220,23 @@ class AsyncTaskManager:
|
||||
def get_tasks_dict(
|
||||
self,
|
||||
type: str = None,
|
||||
kind: str = None,
|
||||
) -> dict:
|
||||
return {
|
||||
'tasks': [t.to_dict() for t in self.tasks if type is None or t.task_type == type],
|
||||
'tasks': [
|
||||
t.to_dict()
|
||||
for t in self.tasks
|
||||
if (type is None or t.task_type == type) and (kind is None or t.kind == kind)
|
||||
],
|
||||
'id_index': TaskWrapper._id_index,
|
||||
}
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
completed = sum(1 for t in self.tasks if t.task.done())
|
||||
return {
|
||||
'total': len(self.tasks),
|
||||
'running': len(self.tasks) - completed,
|
||||
'completed': completed,
|
||||
'id_index': TaskWrapper._id_index,
|
||||
}
|
||||
|
||||
@@ -234,3 +257,27 @@ class AsyncTaskManager:
|
||||
if not wrapper.task.done():
|
||||
wrapper.task.cancel()
|
||||
return
|
||||
|
||||
def _prune_completed_tasks(self):
|
||||
completed_limit = (
|
||||
self.ap.instance_config.data.get('system', {})
|
||||
.get('task_retention', {})
|
||||
.get(
|
||||
'completed_limit',
|
||||
200,
|
||||
)
|
||||
)
|
||||
try:
|
||||
completed_limit = int(completed_limit)
|
||||
except (TypeError, ValueError):
|
||||
completed_limit = 200
|
||||
if completed_limit < 1:
|
||||
completed_limit = 1
|
||||
|
||||
completed_tasks = [wrapper for wrapper in self.tasks if wrapper.task.done()]
|
||||
overflow = len(completed_tasks) - completed_limit
|
||||
if overflow <= 0:
|
||||
return
|
||||
|
||||
remove_ids = {wrapper.id for wrapper in completed_tasks[:overflow]}
|
||||
self.tasks = [wrapper for wrapper in self.tasks if wrapper.id not in remove_ids]
|
||||
|
||||
@@ -17,11 +17,23 @@ class I18nString(pydantic.BaseModel):
|
||||
"""英文"""
|
||||
|
||||
zh_Hans: typing.Optional[str] = None
|
||||
"""中文"""
|
||||
"""简体中文"""
|
||||
|
||||
zh_Hant: typing.Optional[str] = None
|
||||
"""繁体中文"""
|
||||
|
||||
ja_JP: typing.Optional[str] = None
|
||||
"""日文"""
|
||||
|
||||
th_TH: typing.Optional[str] = None
|
||||
"""泰文"""
|
||||
|
||||
vi_VN: typing.Optional[str] = None
|
||||
"""越南文"""
|
||||
|
||||
es_ES: typing.Optional[str] = None
|
||||
"""西班牙文"""
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
dic = {}
|
||||
@@ -29,8 +41,16 @@ class I18nString(pydantic.BaseModel):
|
||||
dic['en_US'] = self.en_US
|
||||
if self.zh_Hans is not None:
|
||||
dic['zh_Hans'] = self.zh_Hans
|
||||
if self.zh_Hant is not None:
|
||||
dic['zh_Hant'] = self.zh_Hant
|
||||
if self.ja_JP is not None:
|
||||
dic['ja_JP'] = self.ja_JP
|
||||
if self.th_TH is not None:
|
||||
dic['th_TH'] = self.th_TH
|
||||
if self.vi_VN is not None:
|
||||
dic['vi_VN'] = self.vi_VN
|
||||
if self.es_ES is not None:
|
||||
dic['es_ES'] = self.es_ES
|
||||
return dic
|
||||
|
||||
|
||||
@@ -284,3 +304,65 @@ class ComponentDiscoveryEngine:
|
||||
if component.kind == kind:
|
||||
result.append(component)
|
||||
return result
|
||||
|
||||
def discover_workflow_nodes(self, nodes_dir: str) -> typing.List[typing.Type]:
|
||||
"""Discover workflow node classes from a directory of Python modules.
|
||||
|
||||
Scans all .py files in the given directory, imports them, and collects
|
||||
classes that are subclasses of WorkflowNode.
|
||||
|
||||
Args:
|
||||
nodes_dir: Directory path like 'pkg/workflow/nodes/'
|
||||
|
||||
Returns:
|
||||
List of WorkflowNode subclasses found
|
||||
"""
|
||||
from langbot.pkg.workflow.node import WorkflowNode
|
||||
|
||||
node_classes: typing.List[typing.Type[WorkflowNode]] = []
|
||||
|
||||
# Normalize path
|
||||
if nodes_dir.endswith('/'):
|
||||
nodes_dir = nodes_dir[:-1]
|
||||
|
||||
# Import the nodes package to trigger all module imports
|
||||
module_path = nodes_dir.replace('/', '.').replace('\\', '.')
|
||||
package_path = module_path
|
||||
|
||||
try:
|
||||
# Import the package __init__ to trigger submodule imports
|
||||
importlib.import_module(f'langbot.{package_path}')
|
||||
except ImportError:
|
||||
self.ap.logger.warning(f'Failed to import workflow nodes package: langbot.{package_path}')
|
||||
|
||||
# Since workflow/__init__.py is empty, explicitly import all .py files in the nodes directory
|
||||
import os
|
||||
# engine.py is in langbot/pkg/discover/, nodes are in langbot/pkg/workflow/nodes/
|
||||
nodes_abs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'workflow', 'nodes'))
|
||||
if os.path.isdir(nodes_abs_path):
|
||||
for filename in os.listdir(nodes_abs_path):
|
||||
if filename.endswith('.py') and not filename.startswith('_'):
|
||||
module_name = filename[:-3]
|
||||
try:
|
||||
importlib.import_module(f'langbot.{package_path}.{module_name}')
|
||||
except ImportError as e:
|
||||
self.ap.logger.warning(f'Failed to import workflow node module: {module_name}: {e}')
|
||||
|
||||
# Now collect all WorkflowNode subclasses from sys.modules
|
||||
import sys
|
||||
prefix = f'langbot.{package_path}.'
|
||||
for mod_name, mod in sys.modules.items():
|
||||
if mod_name.startswith(prefix) and mod is not None:
|
||||
for attr_name in dir(mod):
|
||||
attr = getattr(mod, attr_name)
|
||||
if (
|
||||
isinstance(attr, type)
|
||||
and issubclass(attr, WorkflowNode)
|
||||
and attr is not WorkflowNode
|
||||
and hasattr(attr, 'type_name')
|
||||
and attr.type_name
|
||||
):
|
||||
if attr not in node_classes:
|
||||
node_classes.append(attr)
|
||||
|
||||
return node_classes
|
||||
|
||||
@@ -16,6 +16,14 @@ class Bot(Base):
|
||||
enable = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
|
||||
use_pipeline_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
use_pipeline_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
pipeline_routing_rules = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, server_default='[]')
|
||||
|
||||
# New unified binding fields
|
||||
# binding_type: 'pipeline' or 'workflow'
|
||||
binding_type = sqlalchemy.Column(sqlalchemy.String(32), nullable=False, server_default='pipeline')
|
||||
# binding_uuid: UUID of the bound Pipeline or Workflow
|
||||
binding_uuid = sqlalchemy.Column(sqlalchemy.String(64), nullable=True)
|
||||
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
|
||||
@@ -59,3 +59,22 @@ class EmbeddingModel(Base):
|
||||
server_default=sqlalchemy.func.now(),
|
||||
onupdate=sqlalchemy.func.now(),
|
||||
)
|
||||
|
||||
|
||||
class RerankModel(Base):
|
||||
"""Rerank model"""
|
||||
|
||||
__tablename__ = 'rerank_models'
|
||||
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
prefered_ranking = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
server_default=sqlalchemy.func.now(),
|
||||
onupdate=sqlalchemy.func.now(),
|
||||
)
|
||||
|
||||
@@ -20,6 +20,7 @@ class MonitoringMessage(Base):
|
||||
level = sqlalchemy.Column(sqlalchemy.String(50), nullable=False) # info, warning, error, debug
|
||||
platform = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
user_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
user_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) # User display name
|
||||
runner_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) # Runner name for this query
|
||||
variables = sqlalchemy.Column(sqlalchemy.Text, nullable=True) # Query variables as JSON string
|
||||
role = sqlalchemy.Column(sqlalchemy.String(50), nullable=True, default='user') # user, assistant
|
||||
@@ -64,6 +65,7 @@ class MonitoringSession(Base):
|
||||
is_active = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True, index=True)
|
||||
platform = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
user_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
user_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) # User display name
|
||||
|
||||
|
||||
class MonitoringError(Base):
|
||||
@@ -104,3 +106,26 @@ class MonitoringEmbeddingCall(Base):
|
||||
session_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
message_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
call_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=True) # embedding, retrieve
|
||||
|
||||
|
||||
class MonitoringFeedback(Base):
|
||||
"""User feedback records (like/dislike) from AI Bot conversations"""
|
||||
|
||||
__tablename__ = 'monitoring_feedback'
|
||||
|
||||
id = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
|
||||
timestamp = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, index=True)
|
||||
feedback_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, unique=True, index=True)
|
||||
feedback_type = sqlalchemy.Column(sqlalchemy.Integer, nullable=False) # 1=like, 2=dislike
|
||||
feedback_content = sqlalchemy.Column(sqlalchemy.Text, nullable=True) # User feedback text
|
||||
inaccurate_reasons = sqlalchemy.Column(sqlalchemy.Text, nullable=True) # JSON list of inaccurate reasons
|
||||
# Context fields
|
||||
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
bot_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
pipeline_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
pipeline_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
session_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
message_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
stream_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
user_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
platform = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) # e.g., wecom
|
||||
|
||||
126
src/langbot/pkg/entity/persistence/workflow.py
Normal file
126
src/langbot/pkg/entity/persistence/workflow.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Workflow persistence entities"""
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from .base import Base
|
||||
|
||||
|
||||
class Workflow(Base):
|
||||
"""Workflow definition"""
|
||||
|
||||
__tablename__ = 'workflows'
|
||||
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
description = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
emoji = sqlalchemy.Column(sqlalchemy.String(10), nullable=True, default='🔄')
|
||||
version = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=1)
|
||||
is_enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
|
||||
|
||||
# Workflow definition stored as JSON
|
||||
# Contains: nodes, edges, variables, settings
|
||||
definition = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
|
||||
# Global config (inherited from Pipeline capabilities)
|
||||
# Contains: safety, output configs
|
||||
global_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
|
||||
# Extensions preferences (same as Pipeline)
|
||||
extensions_preferences = sqlalchemy.Column(
|
||||
sqlalchemy.JSON,
|
||||
nullable=False,
|
||||
default={'enable_all_plugins': True, 'enable_all_mcp_servers': True, 'plugins': [], 'mcp_servers': []},
|
||||
)
|
||||
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
server_default=sqlalchemy.func.now(),
|
||||
onupdate=sqlalchemy.func.now(),
|
||||
)
|
||||
|
||||
|
||||
class WorkflowVersion(Base):
|
||||
"""Workflow version history"""
|
||||
|
||||
__tablename__ = 'workflow_versions'
|
||||
|
||||
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||
workflow_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
version = sqlalchemy.Column(sqlalchemy.Integer, nullable=False)
|
||||
definition = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
|
||||
global_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
created_by = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
|
||||
__table_args__ = (sqlalchemy.UniqueConstraint('workflow_uuid', 'version', name='uq_workflow_version'),)
|
||||
|
||||
|
||||
class WorkflowTrigger(Base):
|
||||
"""Workflow trigger configuration"""
|
||||
|
||||
__tablename__ = 'workflow_triggers'
|
||||
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
workflow_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
type = sqlalchemy.Column(sqlalchemy.String(50), nullable=False) # message, cron, event, webhook
|
||||
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
is_enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
|
||||
priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
server_default=sqlalchemy.func.now(),
|
||||
onupdate=sqlalchemy.func.now(),
|
||||
)
|
||||
|
||||
|
||||
class WorkflowExecution(Base):
|
||||
"""Workflow execution record"""
|
||||
|
||||
__tablename__ = 'workflow_executions'
|
||||
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
workflow_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
workflow_version = sqlalchemy.Column(sqlalchemy.Integer, nullable=False)
|
||||
status = sqlalchemy.Column(sqlalchemy.String(20), nullable=False) # pending, running, completed, failed, cancelled
|
||||
trigger_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=True)
|
||||
trigger_data = sqlalchemy.Column(sqlalchemy.JSON, nullable=True)
|
||||
variables = sqlalchemy.Column(sqlalchemy.JSON, nullable=True)
|
||||
start_time = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
end_time = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
error = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
|
||||
|
||||
class WorkflowNodeExecution(Base):
|
||||
"""Workflow node execution record"""
|
||||
|
||||
__tablename__ = 'workflow_node_executions'
|
||||
|
||||
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||
execution_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
node_id = sqlalchemy.Column(sqlalchemy.String(100), nullable=False)
|
||||
node_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=False)
|
||||
status = sqlalchemy.Column(sqlalchemy.String(20), nullable=False) # pending, running, completed, failed, skipped
|
||||
inputs = sqlalchemy.Column(sqlalchemy.JSON, nullable=True)
|
||||
outputs = sqlalchemy.Column(sqlalchemy.JSON, nullable=True)
|
||||
start_time = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
end_time = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
error = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
retry_count = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||
|
||||
|
||||
class ScheduledJob(Base):
|
||||
"""Scheduled job for cron triggers"""
|
||||
|
||||
__tablename__ = 'workflow_scheduled_jobs'
|
||||
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
trigger_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
cron_expression = sqlalchemy.Column(sqlalchemy.String(100), nullable=True)
|
||||
next_run_time = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
last_run_time = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
is_enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
|
||||
0
src/langbot/pkg/persistence/alembic/__init__.py
Normal file
0
src/langbot/pkg/persistence/alembic/__init__.py
Normal file
51
src/langbot/pkg/persistence/alembic/env.py
Normal file
51
src/langbot/pkg/persistence/alembic/env.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Alembic environment for LangBot.
|
||||
|
||||
This env.py is designed to be called programmatically (not via CLI).
|
||||
It supports both SQLite and PostgreSQL.
|
||||
|
||||
The sync connection is passed via config attributes by the runner.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode — emit SQL without a live connection."""
|
||||
url = context.config.get_main_option('sqlalchemy.url')
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={'paramstyle': 'named'},
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations with a live sync connection passed via config attributes."""
|
||||
connection: Connection = context.config.attributes.get('connection')
|
||||
if connection is None:
|
||||
raise RuntimeError('connection not provided in alembic config attributes')
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
# render_as_batch=True is critical for SQLite ALTER TABLE support
|
||||
render_as_batch=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
24
src/langbot/pkg/persistence/alembic/script.py.mako
Normal file
24
src/langbot/pkg/persistence/alembic/script.py.mako
Normal file
@@ -0,0 +1,24 @@
|
||||
# Alembic script.py.mako — template for auto-generated revisions
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,24 @@
|
||||
"""baseline: stamp existing schema (db version 25)
|
||||
|
||||
This is a no-op migration that marks the starting point for Alembic.
|
||||
All tables already exist via create_all() + legacy DBMigration system.
|
||||
|
||||
Revision ID: 0001_baseline
|
||||
Revises: None
|
||||
Create Date: 2026-04-08
|
||||
"""
|
||||
|
||||
revision = '0001_baseline'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# No-op: existing schema is already at database_version=25
|
||||
# This revision serves as the Alembic baseline.
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
62
src/langbot/pkg/persistence/alembic/versions/0002_sample.py
Normal file
62
src/langbot/pkg/persistence/alembic/versions/0002_sample.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""example: sample migration demonstrating Alembic patterns
|
||||
|
||||
This is a SAMPLE showing how to write migrations that work
|
||||
seamlessly across SQLite and PostgreSQL. Delete or adapt as needed.
|
||||
|
||||
Revision ID: 0002_sample
|
||||
Revises: 0001_baseline
|
||||
Create Date: 2026-04-08
|
||||
|
||||
Patterns demonstrated:
|
||||
1. Schema change (add column) — works on both DBs via render_as_batch
|
||||
2. Data migration (read + modify JSON) — pure SQLAlchemy, no dialect branching
|
||||
"""
|
||||
|
||||
revision = '0002_sample'
|
||||
down_revision = '0001_baseline'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
EXAMPLE: Uncomment to use. This shows the patterns.
|
||||
|
||||
# --- Pattern 1: Schema change (add/drop column) ---
|
||||
# render_as_batch=True in env.py makes this work on SQLite too.
|
||||
#
|
||||
# op.add_column('pipelines', sa.Column('description', sa.String(512), server_default=''))
|
||||
|
||||
# --- Pattern 2: Data migration (read + modify JSON field) ---
|
||||
# No if/else for sqlite vs postgres needed!
|
||||
#
|
||||
# conn = op.get_bind()
|
||||
# rows = conn.execute(sa.text("SELECT uuid, config FROM pipelines")).fetchall()
|
||||
# for row in rows:
|
||||
# config = json.loads(row[1]) if isinstance(row[1], str) else row[1]
|
||||
# # Modify the config
|
||||
# config.setdefault('ai', {}).setdefault('some_new_key', 'default_value')
|
||||
# conn.execute(
|
||||
# sa.text("UPDATE pipelines SET config = :cfg WHERE uuid = :uuid"),
|
||||
# {"cfg": json.dumps(config), "uuid": row[0]}
|
||||
# )
|
||||
|
||||
# --- Pattern 3: Create a new table ---
|
||||
#
|
||||
# op.create_table(
|
||||
# 'audit_log',
|
||||
# sa.Column('id', sa.Integer, primary_key=True, autoincrement=True),
|
||||
# sa.Column('action', sa.String(255), nullable=False),
|
||||
# sa.Column('detail', sa.Text),
|
||||
# sa.Column('created_at', sa.DateTime, server_default=sa.func.now()),
|
||||
# )
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""
|
||||
# op.drop_column('pipelines', 'description')
|
||||
# op.drop_table('audit_log')
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,35 @@
|
||||
"""add rerank_models table
|
||||
|
||||
Revision ID: 0003_add_rerank_models
|
||||
Revises: 0002_sample
|
||||
Create Date: 2026-04-19
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '0003_add_rerank_models'
|
||||
down_revision = '0002_sample'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Check if table already exists (may have been created by create_all())
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
if 'rerank_models' not in inspector.get_table_names():
|
||||
op.create_table(
|
||||
'rerank_models',
|
||||
sa.Column('uuid', sa.String(255), primary_key=True, unique=True),
|
||||
sa.Column('name', sa.String(255), nullable=False),
|
||||
sa.Column('provider_uuid', sa.String(255), nullable=False),
|
||||
sa.Column('extra_args', sa.JSON, nullable=False, server_default='{}'),
|
||||
sa.Column('prefered_ranking', sa.Integer, nullable=False, server_default='0'),
|
||||
sa.Column('created_at', sa.DateTime, nullable=False, server_default=sa.func.now()),
|
||||
sa.Column('updated_at', sa.DateTime, nullable=False, server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('rerank_models')
|
||||
150
src/langbot/pkg/persistence/alembic_runner.py
Normal file
150
src/langbot/pkg/persistence/alembic_runner.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Programmatic Alembic runner for LangBot.
|
||||
|
||||
Usage from async code:
|
||||
from langbot.pkg.persistence.alembic_runner import run_alembic_upgrade
|
||||
await run_alembic_upgrade(async_engine)
|
||||
|
||||
CLI usage (autogenerate):
|
||||
python -m langbot.pkg.persistence.alembic_runner autogenerate "add description column"
|
||||
python -m langbot.pkg.persistence.alembic_runner upgrade
|
||||
python -m langbot.pkg.persistence.alembic_runner current
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from alembic.config import Config
|
||||
from alembic import command
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
|
||||
_ALEMBIC_DIR = os.path.join(os.path.dirname(__file__), 'alembic')
|
||||
|
||||
|
||||
def _build_config(connection: Connection) -> Config:
|
||||
"""Build an Alembic Config with sync connection attached."""
|
||||
cfg = Config()
|
||||
cfg.set_main_option('script_location', _ALEMBIC_DIR)
|
||||
cfg.attributes['connection'] = connection
|
||||
return cfg
|
||||
|
||||
|
||||
def _do_upgrade(connection: Connection, revision: str = 'head') -> None:
|
||||
"""Synchronous upgrade — runs inside run_sync."""
|
||||
cfg = _build_config(connection)
|
||||
command.upgrade(cfg, revision)
|
||||
|
||||
|
||||
def _do_stamp(connection: Connection, revision: str = 'head') -> None:
|
||||
"""Synchronous stamp — runs inside run_sync."""
|
||||
cfg = _build_config(connection)
|
||||
command.stamp(cfg, revision)
|
||||
|
||||
|
||||
def _do_get_current(connection: Connection) -> str | None:
|
||||
"""Get current alembic revision synchronously."""
|
||||
ctx = MigrationContext.configure(connection)
|
||||
return ctx.get_current_revision()
|
||||
|
||||
|
||||
def _do_autogenerate(connection: Connection, message: str = 'auto migration') -> None:
|
||||
"""Synchronous autogenerate — runs inside run_sync."""
|
||||
cfg = _build_config(connection)
|
||||
command.revision(cfg, message=message, autogenerate=True)
|
||||
|
||||
|
||||
async def run_alembic_upgrade(async_engine: AsyncEngine, revision: str = 'head') -> None:
|
||||
"""Run Alembic upgrade to the given revision."""
|
||||
async with async_engine.connect() as conn:
|
||||
await conn.run_sync(_do_upgrade, revision)
|
||||
await conn.commit()
|
||||
|
||||
|
||||
async def run_alembic_stamp(async_engine: AsyncEngine, revision: str = 'head') -> None:
|
||||
"""Stamp the database with a revision without running migrations."""
|
||||
async with async_engine.connect() as conn:
|
||||
await conn.run_sync(_do_stamp, revision)
|
||||
await conn.commit()
|
||||
|
||||
|
||||
async def get_alembic_current(async_engine: AsyncEngine) -> str | None:
|
||||
"""Get current alembic revision, or None if not stamped."""
|
||||
async with async_engine.connect() as conn:
|
||||
return await conn.run_sync(_do_get_current)
|
||||
|
||||
|
||||
async def run_alembic_autogenerate(async_engine: AsyncEngine, message: str = 'auto migration') -> None:
|
||||
"""Compare ORM models against DB schema and generate a migration script."""
|
||||
async with async_engine.connect() as conn:
|
||||
await conn.run_sync(_do_autogenerate, message)
|
||||
|
||||
|
||||
# CLI entrypoint: python -m langbot.pkg.persistence.alembic_runner <command> [args]
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
def _get_engine():
|
||||
"""Create engine from data/config.yaml or default SQLite."""
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
try:
|
||||
import yaml
|
||||
|
||||
with open('data/config.yaml') as f:
|
||||
config = yaml.safe_load(f)
|
||||
db_cfg = config.get('database', {})
|
||||
db_type = db_cfg.get('use', 'sqlite')
|
||||
if db_type == 'postgresql':
|
||||
pg = db_cfg.get('postgresql', {})
|
||||
url = (
|
||||
f'postgresql+asyncpg://{pg.get("user", "postgres")}:{pg.get("password", "postgres")}'
|
||||
f'@{pg.get("host", "127.0.0.1")}:{pg.get("port", 5432)}/{pg.get("database", "postgres")}'
|
||||
)
|
||||
else:
|
||||
path = db_cfg.get('sqlite', {}).get('path', 'data/langbot.db')
|
||||
url = f'sqlite+aiosqlite:///{path}'
|
||||
except Exception:
|
||||
url = 'sqlite+aiosqlite:///data/langbot.db'
|
||||
|
||||
return create_async_engine(url)
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print('Usage: python -m langbot.pkg.persistence.alembic_runner <command> [args]')
|
||||
print('Commands:')
|
||||
print(' autogenerate "message" — Generate migration from ORM model diff')
|
||||
print(' upgrade [revision] — Upgrade database (default: head)')
|
||||
print(' stamp [revision] — Stamp revision without running (default: head)')
|
||||
print(' current — Show current revision')
|
||||
sys.exit(1)
|
||||
|
||||
cmd = sys.argv[1]
|
||||
engine = _get_engine()
|
||||
|
||||
if cmd == 'autogenerate':
|
||||
msg = sys.argv[2] if len(sys.argv) > 2 else 'auto migration'
|
||||
asyncio.run(run_alembic_autogenerate(engine, msg))
|
||||
print(f'Migration generated: {msg}')
|
||||
elif cmd == 'upgrade':
|
||||
rev = sys.argv[2] if len(sys.argv) > 2 else 'head'
|
||||
asyncio.run(run_alembic_upgrade(engine, rev))
|
||||
print(f'Upgraded to: {rev}')
|
||||
elif cmd == 'stamp':
|
||||
rev = sys.argv[2] if len(sys.argv) > 2 else 'head'
|
||||
asyncio.run(run_alembic_stamp(engine, rev))
|
||||
print(f'Stamped: {rev}')
|
||||
elif cmd == 'current':
|
||||
rev = asyncio.run(get_alembic_current(engine))
|
||||
print(f'Current revision: {rev}')
|
||||
else:
|
||||
print(f'Unknown command: {cmd}')
|
||||
sys.exit(1)
|
||||
|
||||
main()
|
||||
@@ -2,18 +2,16 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import typing
|
||||
import json
|
||||
import uuid
|
||||
|
||||
|
||||
import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
|
||||
import sqlalchemy
|
||||
|
||||
from . import database, migration
|
||||
from ..entity.persistence import base, pipeline, metadata, model as persistence_model
|
||||
from ..entity.persistence import base, metadata, model as persistence_model
|
||||
from ..entity import persistence
|
||||
from ..core import app
|
||||
from ..utils import constants, importutil
|
||||
from ..api.http.service import pipeline as pipeline_service
|
||||
from . import databases, migrations
|
||||
|
||||
importutil.import_modules_in_pkg(databases)
|
||||
@@ -78,7 +76,9 @@ class PersistenceManager:
|
||||
|
||||
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
|
||||
|
||||
await self.write_default_pipeline()
|
||||
# Run Alembic migrations (new migration system)
|
||||
await self._run_alembic_migrations()
|
||||
|
||||
await self.write_space_model_providers()
|
||||
|
||||
async def create_tables(self):
|
||||
@@ -101,29 +101,6 @@ class PersistenceManager:
|
||||
if row is None:
|
||||
await self.execute_async(sqlalchemy.insert(metadata.Metadata).values(item))
|
||||
|
||||
async def write_default_pipeline(self):
|
||||
# write default pipeline
|
||||
result = await self.execute_async(sqlalchemy.select(pipeline.LegacyPipeline))
|
||||
default_pipeline_uuid = None
|
||||
if result.first() is None:
|
||||
self.ap.logger.info('Creating default pipeline...')
|
||||
|
||||
pipeline_config = json.loads(importutil.read_resource_file('templates/default-pipeline-config.json'))
|
||||
|
||||
default_pipeline_uuid = str(uuid.uuid4())
|
||||
pipeline_data = {
|
||||
'uuid': default_pipeline_uuid,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
'stages': pipeline_service.default_stage_order,
|
||||
'is_default': True,
|
||||
'name': 'ChatPipeline',
|
||||
'description': 'Default pipeline, new bots will be bound to this pipeline | 默认提供的流水线,您配置的机器人将自动绑定到此流水线',
|
||||
'config': pipeline_config,
|
||||
'extensions_preferences': {},
|
||||
}
|
||||
|
||||
await self.execute_async(sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data))
|
||||
|
||||
async def write_space_model_providers(self):
|
||||
space_models_gateway_api_url = self.ap.instance_config.data.get('space', {}).get(
|
||||
'models_gateway_api_url', 'https://api.langbot.cloud/v1'
|
||||
@@ -161,6 +138,28 @@ class PersistenceManager:
|
||||
|
||||
# =================================
|
||||
|
||||
async def _run_alembic_migrations(self):
|
||||
"""Run Alembic-based migrations after legacy migrations complete."""
|
||||
from . import alembic_runner
|
||||
|
||||
engine = self.get_db_engine()
|
||||
|
||||
try:
|
||||
current_rev = await alembic_runner.get_alembic_current(engine)
|
||||
|
||||
if current_rev is None:
|
||||
# First time: stamp baseline so Alembic knows existing schema is up-to-date
|
||||
self.ap.logger.info('Alembic: no revision found, stamping baseline...')
|
||||
await alembic_runner.run_alembic_stamp(engine, '0001_baseline')
|
||||
current_rev = '0001_baseline'
|
||||
|
||||
# Upgrade to head
|
||||
await alembic_runner.run_alembic_upgrade(engine, 'head')
|
||||
self.ap.logger.info('Alembic migrations completed.')
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Alembic migration failed: {e}', exc_info=True)
|
||||
raise
|
||||
|
||||
async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult:
|
||||
async with self.get_db_engine().connect() as conn:
|
||||
result = await conn.execute(*args, **kwargs)
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(21)
|
||||
class DBMigrateMergeExceptionHandling(migration.DBMigration):
|
||||
"""Merge hide-exception and block-failed-request-output into a single exception-handling select option,
|
||||
and add failure-hint field.
|
||||
|
||||
Conversion logic:
|
||||
- block-failed-request-output=true -> exception-handling: hide
|
||||
- hide-exception=true -> exception-handling: show-hint
|
||||
- hide-exception=false -> exception-handling: show-error
|
||||
"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
if 'output' not in config:
|
||||
config['output'] = {}
|
||||
if 'misc' not in config['output']:
|
||||
config['output']['misc'] = {}
|
||||
|
||||
misc = config['output']['misc']
|
||||
|
||||
# Determine new exception-handling value from legacy fields
|
||||
hide_exception = misc.get('hide-exception', True)
|
||||
block_failed = misc.get('block-failed-request-output', False)
|
||||
|
||||
if block_failed:
|
||||
exception_handling = 'hide'
|
||||
elif hide_exception:
|
||||
exception_handling = 'show-hint'
|
||||
else:
|
||||
exception_handling = 'show-error'
|
||||
|
||||
misc['exception-handling'] = exception_handling
|
||||
|
||||
# Add failure-hint with default value
|
||||
misc['failure-hint'] = 'Request failed.'
|
||||
|
||||
# Remove legacy fields
|
||||
misc.pop('hide-exception', None)
|
||||
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
pass
|
||||
@@ -0,0 +1,73 @@
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(22)
|
||||
class DBMigrateMonitoringUserId(migration.DBMigration):
|
||||
"""Add user_id and user_name columns to monitoring_sessions table
|
||||
|
||||
This migration adds the missing user_id column and also ensures user_name
|
||||
column exists (in case migration 21 failed or was skipped).
|
||||
"""
|
||||
|
||||
async def _table_exists(self, table_name: str) -> bool:
|
||||
"""Check if a table exists (works for both SQLite and PostgreSQL)."""
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = :table_name);'
|
||||
).bindparams(table_name=table_name)
|
||||
)
|
||||
return bool(result.scalar())
|
||||
else:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name;").bindparams(
|
||||
table_name=table_name
|
||||
)
|
||||
)
|
||||
return result.first() is not None
|
||||
|
||||
async def _get_table_columns(self, table_name: str) -> list[str]:
|
||||
"""Get column names from a table (works for both SQLite and PostgreSQL)."""
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'SELECT column_name FROM information_schema.columns WHERE table_name = :table_name;'
|
||||
).bindparams(table_name=table_name)
|
||||
)
|
||||
return [row[0] for row in result.fetchall()]
|
||||
else:
|
||||
if not table_name.isidentifier():
|
||||
raise ValueError(f'Invalid table name: {table_name}')
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name});'))
|
||||
return [row[1] for row in result.fetchall()]
|
||||
|
||||
async def _add_column_if_not_exists(self, table_name: str, column_name: str, column_type: str):
|
||||
"""Add a column to a table if it does not already exist."""
|
||||
columns = await self._get_table_columns(table_name)
|
||||
if column_name in columns:
|
||||
self.ap.logger.debug('%s column already exists in %s.', column_name, table_name)
|
||||
return
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(f'ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type};')
|
||||
)
|
||||
self.ap.logger.info('Added %s column to %s table.', column_name, table_name)
|
||||
|
||||
async def upgrade(self):
|
||||
# Check if monitoring_sessions table exists
|
||||
if not await self._table_exists('monitoring_sessions'):
|
||||
self.ap.logger.warning('monitoring_sessions table does not exist, skipping migration.')
|
||||
return
|
||||
|
||||
# Add user_id column to monitoring_sessions table
|
||||
await self._add_column_if_not_exists('monitoring_sessions', 'user_id', 'VARCHAR(255)')
|
||||
|
||||
# Add user_name column to monitoring_sessions table (in case migration 21 failed)
|
||||
await self._add_column_if_not_exists('monitoring_sessions', 'user_name', 'VARCHAR(255)')
|
||||
|
||||
# Add user_name column to monitoring_messages table (in case migration 21 failed)
|
||||
if await self._table_exists('monitoring_messages'):
|
||||
await self._add_column_if_not_exists('monitoring_messages', 'user_name', 'VARCHAR(255)')
|
||||
|
||||
async def downgrade(self):
|
||||
pass
|
||||
@@ -0,0 +1,102 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(23)
|
||||
class DBMigrateModelFallbackConfig(migration.DBMigration):
|
||||
"""Convert model field from plain UUID string to object with primary/fallbacks"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
if 'ai' not in config or 'local-agent' not in config['ai']:
|
||||
continue
|
||||
|
||||
local_agent = config['ai']['local-agent']
|
||||
changed = False
|
||||
|
||||
# Convert model from string to object
|
||||
model_value = local_agent.get('model', '')
|
||||
if isinstance(model_value, str):
|
||||
local_agent['model'] = {
|
||||
'primary': model_value,
|
||||
'fallbacks': [],
|
||||
}
|
||||
changed = True
|
||||
|
||||
# Remove leftover fallback-models field if present
|
||||
if 'fallback-models' in local_agent:
|
||||
del local_agent['fallback-models']
|
||||
changed = True
|
||||
|
||||
if not changed:
|
||||
continue
|
||||
|
||||
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
if 'ai' not in config or 'local-agent' not in config['ai']:
|
||||
continue
|
||||
|
||||
local_agent = config['ai']['local-agent']
|
||||
|
||||
# Convert model from object back to string
|
||||
model_value = local_agent.get('model', '')
|
||||
if isinstance(model_value, dict):
|
||||
local_agent['model'] = model_value.get('primary', '')
|
||||
else:
|
||||
continue
|
||||
|
||||
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
@@ -0,0 +1,49 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(24)
|
||||
class DBMigrateWecomBotWebSocketMode(migration.DBMigration):
|
||||
"""Add enable-webhook field to existing wecombot adapter configs.
|
||||
|
||||
Existing wecombot bots were all using webhook mode, so we set
|
||||
enable-webhook=true to preserve their behavior after the new
|
||||
WebSocket long connection mode is introduced as default.
|
||||
"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("SELECT uuid, adapter_config FROM bots WHERE adapter = 'wecombot'")
|
||||
)
|
||||
bots = result.fetchall()
|
||||
|
||||
for bot_row in bots:
|
||||
bot_uuid = bot_row[0]
|
||||
adapter_config = json.loads(bot_row[1]) if isinstance(bot_row[1], str) else bot_row[1]
|
||||
|
||||
if 'enable-webhook' in adapter_config:
|
||||
continue
|
||||
|
||||
# Determine mode based on existing config: if webhook fields are present, keep webhook mode
|
||||
has_webhook_config = bool(
|
||||
adapter_config.get('Token') and adapter_config.get('EncodingAESKey') and adapter_config.get('Corpid')
|
||||
)
|
||||
adapter_config['enable-webhook'] = has_webhook_config
|
||||
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('UPDATE bots SET adapter_config = :config::jsonb WHERE uuid = :uuid'),
|
||||
{'config': json.dumps(adapter_config), 'uuid': bot_uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('UPDATE bots SET adapter_config = :config WHERE uuid = :uuid'),
|
||||
{'config': json.dumps(adapter_config), 'uuid': bot_uuid},
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
pass
|
||||
@@ -0,0 +1,15 @@
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(25)
|
||||
class DBMigrateBotPipelineRoutingRules(migration.DBMigration):
|
||||
"""Add pipeline_routing_rules column to bots table"""
|
||||
|
||||
async def upgrade(self):
|
||||
sql_text = sqlalchemy.text("ALTER TABLE bots ADD COLUMN pipeline_routing_rules JSON NOT NULL DEFAULT '[]'")
|
||||
await self.ap.persistence_mgr.execute_async(sql_text)
|
||||
|
||||
async def downgrade(self):
|
||||
sql_text = sqlalchemy.text('ALTER TABLE bots DROP COLUMN pipeline_routing_rules')
|
||||
await self.ap.persistence_mgr.execute_async(sql_text)
|
||||
158
src/langbot/pkg/persistence/migrations/dbm026_workflow_tables.py
Normal file
158
src/langbot/pkg/persistence/migrations/dbm026_workflow_tables.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Add workflow tables and update bot binding fields"""
|
||||
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(26)
|
||||
class DBMigrateWorkflowTables(migration.DBMigration):
|
||||
"""Add workflow tables and update bot binding fields"""
|
||||
|
||||
async def upgrade(self):
|
||||
# Create workflows table
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("""
|
||||
CREATE TABLE IF NOT EXISTS workflows (
|
||||
uuid VARCHAR(255) PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
emoji VARCHAR(10) DEFAULT '🔄',
|
||||
version INTEGER NOT NULL DEFAULT 1,
|
||||
is_enabled BOOLEAN NOT NULL DEFAULT 1,
|
||||
definition JSON NOT NULL DEFAULT '{}',
|
||||
global_config JSON NOT NULL DEFAULT '{}',
|
||||
extensions_preferences JSON NOT NULL DEFAULT '{"enable_all_plugins": true, "enable_all_mcp_servers": true, "plugins": [], "mcp_servers": []}',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
# Create workflow_versions table
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("""
|
||||
CREATE TABLE IF NOT EXISTS workflow_versions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
workflow_uuid VARCHAR(255) NOT NULL,
|
||||
version INTEGER NOT NULL,
|
||||
definition JSON NOT NULL,
|
||||
global_config JSON NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
created_by VARCHAR(255),
|
||||
UNIQUE(workflow_uuid, version)
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
# Create workflow_triggers table
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("""
|
||||
CREATE TABLE IF NOT EXISTS workflow_triggers (
|
||||
uuid VARCHAR(255) PRIMARY KEY,
|
||||
workflow_uuid VARCHAR(255) NOT NULL,
|
||||
type VARCHAR(50) NOT NULL,
|
||||
config JSON NOT NULL DEFAULT '{}',
|
||||
is_enabled BOOLEAN NOT NULL DEFAULT 1,
|
||||
priority INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
# Create workflow_executions table
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("""
|
||||
CREATE TABLE IF NOT EXISTS workflow_executions (
|
||||
uuid VARCHAR(255) PRIMARY KEY,
|
||||
workflow_uuid VARCHAR(255) NOT NULL,
|
||||
workflow_version INTEGER NOT NULL,
|
||||
status VARCHAR(20) NOT NULL,
|
||||
trigger_type VARCHAR(50),
|
||||
trigger_data JSON,
|
||||
variables JSON,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
error TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
# Create workflow_node_executions table
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("""
|
||||
CREATE TABLE IF NOT EXISTS workflow_node_executions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
execution_uuid VARCHAR(255) NOT NULL,
|
||||
node_id VARCHAR(100) NOT NULL,
|
||||
node_type VARCHAR(50) NOT NULL,
|
||||
status VARCHAR(20) NOT NULL,
|
||||
inputs JSON,
|
||||
outputs JSON,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
error TEXT,
|
||||
retry_count INTEGER NOT NULL DEFAULT 0
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
# Create workflow_scheduled_jobs table
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("""
|
||||
CREATE TABLE IF NOT EXISTS workflow_scheduled_jobs (
|
||||
uuid VARCHAR(255) PRIMARY KEY,
|
||||
trigger_uuid VARCHAR(255) NOT NULL,
|
||||
cron_expression VARCHAR(100),
|
||||
next_run_time TIMESTAMP,
|
||||
last_run_time TIMESTAMP,
|
||||
is_enabled BOOLEAN NOT NULL DEFAULT 1
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('CREATE INDEX IF NOT EXISTS idx_workflow_versions_uuid ON workflow_versions(workflow_uuid)')
|
||||
)
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('CREATE INDEX IF NOT EXISTS idx_workflow_triggers_uuid ON workflow_triggers(workflow_uuid)')
|
||||
)
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'CREATE INDEX IF NOT EXISTS idx_workflow_executions_uuid ON workflow_executions(workflow_uuid)'
|
||||
)
|
||||
)
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'CREATE INDEX IF NOT EXISTS idx_workflow_node_executions_uuid ON workflow_node_executions(execution_uuid)'
|
||||
)
|
||||
)
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'CREATE INDEX IF NOT EXISTS idx_workflow_scheduled_jobs_trigger ON workflow_scheduled_jobs(trigger_uuid)'
|
||||
)
|
||||
)
|
||||
|
||||
# Update bots table: add binding_type column (default to 'pipeline' for backward compatibility)
|
||||
# Check if column exists first (SQLite doesn't support IF NOT EXISTS for columns)
|
||||
try:
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.text('SELECT binding_type FROM bots LIMIT 1'))
|
||||
except Exception:
|
||||
# Column doesn't exist, add it
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("ALTER TABLE bots ADD COLUMN binding_type VARCHAR(20) NOT NULL DEFAULT 'pipeline'")
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
# Drop tables in reverse order
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.text('DROP TABLE IF EXISTS workflow_scheduled_jobs'))
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.text('DROP TABLE IF EXISTS workflow_node_executions'))
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.text('DROP TABLE IF EXISTS workflow_executions'))
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.text('DROP TABLE IF EXISTS workflow_triggers'))
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.text('DROP TABLE IF EXISTS workflow_versions'))
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.text('DROP TABLE IF EXISTS workflows'))
|
||||
|
||||
# Remove binding_type column from bots (SQLite doesn't support DROP COLUMN directly)
|
||||
# This would need a table recreation in SQLite, so we'll skip it in downgrade
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Add binding_uuid field to bots table and migrate data"""
|
||||
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(27)
|
||||
class DBMigrateBotBindingFields(migration.DBMigration):
|
||||
"""Add binding_uuid field to bots table and migrate existing data"""
|
||||
|
||||
async def upgrade(self):
|
||||
# Add binding_uuid column to bots table
|
||||
# Check if column exists first (SQLite doesn't support IF NOT EXISTS for columns)
|
||||
try:
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.text('SELECT binding_uuid FROM bots LIMIT 1'))
|
||||
except Exception:
|
||||
# Column doesn't exist, add it
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE bots ADD COLUMN binding_uuid VARCHAR(64)')
|
||||
)
|
||||
|
||||
# Migrate existing data: copy use_pipeline_uuid to binding_uuid for records
|
||||
# that have a pipeline bound and binding_uuid is not set yet
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("""
|
||||
UPDATE bots
|
||||
SET binding_uuid = use_pipeline_uuid
|
||||
WHERE use_pipeline_uuid IS NOT NULL
|
||||
AND use_pipeline_uuid != ''
|
||||
AND (binding_uuid IS NULL OR binding_uuid = '')
|
||||
""")
|
||||
)
|
||||
|
||||
# Ensure binding_type is 'pipeline' for records that were migrated
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("""
|
||||
UPDATE bots
|
||||
SET binding_type = 'pipeline'
|
||||
WHERE binding_uuid IS NOT NULL
|
||||
AND binding_uuid != ''
|
||||
AND (binding_type IS NULL OR binding_type = '')
|
||||
""")
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
# SQLite doesn't support DROP COLUMN directly
|
||||
# This would need a table recreation in SQLite, so we'll skip it in downgrade
|
||||
# The column will remain but won't be used
|
||||
pass
|
||||
@@ -37,6 +37,7 @@ class PendingMessage:
|
||||
message_chain: platform_message.MessageChain
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
pipeline_uuid: typing.Optional[str]
|
||||
routed_by_rule: bool = False
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@@ -125,6 +126,7 @@ class MessageAggregator:
|
||||
message_chain: platform_message.MessageChain,
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||
pipeline_uuid: typing.Optional[str] = None,
|
||||
routed_by_rule: bool = False,
|
||||
) -> None:
|
||||
"""Add a message to the aggregation buffer
|
||||
|
||||
@@ -145,6 +147,7 @@ class MessageAggregator:
|
||||
message_chain=message_chain,
|
||||
adapter=adapter,
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
routed_by_rule=routed_by_rule,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -159,6 +162,7 @@ class MessageAggregator:
|
||||
message_chain=message_chain,
|
||||
adapter=adapter,
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
routed_by_rule=routed_by_rule,
|
||||
)
|
||||
|
||||
force_flush = False
|
||||
@@ -217,6 +221,7 @@ class MessageAggregator:
|
||||
message_chain=msg.message_chain,
|
||||
adapter=msg.adapter,
|
||||
pipeline_uuid=msg.pipeline_uuid,
|
||||
routed_by_rule=msg.routed_by_rule,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -231,6 +236,7 @@ class MessageAggregator:
|
||||
message_chain=merged_msg.message_chain,
|
||||
adapter=merged_msg.adapter,
|
||||
pipeline_uuid=merged_msg.pipeline_uuid,
|
||||
routed_by_rule=merged_msg.routed_by_rule,
|
||||
)
|
||||
|
||||
def _merge_messages(self, messages: list[PendingMessage]) -> PendingMessage:
|
||||
@@ -269,6 +275,7 @@ class MessageAggregator:
|
||||
message_chain=merged_chain,
|
||||
adapter=base_msg.adapter,
|
||||
pipeline_uuid=base_msg.pipeline_uuid,
|
||||
routed_by_rule=any(msg.routed_by_rule for msg in messages),
|
||||
)
|
||||
|
||||
async def flush_all(self) -> None:
|
||||
|
||||
@@ -63,6 +63,14 @@ class Controller:
|
||||
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(pipeline_uuid)
|
||||
if pipeline:
|
||||
await pipeline.run(selected_query)
|
||||
else:
|
||||
self.ap.logger.warning(
|
||||
f'Pipeline {pipeline_uuid} not found for query {selected_query.query_id}, query dropped'
|
||||
)
|
||||
else:
|
||||
self.ap.logger.warning(
|
||||
f'No pipeline_uuid for query {selected_query.query_id}, query dropped'
|
||||
)
|
||||
|
||||
async with self.ap.query_pool:
|
||||
(await self.ap.sess_mgr.get_session(selected_query))._semaphore.release()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user