Compare commits
2 Commits
fix/plugin
...
feat/human
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d3d366b569 | ||
|
|
db68c5d0c9 |
25
.github/workflows/check-i18n.yml
vendored
@@ -1,25 +0,0 @@
|
||||
name: Check i18n Keys
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
|
||||
jobs:
|
||||
check-i18n:
|
||||
name: Check i18n Key Consistency
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Check i18n keys against en-US reference
|
||||
run: node web/scripts/check-i18n.mjs
|
||||
171
.github/workflows/test-migrations.yml
vendored
@@ -1,171 +0,0 @@
|
||||
name: Test Migrations
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
- dev
|
||||
paths:
|
||||
- 'src/langbot/pkg/persistence/**'
|
||||
- 'src/langbot/pkg/entity/persistence/**'
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- 'src/langbot/pkg/persistence/**'
|
||||
- 'src/langbot/pkg/entity/persistence/**'
|
||||
|
||||
jobs:
|
||||
test-migrations-sqlite:
|
||||
name: Migrations (SQLite)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Test Alembic upgrade (SQLite)
|
||||
run: |
|
||||
uv run python -c "
|
||||
import asyncio
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.persistence.alembic_runner import run_alembic_upgrade, run_alembic_stamp, get_alembic_current
|
||||
|
||||
async def main():
|
||||
engine = create_async_engine('sqlite+aiosqlite:///test_migrations.db')
|
||||
|
||||
# Create all tables (simulates existing DB)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp baseline
|
||||
await run_alembic_stamp(engine, '0001_baseline')
|
||||
rev = await get_alembic_current(engine)
|
||||
assert rev == '0001_baseline', f'Expected 0001_baseline, got {rev}'
|
||||
print(f'Stamped: {rev}')
|
||||
|
||||
# Upgrade to head
|
||||
await run_alembic_upgrade(engine, 'head')
|
||||
rev = await get_alembic_current(engine)
|
||||
print(f'After upgrade: {rev}')
|
||||
assert rev is not None, 'Expected a revision after upgrade'
|
||||
|
||||
# Verify idempotent
|
||||
await run_alembic_upgrade(engine, 'head')
|
||||
rev2 = await get_alembic_current(engine)
|
||||
assert rev2 == rev, f'Expected {rev}, got {rev2}'
|
||||
print(f'Idempotent check passed: {rev2}')
|
||||
|
||||
# Fresh DB: upgrade from scratch
|
||||
engine2 = create_async_engine('sqlite+aiosqlite:///test_migrations_fresh.db')
|
||||
async with engine2.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await run_alembic_upgrade(engine2, 'head')
|
||||
rev3 = await get_alembic_current(engine2)
|
||||
print(f'Fresh DB upgrade: {rev3}')
|
||||
assert rev3 is not None
|
||||
|
||||
print('All SQLite migration tests passed!')
|
||||
|
||||
asyncio.run(main())
|
||||
"
|
||||
|
||||
test-migrations-postgres:
|
||||
name: Migrations (PostgreSQL)
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:16
|
||||
env:
|
||||
POSTGRES_USER: langbot
|
||||
POSTGRES_PASSWORD: langbot
|
||||
POSTGRES_DB: langbot_test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd="pg_isready -U langbot"
|
||||
--health-interval=5s
|
||||
--health-timeout=5s
|
||||
--health-retries=5
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Test Alembic upgrade (PostgreSQL)
|
||||
run: |
|
||||
uv run python -c "
|
||||
import asyncio
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.persistence.alembic_runner import run_alembic_upgrade, run_alembic_stamp, get_alembic_current
|
||||
|
||||
DB_URL = 'postgresql+asyncpg://langbot:langbot@localhost:5432/langbot_test'
|
||||
|
||||
async def main():
|
||||
engine = create_async_engine(DB_URL)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp baseline
|
||||
await run_alembic_stamp(engine, '0001_baseline')
|
||||
rev = await get_alembic_current(engine)
|
||||
assert rev == '0001_baseline', f'Expected 0001_baseline, got {rev}'
|
||||
print(f'Stamped: {rev}')
|
||||
|
||||
# Upgrade to head
|
||||
await run_alembic_upgrade(engine, 'head')
|
||||
rev = await get_alembic_current(engine)
|
||||
print(f'After upgrade: {rev}')
|
||||
assert rev is not None
|
||||
|
||||
# Verify idempotent
|
||||
await run_alembic_upgrade(engine, 'head')
|
||||
rev2 = await get_alembic_current(engine)
|
||||
assert rev2 == rev, f'Expected {rev}, got {rev2}'
|
||||
print(f'Idempotent check passed: {rev2}')
|
||||
|
||||
# Fresh DB: drop all and upgrade from scratch
|
||||
engine2 = create_async_engine(DB_URL.replace('langbot_test', 'langbot_fresh'))
|
||||
|
||||
# Create fresh database
|
||||
from sqlalchemy import text
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text('COMMIT'))
|
||||
await conn.execute(text('CREATE DATABASE langbot_fresh'))
|
||||
|
||||
async with engine2.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await run_alembic_upgrade(engine2, 'head')
|
||||
rev3 = await get_alembic_current(engine2)
|
||||
print(f'Fresh DB upgrade: {rev3}')
|
||||
assert rev3 is not None
|
||||
|
||||
print('All PostgreSQL migration tests passed!')
|
||||
|
||||
asyncio.run(main())
|
||||
"
|
||||
1
.gitignore
vendored
@@ -47,7 +47,6 @@ plugins.bak
|
||||
coverage.xml
|
||||
.coverage
|
||||
src/langbot/web/
|
||||
testsdk/
|
||||
|
||||
# Build artifacts
|
||||
/dist
|
||||
|
||||
@@ -70,7 +70,7 @@ Plugin Runtime automatically starts each installed plugin and interacts through
|
||||
- type: must be a specific type, such as feat (new feature), fix (bug fix), docs (documentation), style (code style), refactor (refactoring), perf (performance optimization), etc.
|
||||
- scope: the scope of the commit, such as the package name, the file name, the function name, the class name, the module name, etc.
|
||||
- subject: the subject of the commit, such as the description of the commit, the reason for the commit, the impact of the commit, etc.
|
||||
- LangBot uses [Alembic](https://alembic.sqlalchemy.org/) to manage database migrations, supporting both SQLite and PostgreSQL. Migration files are located in `src/langbot/pkg/persistence/alembic/versions/`. If you changed the definition of database entities (ORM models), generate a new migration script by running `uv run python -m langbot.pkg.persistence.alembic_runner autogenerate "description of your change"` in the project root (requires `data/config.yaml` to exist). Review and edit the generated script before committing. Migrations are executed automatically on LangBot startup. For data migrations (e.g. modifying JSON field content), you need to manually add the migration code in the generated script.
|
||||
- If you changed the definition of database entities, please update the migration file in `src/langbot/pkg/persistence/migrations/` and update the constants.py file in `src/langbot/pkg/utils/constants.py` with the new migration number.
|
||||
|
||||
## Some Principles
|
||||
|
||||
|
||||
80
README.md
@@ -84,48 +84,45 @@ docker compose up -d
|
||||
|
||||
| Platform | Status | Notes |
|
||||
|----------|--------|-------|
|
||||
| Discord | ✅ | Official |
|
||||
| Telegram | ✅ | Official |
|
||||
| Slack | ✅ | Official |
|
||||
| LINE | ✅ | Official |
|
||||
| QQ | ✅ | Personal & Official API (Channel, DM, Group) |
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | Personal & Official API |
|
||||
| WeCom | ✅ | Enterprise WeChat, External CS, AI Bot |
|
||||
| WeChat | ✅ | Personal & Official Account |
|
||||
| Lark | ✅ | Official |
|
||||
| DingTalk | ✅ | Official |
|
||||
| KOOK | ✅ | Official |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix, Satori |
|
||||
| Matrix | ✅ | Supports multiple bridged platforms such as Signal, WhatsApp, Messenger, iMessage, Mattermost, Google Chat, IRC, XMPP, Zulip, and more |
|
||||
|
||||
---
|
||||
|
||||
## Supported LLMs & Integrations
|
||||
|
||||
| Provider | Type | Status |
|
||||
| ----------------------------------------------------------------------------------------------------------------- | ------------ | ------ |
|
||||
| [OpenAI](https://platform.openai.com/) | LLM | ✅ |
|
||||
| [Anthropic](https://www.anthropic.com/) | LLM | ✅ |
|
||||
| [DeepSeek](https://www.deepseek.com/) | LLM | ✅ |
|
||||
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | LLM | ✅ |
|
||||
| [xAI](https://x.ai/) | LLM | ✅ |
|
||||
| [Moonshot](https://www.moonshot.cn/) | LLM | ✅ |
|
||||
| [Zhipu AI](https://open.bigmodel.cn/) | LLM | ✅ |
|
||||
| [Ollama](https://ollama.com/) | Local LLM | ✅ |
|
||||
| [LM Studio](https://lmstudio.ai/) | Local LLM | ✅ |
|
||||
| [Dify](https://dify.ai) | LLMOps | ✅ |
|
||||
| [MCP](https://modelcontextprotocol.io/) | Protocol | ✅ |
|
||||
| [SiliconFlow](https://siliconflow.cn/) | Gateway | ✅ |
|
||||
| [Aliyun Bailian](https://bailian.console.aliyun.com/) | Gateway | ✅ |
|
||||
| [Volc Engine Ark](https://console.volcengine.com/ark/region:ark+cn-beijing/model?vendor=Bytedance&view=LIST_VIEW) | Gateway | ✅ |
|
||||
| [ModelScope](https://modelscope.cn/docs/model-service/API-Inference/intro) | Gateway | ✅ |
|
||||
| [GiteeAI](https://ai.gitee.com/) | Gateway | ✅ |
|
||||
| [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_langbot) | GPU Platform | ✅ |
|
||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | GPU Platform | ✅ |
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | GPU Platform | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | Gateway | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | Gateway | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | Gateway | ✅ |
|
||||
| Provider | Type | Status |
|
||||
|----------|------|--------|
|
||||
| [OpenAI](https://platform.openai.com/) | LLM | ✅ |
|
||||
| [Anthropic](https://www.anthropic.com/) | LLM | ✅ |
|
||||
| [DeepSeek](https://www.deepseek.com/) | LLM | ✅ |
|
||||
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | LLM | ✅ |
|
||||
| [xAI](https://x.ai/) | LLM | ✅ |
|
||||
| [Moonshot](https://www.moonshot.cn/) | LLM | ✅ |
|
||||
| [Zhipu AI](https://open.bigmodel.cn/) | LLM | ✅ |
|
||||
| [Ollama](https://ollama.com/) | Local LLM | ✅ |
|
||||
| [LM Studio](https://lmstudio.ai/) | Local LLM | ✅ |
|
||||
| [Dify](https://dify.ai) | LLMOps | ✅ |
|
||||
| [MCP](https://modelcontextprotocol.io/) | Protocol | ✅ |
|
||||
| [SiliconFlow](https://siliconflow.cn/) | Gateway | ✅ |
|
||||
| [Aliyun Bailian](https://bailian.console.aliyun.com/) | Gateway | ✅ |
|
||||
| [Volc Engine Ark](https://console.volcengine.com/ark/region:ark+cn-beijing/model?vendor=Bytedance&view=LIST_VIEW) | Gateway | ✅ |
|
||||
| [ModelScope](https://modelscope.cn/docs/model-service/API-Inference/intro) | Gateway | ✅ |
|
||||
| [GiteeAI](https://ai.gitee.com/) | Gateway | ✅ |
|
||||
| [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_langbot) | GPU Platform | ✅ |
|
||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | GPU Platform | ✅ |
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | GPU Platform | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | Gateway | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | Gateway | ✅ |
|
||||
|
||||
[→ View all integrations](https://link.langbot.app/en/docs/features)
|
||||
|
||||
@@ -133,23 +130,22 @@ docker compose up -d
|
||||
|
||||
## Why LangBot?
|
||||
|
||||
| Use Case | How LangBot Helps |
|
||||
| --------------------------- | ------------------------------------------------------------------------------------------ |
|
||||
| **Customer Support** | Deploy AI agents to Slack/Discord/Telegram that answer questions using your knowledge base |
|
||||
| **Internal Tools** | Connect n8n/Dify workflows to WeCom/DingTalk for automated business processes |
|
||||
| **Community Management** | Moderate QQ/Discord groups with AI-powered content filtering and interaction |
|
||||
| **Multi-Platform Presence** | One bot, all platforms. Manage from a single dashboard |
|
||||
| Use Case | How LangBot Helps |
|
||||
|----------|-------------------|
|
||||
| **Customer Support** | Deploy AI agents to Slack/Discord/Telegram that answer questions using your knowledge base |
|
||||
| **Internal Tools** | Connect n8n/Dify workflows to WeCom/DingTalk for automated business processes |
|
||||
| **Community Management** | Moderate QQ/Discord groups with AI-powered content filtering and interaction |
|
||||
| **Multi-Platform Presence** | One bot, all platforms. Manage from a single dashboard |
|
||||
|
||||
---
|
||||
|
||||
## Live Demo
|
||||
|
||||
**Try it now:** https://demo.langbot.dev/
|
||||
|
||||
- Email: `demo@langbot.app`
|
||||
- Password: `langbot123456`
|
||||
|
||||
_Note: Public demo environment. Do not enter sensitive information._
|
||||
*Note: Public demo environment. Do not enter sensitive information.*
|
||||
|
||||
---
|
||||
|
||||
|
||||
18
README_CN.md
@@ -87,16 +87,13 @@ docker compose up -d
|
||||
| QQ | ✅ | 个人号、官方机器人(频道、私聊、群聊) |
|
||||
| 微信 | ✅ | 个人微信、微信公众号 |
|
||||
| 企业微信 | ✅ | 应用消息、对外客服、智能机器人 |
|
||||
| 飞书 | ✅ | 官方 |
|
||||
| 钉钉 | ✅ | 官方 |
|
||||
| Satori | ✅ | |
|
||||
| Discord | ✅ | 官方 |
|
||||
| Telegram | ✅ | 官方 |
|
||||
| Slack | ✅ | 官方 |
|
||||
| LINE | ✅ | 官方 |
|
||||
| KOOK | ✅ | 官方 |
|
||||
| Email | ✅ | 只 Matrix、Satori |
|
||||
| Matrix | ✅ | 支持多种桥接平台,如 Signal、WhatsApp、Messenger、iMessage、Mattermost、Google Chat、IRC、XMPP、Zulip 等 |
|
||||
| 飞书 | ✅ | |
|
||||
| 钉钉 | ✅ | |
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
|
||||
---
|
||||
|
||||
@@ -127,7 +124,6 @@ docker compose up -d
|
||||
| [302.AI](https://share.302.ai/SuTG99) | 聚合平台 | ✅ |
|
||||
| [小马算力](https://www.tokenpony.cn/453z1) | 聚合平台 | ✅ |
|
||||
| [百宝箱Tbox](https://www.tbox.cn/open) | 智能体平台 | ✅ |
|
||||
| [七牛云Qiniu](https://www.qiniu.com/ai/agent) | 聚合平台 | ✅ |
|
||||
|
||||
[→ 查看完整集成列表](https://link.langbot.app/zh/docs/features)
|
||||
|
||||
|
||||
19
README_ES.md
@@ -83,19 +83,17 @@ docker compose up -d
|
||||
|
||||
| Plataforma | Estado | Notas |
|
||||
|----------|--------|-------|
|
||||
| Discord | ✅ | Oficial |
|
||||
| Telegram | ✅ | Oficial |
|
||||
| Slack | ✅ | Oficial |
|
||||
| LINE | ✅ | Oficial |
|
||||
| QQ | ✅ | Personal y API Oficial (Canal, DM, Grupo) |
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | Personal y API Oficial |
|
||||
| WeCom | ✅ | WeChat Empresarial, CS Externo, AI Bot |
|
||||
| WeChat | ✅ | Personal y Cuenta Oficial |
|
||||
| Lark | ✅ | Oficial |
|
||||
| DingTalk | ✅ | Oficial |
|
||||
| KOOK | ✅ | Oficial |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix, Satori |
|
||||
| Matrix | ✅ | Admite varias plataformas puenteadas como Signal, WhatsApp, Messenger, iMessage, Mattermost, Google Chat, IRC, XMPP, Zulip y más |
|
||||
|
||||
---
|
||||
|
||||
@@ -124,7 +122,6 @@ docker compose up -d
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | Plataforma GPU | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | Pasarela | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | Pasarela | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | Pasarela | ✅ |
|
||||
|
||||
[→ Ver todas las integraciones](https://link.langbot.app/en/docs/features)
|
||||
|
||||
|
||||
19
README_FR.md
@@ -83,19 +83,17 @@ docker compose up -d
|
||||
|
||||
| Plateforme | Statut | Notes |
|
||||
|----------|--------|-------|
|
||||
| Discord | ✅ | Officiel |
|
||||
| Telegram | ✅ | Officiel |
|
||||
| Slack | ✅ | Officiel |
|
||||
| LINE | ✅ | Officiel |
|
||||
| QQ | ✅ | Personnel & API Officielle (Canal, DM, Groupe) |
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | Personnel & API Officielle |
|
||||
| WeCom | ✅ | WeChat Entreprise, CS Externe, AI Bot |
|
||||
| WeChat | ✅ | Personnel & Compte Officiel |
|
||||
| Lark | ✅ | Officiel |
|
||||
| DingTalk | ✅ | Officiel |
|
||||
| KOOK | ✅ | Officiel |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix, Satori |
|
||||
| Matrix | ✅ | Prend en charge plusieurs plateformes via ponts, comme Signal, WhatsApp, Messenger, iMessage, Mattermost, Google Chat, IRC, XMPP, Zulip, etc. |
|
||||
|
||||
---
|
||||
|
||||
@@ -124,7 +122,6 @@ docker compose up -d
|
||||
| [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_langbot) | Plateforme GPU | ✅ |
|
||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | Plateforme GPU | ✅ |
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | Plateforme GPU | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | Passerelle | ✅ |
|
||||
|
||||
[→ Voir toutes les intégrations](https://link.langbot.app/en/docs/features)
|
||||
|
||||
|
||||
21
README_JP.md
@@ -83,19 +83,17 @@ docker compose up -d
|
||||
|
||||
| プラットフォーム | ステータス | 備考 |
|
||||
|----------|--------|-------|
|
||||
| Discord | ✅ | 公式 |
|
||||
| Telegram | ✅ | 公式 |
|
||||
| Slack | ✅ | 公式 |
|
||||
| LINE | ✅ | 公式 |
|
||||
| QQ | ✅ | 個人・公式API(チャンネル・DM・グループ) |
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | 個人 & 公式API |
|
||||
| WeCom | ✅ | 企業WeChat、外部CS、AIボット |
|
||||
| WeChat | ✅ | 個人・公式アカウント |
|
||||
| Lark | ✅ | 公式 |
|
||||
| DingTalk | ✅ | 公式 |
|
||||
| KOOK | ✅ | 公式 |
|
||||
| WeChat | ✅ | 個人 & 公式アカウント |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix、Satori |
|
||||
| Matrix | ✅ | Signal、WhatsApp、Messenger、iMessage、Mattermost、Google Chat、IRC、XMPP、Zulip など複数のブリッジ先プラットフォームに対応 |
|
||||
|
||||
---
|
||||
|
||||
@@ -124,7 +122,6 @@ docker compose up -d
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | GPUプラットフォーム | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | ゲートウェイ | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | ゲートウェイ | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | ゲートウェイ | ✅ |
|
||||
|
||||
[→ すべての統合を表示](https://link.langbot.app/en/docs/features)
|
||||
|
||||
|
||||
19
README_KO.md
@@ -83,19 +83,17 @@ docker compose up -d
|
||||
|
||||
| 플랫폼 | 상태 | 비고 |
|
||||
|--------|------|------|
|
||||
| Discord | ✅ | 공식 |
|
||||
| Telegram | ✅ | 공식 |
|
||||
| Slack | ✅ | 공식 |
|
||||
| LINE | ✅ | 공식 |
|
||||
| QQ | ✅ | 개인 및 공식 API (채널, DM, 그룹) |
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | 개인 및 공식 API |
|
||||
| WeCom | ✅ | 기업 WeChat, 외부 CS, AI Bot |
|
||||
| WeChat | ✅ | 개인 및 공식 계정 |
|
||||
| Lark | ✅ | 공식 |
|
||||
| DingTalk | ✅ | 공식 |
|
||||
| KOOK | ✅ | 공식 |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix, Satori |
|
||||
| Matrix | ✅ | Signal, WhatsApp, Messenger, iMessage, Mattermost, Google Chat, IRC, XMPP, Zulip 등 여러 브리지 플랫폼 지원 |
|
||||
|
||||
---
|
||||
|
||||
@@ -124,7 +122,6 @@ docker compose up -d
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | GPU 플랫폼 | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | 게이트웨이 | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | 게이트웨이 | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | 게이트웨이 | ✅ |
|
||||
|
||||
[→ 모든 통합 보기](https://link.langbot.app/en/docs/features)
|
||||
|
||||
|
||||
19
README_RU.md
@@ -83,19 +83,17 @@ docker compose up -d
|
||||
|
||||
| Платформа | Статус | Примечания |
|
||||
|-----------|--------|------------|
|
||||
| Discord | ✅ | Официальный |
|
||||
| Telegram | ✅ | Официальный |
|
||||
| Slack | ✅ | Официальный |
|
||||
| LINE | ✅ | Официальный |
|
||||
| QQ | ✅ | Личный и официальный API (Канал, ЛС, Группа) |
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | Личный и официальный API |
|
||||
| WeCom | ✅ | Корпоративный WeChat, внешний CS, AI-бот |
|
||||
| WeChat | ✅ | Личный и официальный аккаунт |
|
||||
| Lark | ✅ | Официальный |
|
||||
| DingTalk | ✅ | Официальный |
|
||||
| KOOK | ✅ | Официальный |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix, Satori |
|
||||
| Matrix | ✅ | Поддерживает несколько платформ через мосты, включая Signal, WhatsApp, Messenger, iMessage, Mattermost, Google Chat, IRC, XMPP, Zulip и другие |
|
||||
|
||||
---
|
||||
|
||||
@@ -124,7 +122,6 @@ docker compose up -d
|
||||
| [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_langbot) | Платформа GPU | ✅ |
|
||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | Платформа GPU | ✅ |
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | Платформа GPU | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | Шлюз | ✅ |
|
||||
|
||||
[→ Смотреть все интеграции](https://link.langbot.app/en/docs/features)
|
||||
|
||||
|
||||
19
README_TW.md
@@ -85,19 +85,17 @@ docker compose up -d
|
||||
|
||||
| 平台 | 狀態 | 備註 |
|
||||
|------|------|------|
|
||||
| Discord | ✅ | 官方 |
|
||||
| Telegram | ✅ | 官方 |
|
||||
| Slack | ✅ | 官方 |
|
||||
| LINE | ✅ | 官方 |
|
||||
| QQ | ✅ | 個人號、官方機器人(頻道、私聊、群聊) |
|
||||
| 企業微信 | ✅ | 應用訊息、對外客服、智能機器人 |
|
||||
| 微信 | ✅ | 個人微信、微信公眾號 |
|
||||
| 飛書 | ✅ | 官方 |
|
||||
| 釘釘 | ✅ | 官方 |
|
||||
| KOOK | ✅ | 官方 |
|
||||
| 企業微信 | ✅ | 應用訊息、對外客服、智能機器人 |
|
||||
| 飛書 | ✅ | |
|
||||
| 釘釘 | ✅ | |
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | 只 Matrix、Satori |
|
||||
| Matrix | ✅ | 支援多種橋接平台,如 Signal、WhatsApp、Messenger、iMessage、Mattermost、Google Chat、IRC、XMPP、Zulip 等 |
|
||||
|
||||
---
|
||||
|
||||
@@ -126,7 +124,6 @@ docker compose up -d
|
||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | GPU 平台 | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | 聚合平台 | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | 聚合平台 | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | 聚合平台 | ✅ |
|
||||
|
||||
### TTS(語音合成)
|
||||
|
||||
|
||||
19
README_VI.md
@@ -83,19 +83,17 @@ docker compose up -d
|
||||
|
||||
| Nền tảng | Trạng thái | Ghi chú |
|
||||
|----------|--------|-------|
|
||||
| Discord | ✅ | Chính thức |
|
||||
| Telegram | ✅ | Chính thức |
|
||||
| Slack | ✅ | Chính thức |
|
||||
| LINE | ✅ | Chính thức |
|
||||
| QQ | ✅ | Cá nhân & API chính thức (Kênh, DM, Nhóm) |
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
| LINE | ✅ | |
|
||||
| QQ | ✅ | Cá nhân & API chính thức |
|
||||
| WeCom | ✅ | WeChat doanh nghiệp, CS bên ngoài, AI Bot |
|
||||
| WeChat | ✅ | Cá nhân & Tài khoản công khai |
|
||||
| Lark | ✅ | Chính thức |
|
||||
| DingTalk | ✅ | Chính thức |
|
||||
| KOOK | ✅ | Chính thức |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Satori | ✅ | |
|
||||
| Email | ✅ | Matrix, Satori |
|
||||
| Matrix | ✅ | Hỗ trợ nhiều nền tảng qua bridge như Signal, WhatsApp, Messenger, iMessage, Mattermost, Google Chat, IRC, XMPP, Zulip và hơn thế nữa |
|
||||
|
||||
---
|
||||
|
||||
@@ -124,7 +122,6 @@ docker compose up -d
|
||||
| [ShengSuanYun](https://www.shengsuanyun.com/?from=CH_KYIPP758) | Nền tảng GPU | ✅ |
|
||||
| [接口 AI](https://jiekou.ai/) | Cổng | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | Cổng | ✅ |
|
||||
| [Qiniu](https://www.qiniu.com/ai/agent) | Cổng | ✅ |
|
||||
|
||||
[→ Xem tất cả tích hợp](https://link.langbot.app/en/docs/features)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "langbot"
|
||||
version = "4.9.7"
|
||||
version = "4.9.5"
|
||||
description = "Production-grade platform for building agentic IM bots"
|
||||
readme = "README.md"
|
||||
license-files = ["LICENSE"]
|
||||
@@ -8,7 +8,7 @@ requires-python = ">=3.11,<4.0"
|
||||
dependencies = [
|
||||
"aiocqhttp>=1.4.4",
|
||||
"aiofiles>=24.1.0",
|
||||
"aiohttp>=3.13.4",
|
||||
"aiohttp>=3.11.18",
|
||||
"aioshutil>=1.5",
|
||||
"aiosqlite>=0.21.0",
|
||||
"anthropic>=0.51.0",
|
||||
@@ -16,18 +16,18 @@ dependencies = [
|
||||
"async-lru>=2.0.5",
|
||||
"certifi>=2025.4.26",
|
||||
"colorlog~=6.6.0",
|
||||
"cryptography>=46.0.7",
|
||||
"cryptography>=44.0.3",
|
||||
"dashscope>=1.25.10",
|
||||
"dingtalk-stream>=0.24.0",
|
||||
"discord-py>=2.5.2",
|
||||
"pynacl>=1.5.0", # Required for Discord voice support
|
||||
"gewechat-client>=0.1.5",
|
||||
"lark-oapi>=1.5.5",
|
||||
"lark-oapi>=1.4.15",
|
||||
"mcp>=1.25.0",
|
||||
"nakuru-project-idk>=0.0.2.1",
|
||||
"ollama>=0.4.8",
|
||||
"openai>1.0.0",
|
||||
"pillow>=12.2.0",
|
||||
"pillow>=11.2.1",
|
||||
"psutil>=7.0.0",
|
||||
"pycryptodome>=3.22.0",
|
||||
"pydantic>2.0",
|
||||
@@ -35,12 +35,10 @@ dependencies = [
|
||||
"python-telegram-bot>=22.0",
|
||||
"pyyaml>=6.0.2",
|
||||
"qq-botpy-rc>=1.2.1.6",
|
||||
"qrcode>=7.4",
|
||||
"quart>=0.20.0",
|
||||
"quart-cors>=0.8.0",
|
||||
"requests>=2.32.3",
|
||||
"slack-sdk>=3.35.0",
|
||||
"alembic>=1.15.0",
|
||||
"sqlalchemy[asyncio]>=2.0.40",
|
||||
"sqlmodel>=0.0.24",
|
||||
"telegramify-markdown>=0.5.1",
|
||||
@@ -51,7 +49,7 @@ dependencies = [
|
||||
"pip>=25.1.1",
|
||||
"ruff>=0.11.9",
|
||||
"pre-commit>=4.2.0",
|
||||
"uv>=0.11.6",
|
||||
"uv>=0.7.11",
|
||||
"mypy>=1.16.0",
|
||||
"PyPDF2>=3.0.1",
|
||||
"python-docx>=1.1.0",
|
||||
@@ -62,18 +60,13 @@ dependencies = [
|
||||
"ebooklib>=0.18",
|
||||
"html2text>=2024.2.26",
|
||||
"langchain>=0.2.0",
|
||||
"langchain-core>=1.2.28",
|
||||
"langsmith>=0.7.31",
|
||||
"python-multipart>=0.0.26",
|
||||
"Mako>=1.3.11",
|
||||
"langchain-text-splitters>=1.1.2",
|
||||
"langchain-text-splitters>=0.0.1",
|
||||
"chromadb>=1.0.0,<2.0.0",
|
||||
"qdrant-client (>=1.15.1,<2.0.0)",
|
||||
"pyseekdb==1.1.0.post3",
|
||||
"langbot-plugin==0.3.11",
|
||||
"langbot-plugin==0.3.7",
|
||||
"asyncpg>=0.30.0",
|
||||
"line-bot-sdk>=3.19.0",
|
||||
"matrix-nio>=0.25.2",
|
||||
"tboxsdk>=0.0.10",
|
||||
"boto3>=1.35.0",
|
||||
"pymilvus>=2.6.4",
|
||||
@@ -118,12 +111,12 @@ requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
package-data = { "langbot" = ["templates/**", "pkg/provider/modelmgr/requesters/*", "pkg/platform/sources/*", "web/dist/**", "pkg/persistence/alembic/**"] }
|
||||
package-data = { "langbot" = ["templates/**", "pkg/provider/modelmgr/requesters/*", "pkg/platform/sources/*", "web/dist/**"] }
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pre-commit>=4.2.0",
|
||||
"pytest>=9.0.3",
|
||||
"pytest>=8.4.1",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
"pytest-cov>=7.0.0",
|
||||
"ruff>=0.11.9",
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""LangBot - Production-grade platform for building agentic IM bots"""
|
||||
|
||||
__version__ = '4.9.7'
|
||||
__version__ = '4.9.5'
|
||||
|
||||
@@ -182,88 +182,6 @@ class DingTalkClient:
|
||||
for handler in self._message_handlers[msg_type]:
|
||||
await handler(event)
|
||||
|
||||
async def _parse_quoted_message(self, replied_msg: dict) -> dict:
|
||||
"""Parse the quoted/replied message and extract its content.
|
||||
|
||||
Args:
|
||||
replied_msg: The repliedMsg object from DingTalk message
|
||||
|
||||
Returns:
|
||||
A dict containing the quoted message info with keys:
|
||||
- message_id: The original message ID
|
||||
- msg_type: The message type (text, file, picture, audio, etc.)
|
||||
- content: The text content (if any)
|
||||
- file_url: The file download URL (if file type)
|
||||
- file_name: The file name (if file type)
|
||||
- picture: The picture base64 (if picture type)
|
||||
- audio: The audio base64 (if audio type)
|
||||
"""
|
||||
quote_info = {
|
||||
'message_id': replied_msg.get('msgId', ''),
|
||||
'msg_type': replied_msg.get('msgType', ''),
|
||||
'sender_id': replied_msg.get('senderId', ''),
|
||||
}
|
||||
|
||||
msg_type = replied_msg.get('msgType', '')
|
||||
content = replied_msg.get('content', {})
|
||||
|
||||
# Handle content as string (JSON) or dict
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
content = {}
|
||||
|
||||
if msg_type == 'text':
|
||||
# Text message
|
||||
if isinstance(content, dict):
|
||||
quote_info['content'] = content.get('content', '')
|
||||
else:
|
||||
quote_info['content'] = str(content)
|
||||
|
||||
elif msg_type == 'file':
|
||||
# File message
|
||||
download_code = content.get('downloadCode')
|
||||
file_name = content.get('fileName')
|
||||
if download_code and file_name:
|
||||
try:
|
||||
quote_info['file_url'] = await self.get_file_url(download_code)
|
||||
quote_info['file_name'] = file_name
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
await self.logger.error(f'Failed to get quoted file URL: {e}')
|
||||
|
||||
elif msg_type == 'picture':
|
||||
# Picture message
|
||||
download_code = content.get('downloadCode')
|
||||
if download_code:
|
||||
try:
|
||||
quote_info['picture'] = await self.download_image(download_code)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
await self.logger.error(f'Failed to download quoted image: {e}')
|
||||
|
||||
elif msg_type == 'audio':
|
||||
# Audio message
|
||||
download_code = content.get('downloadCode')
|
||||
if download_code:
|
||||
try:
|
||||
quote_info['audio'] = await self.get_audio_url(download_code)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
await self.logger.error(f'Failed to get quoted audio: {e}')
|
||||
|
||||
elif msg_type == 'richText':
|
||||
# Rich text message - extract text content
|
||||
rich_text = content.get('richText', [])
|
||||
texts = []
|
||||
for item in rich_text:
|
||||
if 'text' in item and item['text'] != '\n':
|
||||
texts.append(item['text'])
|
||||
quote_info['content'] = '\n'.join(texts)
|
||||
|
||||
return quote_info
|
||||
|
||||
async def get_message(self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage):
|
||||
try:
|
||||
# print(json.dumps(incoming_message.to_dict(), indent=4, ensure_ascii=False))
|
||||
@@ -275,15 +193,6 @@ class DingTalkClient:
|
||||
elif str(incoming_message.conversation_type) == '2':
|
||||
message_data['conversation_type'] = 'GroupMessage'
|
||||
|
||||
# Check for quoted/replied message
|
||||
raw_data = incoming_message.to_dict()
|
||||
text_data = raw_data.get('text', {})
|
||||
if isinstance(text_data, dict) and text_data.get('isReplyMsg'):
|
||||
replied_msg = text_data.get('repliedMsg', {})
|
||||
if replied_msg:
|
||||
quote_info = await self._parse_quoted_message(replied_msg)
|
||||
message_data['QuotedMessage'] = quote_info
|
||||
|
||||
if incoming_message.message_type == 'richText':
|
||||
data = incoming_message.rich_text_content.to_dict()
|
||||
|
||||
@@ -359,25 +268,7 @@ class DingTalkClient:
|
||||
|
||||
message_data['Type'] = 'image'
|
||||
elif incoming_message.message_type == 'audio':
|
||||
raw_content = incoming_message.to_dict().get('content', {})
|
||||
# 兼容处理:如果 content 仍为 JSON 字符串则进行解析
|
||||
if isinstance(raw_content, str):
|
||||
try:
|
||||
raw_content = json.loads(raw_content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
raw_content = {}
|
||||
|
||||
if self.logger:
|
||||
await self.logger.info(f'DingTalk audio raw content: {json.dumps(raw_content, ensure_ascii=False)}')
|
||||
|
||||
# 提取钉钉自带的语音转写文字(Powered by Qwen)
|
||||
recognition = raw_content.get('recognition', '')
|
||||
if recognition:
|
||||
message_data['Content'] = recognition
|
||||
|
||||
download_code = raw_content.get('downloadCode')
|
||||
if download_code:
|
||||
message_data['Audio'] = await self.get_audio_url(download_code)
|
||||
message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode'])
|
||||
|
||||
message_data['Type'] = 'audio'
|
||||
elif incoming_message.message_type == 'file':
|
||||
@@ -481,12 +372,6 @@ class DingTalkClient:
|
||||
card_data['config'] = json.dumps({'autoLayout': card_auto_layout})
|
||||
card_data['content'] = ''
|
||||
|
||||
# 将用户的消息内容作为卡片的查询参数,方便后续处理
|
||||
if incoming_message.message_type == 'text':
|
||||
card_data['query'] = incoming_message.get_text_list()[0]
|
||||
else:
|
||||
card_data['query'] = '...'
|
||||
|
||||
card_instance = dingtalk_stream.AICardReplier(self.client, incoming_message)
|
||||
# print(card_instance)
|
||||
# 先投放卡片: https://open.dingtalk.com/document/orgapp/create-and-deliver-cards
|
||||
|
||||
@@ -47,22 +47,6 @@ class DingTalkEvent(dict):
|
||||
def conversation(self):
|
||||
return self.get('conversation_type', '')
|
||||
|
||||
@property
|
||||
def quoted_message(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get the quoted/replied message info if this is a reply message.
|
||||
|
||||
Returns:
|
||||
A dict containing:
|
||||
- message_id: The original message ID
|
||||
- msg_type: The message type (text, file, picture, audio, etc.)
|
||||
- content: The text content (if any)
|
||||
- file_url: The file download URL (if file type)
|
||||
- file_name: The file name (if file type)
|
||||
- picture: The picture base64 (if picture type)
|
||||
- audio: The audio base64 (if audio type)
|
||||
"""
|
||||
return self.get('QuotedMessage')
|
||||
|
||||
def __getattr__(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
允许通过属性访问数据中的任意字段。
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import re
|
||||
import time
|
||||
import asyncio
|
||||
from quart import request
|
||||
import httpx
|
||||
from quart import Quart
|
||||
from typing import Callable, Dict, Any, Optional
|
||||
from typing import Callable, Dict, Any
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
from .qqofficialevent import QQOfficialEvent
|
||||
import json
|
||||
@@ -34,8 +32,6 @@ class QQOfficialClient:
|
||||
self.access_token = ''
|
||||
self.access_token_expiry_time = None
|
||||
self.logger = logger
|
||||
self._msg_seq_counter = 0
|
||||
self._token_refresh_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def check_access_token(self):
|
||||
"""检查access_token是否存在"""
|
||||
@@ -54,18 +50,18 @@ class QQOfficialClient:
|
||||
headers = {
|
||||
'content-type': 'application/json',
|
||||
}
|
||||
response = await client.post(url, json=params, headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to get access_token: HTTP {response.status_code} {response.text}')
|
||||
response_data = response.json()
|
||||
access_token = response_data.get('access_token')
|
||||
expires_in = int(response_data.get('expires_in', 7200))
|
||||
self.access_token_expiry_time = time.time() + expires_in - 60
|
||||
if access_token:
|
||||
self.access_token = access_token
|
||||
await self.logger.info(f'access_token obtained, expires_in={expires_in}s')
|
||||
else:
|
||||
raise Exception('Failed to get access_token: no access_token in response')
|
||||
try:
|
||||
response = await client.post(url, json=params, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
access_token = response_data.get('access_token')
|
||||
expires_in = int(response_data.get('expires_in', 7200))
|
||||
self.access_token_expiry_time = time.time() + expires_in - 60
|
||||
if access_token:
|
||||
self.access_token = access_token
|
||||
except Exception as e:
|
||||
await self.logger.error(f'获取access_token失败: {response_data}')
|
||||
raise Exception(f'获取access_token失败: {e}')
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求(独立端口模式,使用全局 request)"""
|
||||
@@ -91,10 +87,10 @@ class QQOfficialClient:
|
||||
try:
|
||||
body = await req.get_data()
|
||||
|
||||
await self.logger.info(f'Received request, body length: {len(body)}')
|
||||
print(f'[QQ Official] Received request, body length: {len(body)}')
|
||||
|
||||
if not body or len(body) == 0:
|
||||
await self.logger.info('Received empty body, might be health check or GET request')
|
||||
print('[QQ Official] Received empty body, might be health check or GET request')
|
||||
return {'code': 0, 'message': 'ok'}, 200
|
||||
|
||||
payload = json.loads(body)
|
||||
@@ -115,6 +111,7 @@ class QQOfficialClient:
|
||||
return {'code': 0, 'message': 'success'}
|
||||
|
||||
except Exception as e:
|
||||
print(f'[QQ Official] ERROR: {traceback.format_exc()}')
|
||||
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
|
||||
return {'error': str(e)}, 400
|
||||
|
||||
@@ -142,24 +139,21 @@ class QQOfficialClient:
|
||||
|
||||
async def get_message(self, msg: dict) -> Dict[str, Any]:
|
||||
"""获取消息"""
|
||||
d = msg.get('d', {})
|
||||
if not isinstance(d, dict):
|
||||
return {}
|
||||
message_data = {
|
||||
't': msg.get('t', {}),
|
||||
'user_openid': d.get('author', {}).get('user_openid', {}),
|
||||
'timestamp': d.get('timestamp', {}),
|
||||
'd_author_id': d.get('author', {}).get('id', {}),
|
||||
'content': d.get('content', {}),
|
||||
'd_id': d.get('id', {}),
|
||||
'user_openid': msg.get('d', {}).get('author', {}).get('user_openid', {}),
|
||||
'timestamp': msg.get('d', {}).get('timestamp', {}),
|
||||
'd_author_id': msg.get('d', {}).get('author', {}).get('id', {}),
|
||||
'content': msg.get('d', {}).get('content', {}),
|
||||
'd_id': msg.get('d', {}).get('id', {}),
|
||||
'id': msg.get('id', {}),
|
||||
'channel_id': d.get('channel_id', {}),
|
||||
'username': d.get('author', {}).get('username', {}),
|
||||
'guild_id': d.get('guild_id', {}),
|
||||
'member_openid': d.get('author', {}).get('openid', {}),
|
||||
'group_openid': d.get('group_openid', {}),
|
||||
'channel_id': msg.get('d', {}).get('channel_id', {}),
|
||||
'username': msg.get('d', {}).get('author', {}).get('username', {}),
|
||||
'guild_id': msg.get('d', {}).get('guild_id', {}),
|
||||
'member_openid': msg.get('d', {}).get('author', {}).get('openid', {}),
|
||||
'group_openid': msg.get('d', {}).get('group_openid', {}),
|
||||
}
|
||||
attachments = d.get('attachments', [])
|
||||
attachments = msg.get('d', {}).get('attachments', [])
|
||||
image_attachments = [attachment['url'] for attachment in attachments if await self.is_image(attachment)]
|
||||
image_attachments_type = [
|
||||
attachment['content_type'] for attachment in attachments if await self.is_image(attachment)
|
||||
@@ -198,7 +192,7 @@ class QQOfficialClient:
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
await self.logger.error(f'Failed to send private message: {response_data}')
|
||||
await self.logger.error(f'发送私聊消息失败: {response_data}')
|
||||
raise ValueError(response)
|
||||
|
||||
async def send_group_text_msg(self, group_openid: str, content: str, msg_id: str):
|
||||
@@ -221,7 +215,7 @@ class QQOfficialClient:
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
await self.logger.error(f'Failed to send group message: {response.json()}')
|
||||
await self.logger.error(f'发送群聊消息失败:{response.json()}')
|
||||
raise Exception(response.read().decode())
|
||||
|
||||
async def send_channle_group_text_msg(self, channel_id: str, content: str, msg_id: str):
|
||||
@@ -244,7 +238,7 @@ class QQOfficialClient:
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
await self.logger.error(f'Failed to send channel group message: {response.json()}')
|
||||
await self.logger.error(f'发送频道群聊消息失败: {response.json()}')
|
||||
raise Exception(response)
|
||||
|
||||
async def send_channle_private_text_msg(self, guild_id: str, content: str, msg_id: str):
|
||||
@@ -267,224 +261,9 @@ class QQOfficialClient:
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
await self.logger.error(f'Failed to send channel private message: {response.json()}')
|
||||
await self.logger.error(f'发送频道私聊消息失败: {response.json()}')
|
||||
raise Exception(response)
|
||||
|
||||
# ---- 富媒体消息 ----
|
||||
|
||||
# 媒体文件类型
|
||||
MEDIA_TYPE_IMAGE = 1
|
||||
MEDIA_TYPE_VIDEO = 2
|
||||
MEDIA_TYPE_VOICE = 3
|
||||
MEDIA_TYPE_FILE = 4
|
||||
|
||||
async def upload_media(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
file_type: int,
|
||||
file_url: str = None,
|
||||
file_data: str = None,
|
||||
file_name: str = None,
|
||||
) -> str:
|
||||
"""上传媒体文件,返回 file_info。
|
||||
|
||||
Args:
|
||||
target_type: 'c2c' | 'group'
|
||||
target_id: 用户 openid 或群 openid
|
||||
file_type: 1=图片, 2=视频, 3=语音, 4=文件
|
||||
file_url: 在线 URL(与 file_data 二选一)
|
||||
file_data: base64 编码的文件数据或 data URL(与 file_url 二选一)
|
||||
file_name: 文件名(file_type=4 时必填)
|
||||
"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
|
||||
if target_type == 'c2c':
|
||||
url = f'{self.base_url}/v2/users/{target_id}/files'
|
||||
elif target_type == 'group':
|
||||
url = f'{self.base_url}/v2/groups/{target_id}/files'
|
||||
else:
|
||||
raise ValueError(f'Unsupported target_type: {target_type}')
|
||||
|
||||
body = {
|
||||
'file_type': file_type,
|
||||
'srv_send_msg': False,
|
||||
}
|
||||
if file_url:
|
||||
body['url'] = file_url
|
||||
elif file_data:
|
||||
# 处理 data URL 格式: data:image/png;base64,xxxxx
|
||||
if file_data.startswith('data:'):
|
||||
match = re.match(r'^data:[^;]+;base64,(.+)$', file_data, re.DOTALL)
|
||||
if match:
|
||||
body['file_data'] = match.group(1)
|
||||
else:
|
||||
body['file_data'] = file_data
|
||||
else:
|
||||
body['file_data'] = file_data
|
||||
else:
|
||||
raise ValueError('file_url or file_data is required')
|
||||
|
||||
if file_type == self.MEDIA_TYPE_FILE and file_name:
|
||||
body['file_name'] = file_name
|
||||
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
headers = {
|
||||
'Authorization': f'QQBot {self.access_token}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
file_info = data.get('file_info', '')
|
||||
preview = file_info[:80] + '...' if len(file_info) > 80 else file_info
|
||||
await self.logger.info(f'Upload media success, file_info={preview}')
|
||||
return file_info
|
||||
else:
|
||||
raise Exception(f'Failed to upload media: HTTP {response.status_code} {response.text}')
|
||||
|
||||
async def _send_media_msg(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
file_info: str,
|
||||
msg_id: str = None,
|
||||
content: str = None,
|
||||
):
|
||||
"""发送富媒体消息(msg_type=7)"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
|
||||
if target_type == 'c2c':
|
||||
url = f'{self.base_url}/v2/users/{target_id}/messages'
|
||||
elif target_type == 'group':
|
||||
url = f'{self.base_url}/v2/groups/{target_id}/messages'
|
||||
else:
|
||||
raise ValueError(f'Unsupported target_type: {target_type}')
|
||||
|
||||
self._msg_seq_counter += 1
|
||||
msg_seq = self._msg_seq_counter
|
||||
body = {
|
||||
'msg_type': 7,
|
||||
'media': {'file_info': file_info},
|
||||
'msg_seq': msg_seq,
|
||||
}
|
||||
if content:
|
||||
body['content'] = content
|
||||
if msg_id:
|
||||
body['msg_id'] = msg_id
|
||||
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
headers = {
|
||||
'Authorization': f'QQBot {self.access_token}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
await self.logger.info(f'Sending rich media: {json.dumps(body, ensure_ascii=False)[:200]}')
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to send rich media message: HTTP {response.status_code} {response.text}')
|
||||
|
||||
async def send_image_msg(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
file_url: str = None,
|
||||
file_data: str = None,
|
||||
msg_id: str = None,
|
||||
content: str = None,
|
||||
):
|
||||
"""发送图片消息"""
|
||||
file_info = await self.upload_media(
|
||||
target_type,
|
||||
target_id,
|
||||
self.MEDIA_TYPE_IMAGE,
|
||||
file_url=file_url,
|
||||
file_data=file_data,
|
||||
)
|
||||
await self._send_media_msg(target_type, target_id, file_info, msg_id, content)
|
||||
|
||||
async def send_voice_msg(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
file_url: str = None,
|
||||
file_data: str = None,
|
||||
msg_id: str = None,
|
||||
):
|
||||
"""发送语音消息"""
|
||||
file_info = await self.upload_media(
|
||||
target_type,
|
||||
target_id,
|
||||
self.MEDIA_TYPE_VOICE,
|
||||
file_url=file_url,
|
||||
file_data=file_data,
|
||||
)
|
||||
await self._send_media_msg(target_type, target_id, file_info, msg_id)
|
||||
|
||||
async def send_file_msg(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
file_url: str = None,
|
||||
file_data: str = None,
|
||||
file_name: str = None,
|
||||
msg_id: str = None,
|
||||
):
|
||||
"""发送文件消息(含视频)"""
|
||||
file_info = await self.upload_media(
|
||||
target_type,
|
||||
target_id,
|
||||
self.MEDIA_TYPE_FILE,
|
||||
file_url=file_url,
|
||||
file_data=file_data,
|
||||
file_name=file_name,
|
||||
)
|
||||
await self._send_media_msg(target_type, target_id, file_info, msg_id)
|
||||
|
||||
async def send_stream_msg(
|
||||
self,
|
||||
user_openid: str,
|
||||
content: str,
|
||||
event_id: str,
|
||||
msg_id: str,
|
||||
msg_seq: int = 1,
|
||||
index: int = 0,
|
||||
stream_msg_id: str = None,
|
||||
input_state: int = 1,
|
||||
):
|
||||
"""发送流式消息(C2C 私聊)。
|
||||
|
||||
Args:
|
||||
input_state: 1=生成中, 10=生成结束
|
||||
"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
|
||||
url = f'{self.base_url}/v2/users/{user_openid}/stream_messages'
|
||||
body = {
|
||||
'input_mode': 'replace',
|
||||
'input_state': input_state,
|
||||
'content_type': 'markdown',
|
||||
'content_raw': content,
|
||||
'event_id': event_id,
|
||||
'msg_id': msg_id,
|
||||
'msg_seq': msg_seq,
|
||||
'index': index,
|
||||
}
|
||||
if stream_msg_id:
|
||||
body['stream_msg_id'] = stream_msg_id
|
||||
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
headers = {
|
||||
'Authorization': f'QQBot {self.access_token}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to send stream message: HTTP {response.status_code} {response.text}')
|
||||
return response.json()
|
||||
|
||||
async def is_token_expired(self):
|
||||
"""检查token是否过期"""
|
||||
if self.access_token_expiry_time is None:
|
||||
@@ -513,325 +292,3 @@ class QQOfficialClient:
|
||||
'signature': signature,
|
||||
}
|
||||
return response
|
||||
|
||||
# ---- WebSocket Gateway ----
|
||||
# Reference: https://bot.q.qq.com/wiki/develop/api-v2/dev-prepare/interface-framework/event-emit.html
|
||||
|
||||
INTENT_GUILDS = 1 << 0
|
||||
INTENT_GUILD_MEMBERS = 1 << 1
|
||||
INTENT_PUBLIC_GUILD_MESSAGES = 1 << 30
|
||||
INTENT_DIRECT_MESSAGE = 1 << 12
|
||||
INTENT_GROUP_AND_C2C = 1 << 25
|
||||
INTENT_INTERACTION = 1 << 26
|
||||
|
||||
FULL_INTENTS = (
|
||||
INTENT_GUILDS
|
||||
| INTENT_GUILD_MEMBERS
|
||||
| INTENT_PUBLIC_GUILD_MESSAGES
|
||||
| INTENT_DIRECT_MESSAGE
|
||||
| INTENT_GROUP_AND_C2C
|
||||
| INTENT_INTERACTION
|
||||
)
|
||||
|
||||
async def get_gateway_url(self) -> str:
|
||||
"""获取 WebSocket 网关地址"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
|
||||
url = f'{self.base_url}/gateway'
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
'Authorization': f'QQBot {self.access_token}',
|
||||
}
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
ws_url = data.get('url', '')
|
||||
if not ws_url:
|
||||
raise Exception('Gateway URL is empty')
|
||||
return ws_url
|
||||
else:
|
||||
raise Exception(f'Failed to get Gateway URL: HTTP {response.status_code} {response.text}')
|
||||
|
||||
async def _background_token_refresh(self):
|
||||
"""在 token 到期前主动刷新"""
|
||||
try:
|
||||
while True:
|
||||
if self.access_token_expiry_time:
|
||||
remain = self.access_token_expiry_time - time.time()
|
||||
if remain > 120:
|
||||
await asyncio.sleep(remain - 60)
|
||||
continue
|
||||
self.access_token = ''
|
||||
self.access_token_expiry_time = None
|
||||
if await self.check_access_token():
|
||||
await asyncio.sleep(60)
|
||||
else:
|
||||
await self.get_access_token()
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def connect_gateway(
|
||||
self,
|
||||
on_event: Callable[[str, dict], Any],
|
||||
on_ready: Optional[Callable[[], Any]] = None,
|
||||
on_error: Optional[Callable[[Exception], Any]] = None,
|
||||
):
|
||||
"""WebSocket 网关连接,含重连逻辑。持续重连直到达到最大次数或被取消。
|
||||
|
||||
Args:
|
||||
on_event: 收到 op=0 Dispatch 事件时的回调,参数为 (event_type, event_data)
|
||||
on_ready: 连接就绪 (收到 READY) 时的回调
|
||||
on_error: 发生错误时的回调
|
||||
"""
|
||||
import websockets
|
||||
|
||||
session_id = ''
|
||||
last_seq = 0
|
||||
reconnect_attempts = 0
|
||||
max_reconnect_attempts = 100
|
||||
backoff_delays = [1, 2, 5, 10, 30, 60]
|
||||
rate_limit_delay = 60
|
||||
|
||||
# Cancel previous token refresh task if any
|
||||
if self._token_refresh_task and not self._token_refresh_task.done():
|
||||
self._token_refresh_task.cancel()
|
||||
try:
|
||||
await self._token_refresh_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._token_refresh_task = None
|
||||
|
||||
while reconnect_attempts <= max_reconnect_attempts:
|
||||
heartbeat_interval = 45000
|
||||
should_refresh_token = False
|
||||
ws = None
|
||||
heartbeat_task = None
|
||||
|
||||
# Refresh token if needed
|
||||
if should_refresh_token:
|
||||
self.access_token = ''
|
||||
self.access_token_expiry_time = None
|
||||
|
||||
try:
|
||||
ws_url = await self.get_gateway_url()
|
||||
await self.logger.info(f'Gateway URL obtained: {ws_url[:60]}...')
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
await self.logger.error(f'Failed to get gateway URL: {e}')
|
||||
reconnect_attempts += 1
|
||||
if '100017' in error_msg or '频率' in error_msg or 'Too many' in error_msg:
|
||||
delay = rate_limit_delay
|
||||
else:
|
||||
delay = backoff_delays[min(reconnect_attempts - 1, len(backoff_delays) - 1)]
|
||||
await self.logger.info(f'Reconnecting in {delay}s (attempt {reconnect_attempts})')
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
try:
|
||||
await self.logger.info('Connecting to WebSocket gateway...')
|
||||
ws = await websockets.connect(ws_url)
|
||||
await self.logger.info('WebSocket connected')
|
||||
except Exception as e:
|
||||
await self.logger.error(f'WebSocket connection failed: {e}')
|
||||
reconnect_attempts += 1
|
||||
delay = backoff_delays[min(reconnect_attempts - 1, len(backoff_delays) - 1)]
|
||||
await self.logger.info(f'Reconnecting in {delay}s (attempt {reconnect_attempts})')
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
try:
|
||||
async for raw_msg in ws:
|
||||
try:
|
||||
payload = json.loads(raw_msg)
|
||||
except json.JSONDecodeError:
|
||||
await self.logger.error(f'Failed to parse message: {raw_msg}')
|
||||
continue
|
||||
|
||||
op = payload.get('op')
|
||||
d = payload.get('d', {})
|
||||
s = payload.get('s')
|
||||
t = payload.get('t')
|
||||
|
||||
if not isinstance(d, dict):
|
||||
d = {}
|
||||
|
||||
if op == 10: # Hello
|
||||
heartbeat_interval = d.get('heartbeat_interval', 45000)
|
||||
await self.logger.info(f'Received Hello, heartbeat_interval={heartbeat_interval}ms')
|
||||
|
||||
# Send Identify or Resume
|
||||
if session_id and last_seq > 0:
|
||||
resume_payload = {
|
||||
'op': 6,
|
||||
'd': {
|
||||
'token': f'QQBot {self.access_token}',
|
||||
'session_id': session_id,
|
||||
'seq': last_seq,
|
||||
},
|
||||
}
|
||||
await ws.send(json.dumps(resume_payload))
|
||||
await self.logger.info(f'Sent Resume, session_id={session_id}, seq={last_seq}')
|
||||
else:
|
||||
identify_payload = {
|
||||
'op': 2,
|
||||
'd': {
|
||||
'token': f'QQBot {self.access_token}',
|
||||
'intents': self.FULL_INTENTS,
|
||||
'shard': [0, 1],
|
||||
},
|
||||
}
|
||||
await ws.send(json.dumps(identify_payload))
|
||||
await self.logger.info(f'Sent Identify, intents={self.FULL_INTENTS}')
|
||||
|
||||
# Start heartbeat
|
||||
async def _heartbeat_loop(conn, interval_ms):
|
||||
interval_sec = interval_ms / 1000.0
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(interval_sec)
|
||||
try:
|
||||
hb_payload = {'op': 1, 'd': last_seq}
|
||||
await conn.send(json.dumps(hb_payload))
|
||||
except Exception:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
heartbeat_task = asyncio.create_task(_heartbeat_loop(ws, heartbeat_interval))
|
||||
|
||||
elif op == 0: # Dispatch
|
||||
if s is not None:
|
||||
last_seq = s
|
||||
|
||||
if t == 'READY':
|
||||
session_id = d.get('session_id', '')
|
||||
reconnect_attempts = 0
|
||||
await self.logger.info(f'READY, session_id={session_id}')
|
||||
if on_ready:
|
||||
try:
|
||||
result = on_ready()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception:
|
||||
pass
|
||||
# Track token refresh task to avoid leaks
|
||||
if self._token_refresh_task and not self._token_refresh_task.done():
|
||||
self._token_refresh_task.cancel()
|
||||
try:
|
||||
await self._token_refresh_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._token_refresh_task = asyncio.create_task(self._background_token_refresh())
|
||||
|
||||
elif t == 'RESUMED':
|
||||
reconnect_attempts = 0
|
||||
await self.logger.info('RESUMED')
|
||||
|
||||
else:
|
||||
await self.logger.debug(f'Received event: {t}, seq={s}')
|
||||
if on_event:
|
||||
try:
|
||||
result = on_event(t, d)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception:
|
||||
await self.logger.error(f'Error handling event {t}: {traceback.format_exc()}')
|
||||
|
||||
elif op == 11: # Heartbeat ACK
|
||||
pass
|
||||
|
||||
elif op == 7: # Reconnect
|
||||
await self.logger.info('Received Reconnect directive')
|
||||
break
|
||||
|
||||
elif op == 9: # Invalid Session
|
||||
can_resume = d.get('can_resume', False)
|
||||
await self.logger.warning(f'Invalid Session, can_resume={can_resume}')
|
||||
if not can_resume:
|
||||
session_id = ''
|
||||
last_seq = 0
|
||||
should_refresh_token = True
|
||||
break
|
||||
|
||||
# Connection closed normally (end of async for)
|
||||
try:
|
||||
close_code = ws.close_code
|
||||
close_reason = ws.close_reason or ''
|
||||
except Exception:
|
||||
close_code = None
|
||||
close_reason = ''
|
||||
await self.logger.info(f'Connection closed, code={close_code}, reason={close_reason}')
|
||||
|
||||
if close_code == 4004:
|
||||
should_refresh_token = True
|
||||
elif close_code in (4006, 4007, 4009):
|
||||
session_id = ''
|
||||
last_seq = 0
|
||||
should_refresh_token = True
|
||||
elif close_code == 4008:
|
||||
reconnect_attempts += 1
|
||||
delay = rate_limit_delay
|
||||
await self.logger.info(
|
||||
f'Rate limited, waiting {delay}s before reconnect (attempt {reconnect_attempts})'
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
elif close_code in (4914, 4915):
|
||||
err = Exception(f'Bot disconnected/banned (close_code={close_code})')
|
||||
if on_error:
|
||||
await self._safe_callback(on_error, err)
|
||||
return
|
||||
elif close_code in (4900, 4901, 4902, 4903, 4904, 4905, 4906, 4907, 4908, 4909, 4910, 4911, 4912, 4913):
|
||||
session_id = ''
|
||||
last_seq = 0
|
||||
|
||||
if close_code == 1000:
|
||||
return
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
await self.logger.error(f'Unexpected error in WebSocket loop: {traceback.format_exc()}')
|
||||
finally:
|
||||
if heartbeat_task:
|
||||
heartbeat_task.cancel()
|
||||
try:
|
||||
await heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if ws:
|
||||
try:
|
||||
await ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If we reach here, we need to reconnect
|
||||
reconnect_attempts += 1
|
||||
if reconnect_attempts > max_reconnect_attempts:
|
||||
await self.logger.error(f'Max reconnect attempts ({max_reconnect_attempts}) reached, stopping')
|
||||
if on_error:
|
||||
await self._safe_callback(on_error, Exception('Max reconnect attempts reached'))
|
||||
return
|
||||
delay = backoff_delays[min(reconnect_attempts - 1, len(backoff_delays) - 1)]
|
||||
await self.logger.info(f'Reconnecting in {delay}s (attempt {reconnect_attempts})')
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async def _safe_callback(self, callback, *args):
|
||||
"""Safely invoke a callback, handling both sync and async functions."""
|
||||
try:
|
||||
result = callback(*args)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def connect_gateway_loop(
|
||||
self,
|
||||
on_event: Callable[[str, dict], Any],
|
||||
on_ready: Optional[Callable[[], Any]] = None,
|
||||
on_error: Optional[Callable[[Exception], Any]] = None,
|
||||
):
|
||||
"""持续重连的网关循环。"""
|
||||
await self.connect_gateway(on_event, on_ready, on_error)
|
||||
|
||||
@@ -71,11 +71,6 @@ class StreamSession:
|
||||
class StreamSessionManager:
|
||||
"""管理 stream 会话的生命周期,并负责队列的生产消费。"""
|
||||
|
||||
# Sessions with registered feedback_ids use a longer TTL to survive the
|
||||
# full like → cancel → dislike feedback flow. Must align with the adapter's
|
||||
# _stream_to_monitoring_msg TTL (wecombot.py).
|
||||
_FEEDBACK_SESSION_TTL = 600 # 10 minutes
|
||||
|
||||
def __init__(self, logger: EventLogger, ttl: int = 60) -> None:
|
||||
self.logger = logger
|
||||
|
||||
@@ -219,17 +214,11 @@ class StreamSessionManager:
|
||||
session.last_access = time.time()
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""定期清理过期会话,防止队列与映射无上限累积。
|
||||
|
||||
已注册 feedback_id 的会话使用更长的 TTL,确保用户在点赞/取消/点踩流程中
|
||||
不会因为 session 被提前清除而丢失上下文信息。
|
||||
"""
|
||||
"""定期清理过期会话,防止队列与映射无上限累积。"""
|
||||
now = time.time()
|
||||
expired: list[str] = []
|
||||
for stream_id, session in self._sessions.items():
|
||||
# Sessions with registered feedback_ids use a longer TTL
|
||||
effective_ttl = self._FEEDBACK_SESSION_TTL if session.feedback_id else self.ttl
|
||||
if now - session.last_access > effective_ttl:
|
||||
if now - session.last_access > self.ttl:
|
||||
expired.append(stream_id)
|
||||
|
||||
for stream_id in expired:
|
||||
@@ -239,9 +228,6 @@ class StreamSessionManager:
|
||||
msg_id = session.msg_id
|
||||
if msg_id and self._msg_index.get(msg_id) == stream_id:
|
||||
self._msg_index.pop(msg_id, None)
|
||||
# Clean up feedback index for expired sessions
|
||||
if session.feedback_id:
|
||||
self._feedback_index.pop(session.feedback_id, None)
|
||||
|
||||
|
||||
def _decrypt_file(encrypted_data: bytes, aes_key_str: str) -> bytes:
|
||||
@@ -606,120 +592,6 @@ async def parse_wecom_bot_message(
|
||||
if msg_json.get('aibotid'):
|
||||
message_data['aibotid'] = msg_json.get('aibotid', '')
|
||||
|
||||
# Handle quote (referenced message) - important for group chat file references
|
||||
quote_info = msg_json.get('quote')
|
||||
if quote_info:
|
||||
quote_data: dict[str, Any] = {}
|
||||
quote_type = quote_info.get('msgtype', '')
|
||||
quote_data['msgtype'] = quote_type
|
||||
|
||||
if quote_type == 'text':
|
||||
quote_data['content'] = quote_info.get('text', {}).get('content', '')
|
||||
elif quote_type == 'image':
|
||||
img_info = quote_info.get('image', {})
|
||||
img_url = img_info.get('url', '')
|
||||
img_aeskey = img_info.get('aeskey', '')
|
||||
base64_data = await _safe_download_as_data_uri(img_url, img_aeskey)
|
||||
if base64_data:
|
||||
quote_data['picurl'] = base64_data
|
||||
quote_data['images'] = [base64_data]
|
||||
elif quote_type == 'file':
|
||||
file_info = quote_info.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
item_aeskey = file_info.get('aeskey', '')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
# Same as private chat: append aeskey to download_url for plugin processing
|
||||
if download_url and item_aeskey:
|
||||
file_data['download_url'] = download_url + f'?aeskey={item_aeskey}'
|
||||
quote_data['file'] = file_data
|
||||
elif quote_type == 'voice':
|
||||
voice_info = quote_info.get('voice', {}) or {}
|
||||
download_url = voice_info.get('url')
|
||||
item_aeskey = voice_info.get('aeskey', '')
|
||||
voice_data = {
|
||||
'url': download_url,
|
||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||
}
|
||||
if voice_info.get('content'):
|
||||
quote_data['content'] = voice_info.get('content')
|
||||
# Same as private chat: append aeskey to url for plugin processing
|
||||
if download_url and item_aeskey:
|
||||
voice_data['url'] = download_url + f'?aeskey={item_aeskey}'
|
||||
quote_data['voice'] = voice_data
|
||||
elif quote_type == 'video':
|
||||
video_info = quote_info.get('video', {}) or {}
|
||||
download_url = video_info.get('url')
|
||||
item_aeskey = video_info.get('aeskey', '')
|
||||
video_data = {
|
||||
'url': download_url,
|
||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||
'filename': video_info.get('filename') or video_info.get('name'),
|
||||
}
|
||||
# Same as private chat: append aeskey to download_url for plugin processing
|
||||
if download_url and item_aeskey:
|
||||
video_data['download_url'] = download_url + f'?aeskey={item_aeskey}'
|
||||
quote_data['video'] = video_data
|
||||
elif quote_type == 'link':
|
||||
quote_data['link'] = quote_info.get('link', {})
|
||||
link = quote_data['link']
|
||||
title = link.get('title', '')
|
||||
desc = link.get('description') or link.get('digest', '')
|
||||
quote_data['content'] = '\n'.join(filter(None, [title, desc]))
|
||||
elif quote_type == 'mixed':
|
||||
# Handle mixed type in quote (text + images + files etc.)
|
||||
items = quote_info.get('mixed', {}).get('msg_item', [])
|
||||
texts = []
|
||||
images = []
|
||||
files = []
|
||||
for item in items:
|
||||
item_type = item.get('msgtype')
|
||||
if item_type == 'text':
|
||||
texts.append(item.get('text', {}).get('content', ''))
|
||||
elif item_type == 'image':
|
||||
img_info = item.get('image', {})
|
||||
img_url = img_info.get('url')
|
||||
img_aeskey = img_info.get('aeskey', '')
|
||||
base64_data = await _safe_download_as_data_uri(img_url, img_aeskey)
|
||||
if base64_data:
|
||||
images.append(base64_data)
|
||||
elif item_type == 'file':
|
||||
file_info = item.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
item_aeskey = file_info.get('aeskey', '')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
# Same as private chat: append aeskey to download_url for plugin processing
|
||||
if download_url and item_aeskey:
|
||||
file_data['download_url'] = download_url + f'?aeskey={item_aeskey}'
|
||||
files.append(file_data)
|
||||
if texts:
|
||||
quote_data['content'] = ' '.join(texts)
|
||||
if images:
|
||||
quote_data['images'] = images
|
||||
quote_data['picurl'] = images[0]
|
||||
if files:
|
||||
quote_data['files'] = files
|
||||
quote_data['file'] = files[0]
|
||||
|
||||
message_data['quote'] = quote_data
|
||||
|
||||
return message_data
|
||||
|
||||
|
||||
@@ -1031,38 +903,35 @@ class WecomBotClient:
|
||||
)
|
||||
|
||||
session = self.stream_sessions.get_session_by_feedback_id(feedback_id)
|
||||
|
||||
if session:
|
||||
await self.logger.info(
|
||||
f'反馈关联到会话: stream_id={session.stream_id}, msg_id={session.msg_id}, user_id={session.user_id}'
|
||||
)
|
||||
for handler in self._message_handlers.get('feedback', []):
|
||||
try:
|
||||
await handler(
|
||||
feedback_id=feedback_id,
|
||||
feedback_type=feedback_type,
|
||||
feedback_content=feedback_content,
|
||||
inaccurate_reasons=inaccurate_reasons,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
|
||||
if self._feedback_callback:
|
||||
try:
|
||||
await self._feedback_callback(
|
||||
feedback_id=feedback_id,
|
||||
feedback_type=feedback_type,
|
||||
feedback_content=feedback_content,
|
||||
inaccurate_reasons=inaccurate_reasons,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
else:
|
||||
await self.logger.warning(f'未找到 feedback_id={feedback_id} 对应的会话,仍将记录反馈')
|
||||
|
||||
# Dispatch feedback event regardless of session availability
|
||||
for handler in self._message_handlers.get('feedback', []):
|
||||
try:
|
||||
await handler(
|
||||
feedback_id=feedback_id,
|
||||
feedback_type=feedback_type,
|
||||
feedback_content=feedback_content,
|
||||
inaccurate_reasons=inaccurate_reasons,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
|
||||
if self._feedback_callback:
|
||||
try:
|
||||
await self._feedback_callback(
|
||||
feedback_id=feedback_id,
|
||||
feedback_type=feedback_type,
|
||||
feedback_content=feedback_content,
|
||||
inaccurate_reasons=inaccurate_reasons,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
await self.logger.warning(f'未找到 feedback_id={feedback_id} 对应的会话')
|
||||
|
||||
except Exception:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
|
||||
@@ -147,10 +147,3 @@ class WecomBotEvent(dict):
|
||||
流式消息 ID
|
||||
"""
|
||||
return self.get('stream_id', '')
|
||||
|
||||
@property
|
||||
def quote(self):
|
||||
"""
|
||||
引用消息信息(群聊中用户引用其他消息时返回)
|
||||
"""
|
||||
return self.get('quote', {})
|
||||
|
||||
97
src/langbot/pkg/api/http/controller/groups/human_takeover.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import quart
|
||||
|
||||
from .. import group
|
||||
|
||||
|
||||
@group.group_class('human-takeover', '/api/v1/human-takeover')
|
||||
class HumanTakeoverRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/sessions', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_sessions():
|
||||
"""Get list of takeover sessions, optionally filtered by bot UUID."""
|
||||
bot_uuid = quart.request.args.get('botUuid')
|
||||
limit = int(quart.request.args.get('limit', 100))
|
||||
offset = int(quart.request.args.get('offset', 0))
|
||||
|
||||
sessions, total = await self.ap.human_takeover_service.get_active_sessions(
|
||||
bot_uuid=bot_uuid if bot_uuid else None,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'sessions': sessions,
|
||||
'total': total,
|
||||
'limit': limit,
|
||||
'offset': offset,
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/sessions/<session_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_session_detail(session_id: str):
|
||||
"""Get detail for a specific takeover session."""
|
||||
detail = await self.ap.human_takeover_service.get_session_detail(session_id)
|
||||
if not detail:
|
||||
return self.success(data={'found': False, 'session_id': session_id})
|
||||
return self.success(data={'found': True, 'session': detail})
|
||||
|
||||
@self.route('/sessions/<session_id>/takeover', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def takeover_session(session_id: str, user_email: str = None):
|
||||
"""Take over a conversation session."""
|
||||
data = await quart.request.get_json(silent=True) or {}
|
||||
|
||||
bot_uuid = data.get('bot_uuid')
|
||||
if not bot_uuid:
|
||||
return self.fail(-1, 'bot_uuid is required')
|
||||
|
||||
platform = data.get('platform')
|
||||
user_id = data.get('user_id')
|
||||
user_name = data.get('user_name')
|
||||
|
||||
try:
|
||||
result = await self.ap.human_takeover_service.takeover_session(
|
||||
session_id=session_id,
|
||||
bot_uuid=bot_uuid,
|
||||
taken_by=user_email or data.get('taken_by'),
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
)
|
||||
return self.success(data=result)
|
||||
except ValueError as e:
|
||||
return self.fail(-1, str(e))
|
||||
|
||||
@self.route('/sessions/<session_id>/release', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def release_session(session_id: str):
|
||||
"""Release a taken-over session back to AI pipeline."""
|
||||
try:
|
||||
result = await self.ap.human_takeover_service.release_session(session_id)
|
||||
return self.success(data=result)
|
||||
except ValueError as e:
|
||||
return self.fail(-1, str(e))
|
||||
|
||||
@self.route('/sessions/<session_id>/message', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def send_message(session_id: str, user_email: str = None):
|
||||
"""Send a message from the operator to the user."""
|
||||
data = await quart.request.get_json(silent=True) or {}
|
||||
|
||||
message_text = data.get('message')
|
||||
if not message_text:
|
||||
return self.fail(-1, 'message is required')
|
||||
|
||||
operator_name = user_email or data.get('operator_name', 'Operator')
|
||||
|
||||
try:
|
||||
result = await self.ap.human_takeover_service.send_message(
|
||||
session_id=session_id,
|
||||
message_text=message_text,
|
||||
operator_name=operator_name,
|
||||
)
|
||||
return self.success(data=result)
|
||||
except ValueError as e:
|
||||
return self.fail(-1, str(e))
|
||||
except RuntimeError as e:
|
||||
return self.fail(-2, str(e))
|
||||
@@ -1,384 +0,0 @@
|
||||
"""Embed widget routes - serve embeddable chat widget for external websites.
|
||||
|
||||
All user-facing URLs are keyed by **bot_uuid** (not pipeline_uuid) so that
|
||||
internal pipeline identifiers are never exposed to end-users. Each handler
|
||||
resolves the bot_uuid to the owning ``web_page_bot`` RuntimeBot and extracts
|
||||
the bound pipeline_uuid for internal routing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
import hmac
|
||||
import hashlib
|
||||
import time
|
||||
import re
|
||||
import httpx
|
||||
|
||||
import quart
|
||||
|
||||
from ... import group
|
||||
from ......utils import paths
|
||||
from ......platform.sources.websocket_manager import ws_connection_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache the widget template content
|
||||
_widget_template_cache: str | None = None
|
||||
_logo_bytes_cache: bytes | None = None
|
||||
|
||||
|
||||
def _is_valid_uuid(s: str) -> bool:
|
||||
return bool(re.match(r'^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$', s))
|
||||
|
||||
|
||||
def _get_widget_template() -> str:
|
||||
"""Load and cache the widget JS template."""
|
||||
global _widget_template_cache
|
||||
if _widget_template_cache is None:
|
||||
template_path = paths.get_resource_path('templates/embed/widget.js')
|
||||
with open(template_path, 'r', encoding='utf-8') as f:
|
||||
_widget_template_cache = f.read()
|
||||
return _widget_template_cache
|
||||
|
||||
|
||||
def _get_logo_bytes() -> bytes:
|
||||
"""Load and cache the logo image."""
|
||||
global _logo_bytes_cache
|
||||
if _logo_bytes_cache is None:
|
||||
logo_path = paths.get_resource_path('templates/embed/logo.webp')
|
||||
with open(logo_path, 'rb') as f:
|
||||
_logo_bytes_cache = f.read()
|
||||
return _logo_bytes_cache
|
||||
|
||||
|
||||
@group.group_class('embed', '/api/v1/embed')
|
||||
class EmbedRouterGroup(group.RouterGroup):
|
||||
# -- helpers -------------------------------------------------------------
|
||||
|
||||
def _resolve_bot(self, bot_uuid: str):
|
||||
"""Resolve *bot_uuid* to ``(runtime_bot, pipeline_uuid)``.
|
||||
|
||||
Returns ``(None, None)`` when the bot does not exist, is not a
|
||||
``web_page_bot``, is disabled, or has no pipeline bound.
|
||||
"""
|
||||
for bot in self.ap.platform_mgr.bots:
|
||||
if (
|
||||
bot.bot_entity.uuid == bot_uuid
|
||||
and bot.bot_entity.adapter == 'web_page_bot'
|
||||
and bot.bot_entity.enable
|
||||
and bot.bot_entity.use_pipeline_uuid
|
||||
):
|
||||
return bot, bot.bot_entity.use_pipeline_uuid
|
||||
return None, None
|
||||
|
||||
def _get_bot_config(self, bot_uuid: str) -> dict:
|
||||
for bot in self.ap.platform_mgr.bots:
|
||||
if bot.bot_entity.uuid == bot_uuid and bot.bot_entity.adapter == 'web_page_bot':
|
||||
return bot.bot_entity.adapter_config
|
||||
return {}
|
||||
|
||||
async def _verify_session_token(self, request, bot_uuid: str) -> bool:
|
||||
config = self._get_bot_config(bot_uuid)
|
||||
secret = config.get('turnstile_secret_key', '')
|
||||
if not secret:
|
||||
return True
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
if not auth_header.startswith('Bearer '):
|
||||
return False
|
||||
token = auth_header[7:]
|
||||
try:
|
||||
ts_str, mac = token.split('.', 1)
|
||||
ts = float(ts_str)
|
||||
if time.time() - ts > 86400:
|
||||
return False
|
||||
expected_mac = hmac.new(secret.encode(), f'{ts_str}'.encode(), hashlib.sha256).hexdigest()
|
||||
return hmac.compare_digest(mac, expected_mac)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# -- routes --------------------------------------------------------------
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/<bot_uuid>/turnstile/verify', methods=['POST'], auth_type=group.AuthType.NONE)
|
||||
async def verify_turnstile(bot_uuid: str) -> str:
|
||||
if not _is_valid_uuid(bot_uuid):
|
||||
return self.http_status(400, -1, 'Invalid bot_uuid format')
|
||||
runtime_bot, pipeline_uuid = self._resolve_bot(bot_uuid)
|
||||
if runtime_bot is None:
|
||||
return self.http_status(404, -1, 'Bot not found or not available')
|
||||
try:
|
||||
data = await quart.request.get_json()
|
||||
token = data.get('token')
|
||||
if not token:
|
||||
return self.http_status(400, -1, 'Token is required')
|
||||
|
||||
config = self._get_bot_config(bot_uuid)
|
||||
secret = config.get('turnstile_secret_key', '')
|
||||
if not secret:
|
||||
ts = time.time()
|
||||
return self.success(data={'token': f'{ts}.dummy'})
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
'https://challenges.cloudflare.com/turnstile/v0/siteverify',
|
||||
data={'secret': secret, 'response': token},
|
||||
)
|
||||
result = resp.json()
|
||||
|
||||
if not result.get('success'):
|
||||
return self.http_status(403, -1, 'Turnstile verification failed')
|
||||
|
||||
ts = time.time()
|
||||
mac = hmac.new(secret.encode(), f'{ts}'.encode(), hashlib.sha256).hexdigest()
|
||||
session_token = f'{ts}.{mac}'
|
||||
|
||||
return self.success(data={'token': session_token})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Turnstile verify failed: {e}', exc_info=True)
|
||||
return self.http_status(500, -1, 'Internal server error')
|
||||
|
||||
@self.route('/<bot_uuid>/widget.js', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def serve_widget(bot_uuid: str) -> quart.Response:
|
||||
"""Serve the embed widget JavaScript with injected configuration."""
|
||||
if not _is_valid_uuid(bot_uuid):
|
||||
return self.http_status(400, -1, 'Invalid bot_uuid format')
|
||||
runtime_bot, pipeline_uuid = self._resolve_bot(bot_uuid)
|
||||
if runtime_bot is None:
|
||||
return quart.Response(
|
||||
'// Bot not found or not available', status=404, content_type='application/javascript'
|
||||
)
|
||||
try:
|
||||
template = _get_widget_template()
|
||||
except FileNotFoundError:
|
||||
return quart.Response('// Widget template not found', status=404, content_type='application/javascript')
|
||||
|
||||
base_url = quart.request.host_url.rstrip('/')
|
||||
webhook_prefix = self.ap.instance_config.data.get('api', {}).get('webhook_prefix', '')
|
||||
if webhook_prefix:
|
||||
base_url = webhook_prefix.rstrip('/')
|
||||
|
||||
if not re.match(r'^https?://[a-zA-Z0-9._:/-]+$', base_url):
|
||||
base_url = quart.request.host_url.rstrip('/')
|
||||
|
||||
config = self._get_bot_config(bot_uuid)
|
||||
site_key = config.get('turnstile_site_key', '')
|
||||
locale = config.get('language', 'en_US') or 'en_US'
|
||||
bubble_icon = config.get('bubble_icon', 'logo') or 'logo'
|
||||
widget_js = template.replace('__LANGBOT_TURNSTILE_SITE_KEY__', site_key)
|
||||
widget_js = widget_js.replace('__LANGBOT_BOT_UUID__', bot_uuid)
|
||||
widget_js = widget_js.replace('__LANGBOT_BASE_URL__', base_url)
|
||||
widget_js = widget_js.replace('__LANGBOT_LOCALE__', locale)
|
||||
widget_js = widget_js.replace('__LANGBOT_BUBBLE_ICON__', bubble_icon)
|
||||
|
||||
response = quart.Response(widget_js, content_type='application/javascript; charset=utf-8')
|
||||
response.headers['Cache-Control'] = 'public, max-age=300'
|
||||
return response
|
||||
|
||||
@self.route('/logo', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def serve_logo() -> quart.Response:
|
||||
"""Serve the LangBot logo for the embed widget."""
|
||||
try:
|
||||
logo_data = _get_logo_bytes()
|
||||
except FileNotFoundError:
|
||||
return quart.Response('', status=404)
|
||||
|
||||
response = quart.Response(logo_data, content_type='image/webp')
|
||||
response.headers['Cache-Control'] = 'public, max-age=86400'
|
||||
return response
|
||||
|
||||
@self.route('/<bot_uuid>/messages/<session_type>', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def get_embed_messages(bot_uuid: str, session_type: str) -> str:
|
||||
if not _is_valid_uuid(bot_uuid):
|
||||
return self.http_status(400, -1, 'Invalid bot_uuid format')
|
||||
runtime_bot, pipeline_uuid = self._resolve_bot(bot_uuid)
|
||||
if runtime_bot is None:
|
||||
return self.http_status(404, -1, 'Bot not found or not available')
|
||||
if not await self._verify_session_token(quart.request, bot_uuid):
|
||||
return self.http_status(403, -1, 'Unauthorized or session expired')
|
||||
try:
|
||||
if session_type not in ['person', 'group']:
|
||||
return self.http_status(400, -1, 'session_type must be person or group')
|
||||
|
||||
websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter
|
||||
if not websocket_adapter:
|
||||
return self.http_status(404, -1, 'WebSocket adapter not found')
|
||||
|
||||
messages = websocket_adapter.get_websocket_messages(pipeline_uuid, session_type)
|
||||
return self.success(data={'messages': messages})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to get embed messages: {e}', exc_info=True)
|
||||
return self.http_status(500, -1, 'Internal server error')
|
||||
|
||||
@self.route('/<bot_uuid>/reset/<session_type>', methods=['POST'], auth_type=group.AuthType.NONE)
|
||||
async def reset_embed_session(bot_uuid: str, session_type: str) -> str:
|
||||
if not _is_valid_uuid(bot_uuid):
|
||||
return self.http_status(400, -1, 'Invalid bot_uuid format')
|
||||
runtime_bot, pipeline_uuid = self._resolve_bot(bot_uuid)
|
||||
if runtime_bot is None:
|
||||
return self.http_status(404, -1, 'Bot not found or not available')
|
||||
if not await self._verify_session_token(quart.request, bot_uuid):
|
||||
return self.http_status(403, -1, 'Unauthorized or session expired')
|
||||
try:
|
||||
if session_type not in ['person', 'group']:
|
||||
return self.http_status(400, -1, 'session_type must be person or group')
|
||||
|
||||
websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter
|
||||
if not websocket_adapter:
|
||||
return self.http_status(404, -1, 'WebSocket adapter not found')
|
||||
|
||||
websocket_adapter.reset_session(pipeline_uuid, session_type)
|
||||
return self.success(data={'message': 'Session reset successfully'})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to reset embed session: {e}', exc_info=True)
|
||||
return self.http_status(500, -1, 'Internal server error')
|
||||
|
||||
@self.route('/<bot_uuid>/feedback', methods=['POST'], auth_type=group.AuthType.NONE)
|
||||
async def submit_feedback(bot_uuid: str) -> str:
|
||||
if not _is_valid_uuid(bot_uuid):
|
||||
return self.http_status(400, -1, 'Invalid bot_uuid format')
|
||||
runtime_bot, pipeline_uuid = self._resolve_bot(bot_uuid)
|
||||
if runtime_bot is None:
|
||||
return self.http_status(404, -1, 'Bot not found or not available')
|
||||
if not await self._verify_session_token(quart.request, bot_uuid):
|
||||
return self.http_status(403, -1, 'Unauthorized or session expired')
|
||||
try:
|
||||
data = await quart.request.get_json()
|
||||
message_id = data.get('message_id', '')
|
||||
feedback_type = data.get('feedback_type')
|
||||
|
||||
if feedback_type not in (1, 2, 3):
|
||||
return self.http_status(400, -1, 'feedback_type must be 1 (like), 2 (dislike), or 3 (cancel)')
|
||||
|
||||
feedback_id = f'embed_{uuid.uuid4().hex[:12]}'
|
||||
|
||||
await self.ap.monitoring_service.record_feedback(
|
||||
feedback_id=feedback_id,
|
||||
feedback_type=feedback_type,
|
||||
bot_id=runtime_bot.bot_entity.uuid,
|
||||
bot_name=runtime_bot.bot_entity.name or bot_uuid,
|
||||
pipeline_id=pipeline_uuid,
|
||||
message_id=str(message_id),
|
||||
platform='web_page_bot',
|
||||
)
|
||||
|
||||
return self.success(data={'feedback_id': feedback_id})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to record feedback: {e}', exc_info=True)
|
||||
return self.http_status(500, -1, 'Internal server error')
|
||||
|
||||
# -- Embed WebSocket endpoint ----------------------------------------
|
||||
|
||||
@self.quart_app.websocket(self.path + '/<bot_uuid>/ws/connect')
|
||||
async def embed_websocket_connect(bot_uuid: str):
|
||||
"""WebSocket connection for embed widget, keyed by bot_uuid."""
|
||||
if not _is_valid_uuid(bot_uuid):
|
||||
await quart.websocket.send(json.dumps({'type': 'error', 'message': 'Invalid bot_uuid format'}))
|
||||
return
|
||||
|
||||
runtime_bot, pipeline_uuid = self._resolve_bot(bot_uuid)
|
||||
if runtime_bot is None:
|
||||
await quart.websocket.send(json.dumps({'type': 'error', 'message': 'Bot not found or not available'}))
|
||||
return
|
||||
|
||||
session_type = quart.websocket.args.get('session_type', 'person')
|
||||
if session_type not in ['person', 'group']:
|
||||
await quart.websocket.send(
|
||||
json.dumps({'type': 'error', 'message': 'session_type must be person or group'})
|
||||
)
|
||||
return
|
||||
|
||||
websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter
|
||||
if not websocket_adapter:
|
||||
await quart.websocket.send(json.dumps({'type': 'error', 'message': 'WebSocket adapter not found'}))
|
||||
return
|
||||
|
||||
try:
|
||||
connection = await ws_connection_manager.add_connection(
|
||||
websocket=quart.websocket._get_current_object(),
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
session_type=session_type,
|
||||
metadata={'user_agent': quart.websocket.headers.get('User-Agent', '')},
|
||||
)
|
||||
|
||||
await quart.websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
'type': 'connected',
|
||||
'connection_id': connection.connection_id,
|
||||
'bot_uuid': bot_uuid,
|
||||
'session_type': session_type,
|
||||
'timestamp': connection.created_at.isoformat(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f'Embed WebSocket connected: {connection.connection_id} '
|
||||
f'(bot={bot_uuid}, pipeline={pipeline_uuid}, session_type={session_type})'
|
||||
)
|
||||
|
||||
receive_task = asyncio.create_task(self._handle_receive(connection, websocket_adapter, runtime_bot))
|
||||
send_task = asyncio.create_task(self._handle_send(connection))
|
||||
|
||||
try:
|
||||
await asyncio.gather(receive_task, send_task)
|
||||
except Exception as e:
|
||||
logger.error(f'Embed WebSocket task error: {e}')
|
||||
finally:
|
||||
await ws_connection_manager.remove_connection(connection.connection_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Embed WebSocket connection error: {e}', exc_info=True)
|
||||
try:
|
||||
await quart.websocket.send(json.dumps({'type': 'error', 'message': 'Internal server error'}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -- WebSocket receive/send helpers --------------------------------------
|
||||
|
||||
async def _handle_receive(self, connection, websocket_adapter, owner_bot):
|
||||
try:
|
||||
while connection.is_active:
|
||||
message = await quart.websocket.receive()
|
||||
await ws_connection_manager.update_activity(connection.connection_id)
|
||||
|
||||
try:
|
||||
data = json.loads(message)
|
||||
message_type = data.get('type', 'message')
|
||||
|
||||
if message_type == 'ping':
|
||||
await connection.send_queue.put(
|
||||
{'type': 'pong', 'timestamp': datetime.datetime.now().isoformat()}
|
||||
)
|
||||
elif message_type == 'message':
|
||||
await websocket_adapter.handle_websocket_message(connection, data, owner_bot=owner_bot)
|
||||
elif message_type == 'disconnect':
|
||||
break
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await connection.send_queue.put({'type': 'error', 'message': 'Invalid JSON format'})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Embed receive error: {e}', exc_info=True)
|
||||
finally:
|
||||
connection.is_active = False
|
||||
|
||||
async def _handle_send(self, connection):
|
||||
try:
|
||||
while connection.is_active:
|
||||
try:
|
||||
message = await asyncio.wait_for(connection.send_queue.get(), timeout=1.0)
|
||||
await quart.websocket.send(json.dumps(message))
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f'Embed send error: {e}', exc_info=True)
|
||||
finally:
|
||||
connection.is_active = False
|
||||
@@ -43,9 +43,6 @@ class WebSocketChatRouterGroup(group.RouterGroup):
|
||||
await quart.websocket.send(json.dumps({'type': 'error', 'message': 'WebSocket adapter not found'}))
|
||||
return
|
||||
|
||||
# Find the owning bot for this pipeline (e.g. a web_page_bot)
|
||||
owner_bot = self._find_owner_bot(pipeline_uuid)
|
||||
|
||||
# 注册连接
|
||||
connection = await ws_connection_manager.add_connection(
|
||||
websocket=quart.websocket._get_current_object(),
|
||||
@@ -73,7 +70,7 @@ class WebSocketChatRouterGroup(group.RouterGroup):
|
||||
)
|
||||
|
||||
# 创建接收和发送任务
|
||||
receive_task = asyncio.create_task(self._handle_receive(connection, websocket_adapter, owner_bot))
|
||||
receive_task = asyncio.create_task(self._handle_receive(connection, websocket_adapter))
|
||||
send_task = asyncio.create_task(self._handle_send(connection))
|
||||
|
||||
# 等待任务完成
|
||||
@@ -181,14 +178,7 @@ class WebSocketChatRouterGroup(group.RouterGroup):
|
||||
except Exception as e:
|
||||
return self.http_status(500, -1, f'Internal server error: {str(e)}')
|
||||
|
||||
def _find_owner_bot(self, pipeline_uuid: str):
|
||||
"""Find a user-created bot (e.g. web_page_bot) that owns this pipeline."""
|
||||
for bot in self.ap.platform_mgr.bots:
|
||||
if bot.bot_entity.adapter == 'web_page_bot' and bot.bot_entity.use_pipeline_uuid == pipeline_uuid:
|
||||
return bot
|
||||
return None
|
||||
|
||||
async def _handle_receive(self, connection, websocket_adapter, owner_bot=None):
|
||||
async def _handle_receive(self, connection, websocket_adapter):
|
||||
"""处理接收消息的任务"""
|
||||
try:
|
||||
while connection.is_active:
|
||||
@@ -213,7 +203,7 @@ class WebSocketChatRouterGroup(group.RouterGroup):
|
||||
logger.debug(f'收到消息: {data} from {connection.connection_id}')
|
||||
|
||||
# 处理消息(不等待响应,响应会通过broadcast异步发送)
|
||||
await websocket_adapter.handle_websocket_message(connection, data, owner_bot=owner_bot)
|
||||
await websocket_adapter.handle_websocket_message(connection, data)
|
||||
|
||||
elif message_type == 'disconnect':
|
||||
# 客户端主动断开
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import quart
|
||||
import mimetypes
|
||||
import asyncio
|
||||
from ... import group
|
||||
from langbot.pkg.utils import importutil
|
||||
|
||||
@@ -36,640 +35,3 @@ class AdaptersRouterGroup(group.RouterGroup):
|
||||
return quart.Response(
|
||||
importutil.read_resource_file_bytes(icon_path), mimetype=mimetypes.guess_type(icon_path)[0]
|
||||
)
|
||||
|
||||
# In-memory session store for active registrations
|
||||
_create_app_sessions: dict = {}
|
||||
_SESSION_TTL = 900 # 15 minutes
|
||||
|
||||
def _cleanup_expired_sessions():
|
||||
"""Remove sessions that have exceeded their TTL."""
|
||||
import time
|
||||
|
||||
now = time.time()
|
||||
expired = [sid for sid, s in _create_app_sessions.items() if now - s.get('created_at', 0) > _SESSION_TTL]
|
||||
for sid in expired:
|
||||
session = _create_app_sessions.pop(sid, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
|
||||
@self.route('/lark/create-app', methods=['POST'])
|
||||
async def _() -> str:
|
||||
"""Start Feishu one-click app registration. Returns session_id + QR code URL."""
|
||||
import uuid
|
||||
import time
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.scene.registration.errors import AppAccessDeniedError, AppExpiredError
|
||||
|
||||
_cleanup_expired_sessions()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
session = {
|
||||
'status': 'pending',
|
||||
'qr_url': None,
|
||||
'expire_at': None,
|
||||
'app_id': None,
|
||||
'app_secret': None,
|
||||
'error': None,
|
||||
'created_at': time.time(),
|
||||
}
|
||||
_create_app_sessions[session_id] = session
|
||||
|
||||
def on_qr_code(info):
|
||||
# May be called from a background thread by the SDK;
|
||||
# use call_soon_threadsafe to safely update session state.
|
||||
def _update():
|
||||
session['qr_url'] = info['url']
|
||||
session['expire_at'] = time.time() + 600 # 10 minutes
|
||||
session['status'] = 'waiting'
|
||||
|
||||
loop.call_soon_threadsafe(_update)
|
||||
|
||||
async def run_registration():
|
||||
try:
|
||||
result = await lark.aregister_app(
|
||||
on_qr_code=on_qr_code,
|
||||
source='langbot',
|
||||
)
|
||||
session['status'] = 'success'
|
||||
session['app_id'] = result['client_id']
|
||||
session['app_secret'] = result['client_secret']
|
||||
except AppAccessDeniedError:
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'User denied authorization'
|
||||
except AppExpiredError:
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code expired'
|
||||
except Exception as e:
|
||||
session['status'] = 'error'
|
||||
session['error'] = str(e)
|
||||
|
||||
task = asyncio.create_task(run_registration())
|
||||
session['task'] = task
|
||||
|
||||
# Wait for QR code to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
if session['qr_url']:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not session['qr_url']:
|
||||
task.cancel()
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Timeout waiting for QR code'
|
||||
return self.http_status(504, -1, 'Timeout waiting for QR code')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'session_id': session_id,
|
||||
'qr_url': session['qr_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/lark/create-app/status/<session_id>', methods=['GET'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Poll registration status."""
|
||||
session = _create_app_sessions.get(session_id)
|
||||
if not session:
|
||||
return self.http_status(404, -1, 'Session not found')
|
||||
|
||||
data = {'status': session['status']}
|
||||
|
||||
if session['status'] == 'success':
|
||||
data['app_id'] = session['app_id']
|
||||
data['app_secret'] = session['app_secret']
|
||||
_create_app_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'error':
|
||||
data['error'] = session['error']
|
||||
_create_app_sessions.pop(session_id, None)
|
||||
|
||||
return self.success(data=data)
|
||||
|
||||
@self.route('/lark/create-app/<session_id>', methods=['DELETE'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Cancel and clean up a registration session."""
|
||||
session = _create_app_sessions.pop(session_id, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
return self.success(data={})
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# WeChat QR Code Login
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
_weixin_login_sessions: dict = {}
|
||||
_WEIXIN_SESSION_TTL = 600 # 10 minutes (3 retries × 3 min QR validity)
|
||||
|
||||
def _cleanup_expired_weixin_sessions():
|
||||
import time
|
||||
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid for sid, s in _weixin_login_sessions.items() if now - s.get('created_at', 0) > _WEIXIN_SESSION_TTL
|
||||
]
|
||||
for sid in expired:
|
||||
session = _weixin_login_sessions.pop(sid, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
|
||||
@self.route('/weixin/login', methods=['POST'])
|
||||
async def _() -> str:
|
||||
"""Start WeChat QR code login. Returns session_id + QR code data URL."""
|
||||
import uuid
|
||||
import time
|
||||
import io
|
||||
import base64
|
||||
|
||||
from langbot.libs.openclaw_weixin_api.client import OpenClawWeixinClient, DEFAULT_BASE_URL
|
||||
|
||||
_cleanup_expired_weixin_sessions()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
session = {
|
||||
'status': 'pending',
|
||||
'qr_data_url': None,
|
||||
'expire_at': None,
|
||||
'token': None,
|
||||
'base_url': None,
|
||||
'account_id': None,
|
||||
'error': None,
|
||||
'created_at': time.time(),
|
||||
}
|
||||
_weixin_login_sessions[session_id] = session
|
||||
|
||||
client = OpenClawWeixinClient(
|
||||
base_url=DEFAULT_BASE_URL,
|
||||
token='',
|
||||
)
|
||||
|
||||
async def run_login():
|
||||
try:
|
||||
import qrcode as qr_lib
|
||||
|
||||
for _attempt in range(3):
|
||||
qr_resp = await client.fetch_qrcode()
|
||||
if not qr_resp.qrcode or not qr_resp.qrcode_img_content:
|
||||
raise Exception('Failed to get QR code from server')
|
||||
|
||||
# Generate QR code image locally
|
||||
qr = qr_lib.QRCode(error_correction=qr_lib.constants.ERROR_CORRECT_L)
|
||||
qr.add_data(qr_resp.qrcode_img_content)
|
||||
qr.make(fit=True)
|
||||
img = qr.make_image(fill_color='black', back_color='white')
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format='PNG')
|
||||
b64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
||||
data_url = f'data:image/png;base64,{b64}'
|
||||
|
||||
def _update_qr():
|
||||
session['qr_data_url'] = data_url
|
||||
session['expire_at'] = time.time() + 480 # 8 minutes
|
||||
session['status'] = 'waiting'
|
||||
|
||||
loop.call_soon_threadsafe(_update_qr)
|
||||
|
||||
# Poll for scan status
|
||||
deadline = loop.time() + 180
|
||||
while loop.time() < deadline:
|
||||
try:
|
||||
status_resp = await client.poll_qrcode_status(qr_resp.qrcode)
|
||||
except Exception:
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
if status_resp.status == 'confirmed' and status_resp.bot_token:
|
||||
session['status'] = 'success'
|
||||
session['token'] = status_resp.bot_token
|
||||
session['base_url'] = status_resp.baseurl or client.base_url
|
||||
session['account_id'] = status_resp.ilink_bot_id or ''
|
||||
return
|
||||
|
||||
if status_resp.status == 'expired':
|
||||
break # retry with new QR code
|
||||
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
pass # timeout, retry
|
||||
|
||||
# All retries exhausted
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code login failed: max retries exceeded'
|
||||
|
||||
except Exception as e:
|
||||
session['status'] = 'error'
|
||||
session['error'] = str(e)
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
task = asyncio.create_task(run_login())
|
||||
session['task'] = task
|
||||
|
||||
# Wait for QR code to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
if session['qr_data_url']:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not session['qr_data_url']:
|
||||
task.cancel()
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Timeout waiting for QR code'
|
||||
return self.http_status(504, -1, 'Timeout waiting for QR code')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'session_id': session_id,
|
||||
'qr_data_url': session['qr_data_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/weixin/login/status/<session_id>', methods=['GET'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Poll WeChat login status."""
|
||||
session = _weixin_login_sessions.get(session_id)
|
||||
if not session:
|
||||
return self.http_status(404, -1, 'Session not found')
|
||||
|
||||
data = {'status': session['status']}
|
||||
|
||||
if session['status'] == 'success':
|
||||
data['token'] = session['token']
|
||||
data['base_url'] = session['base_url']
|
||||
data['account_id'] = session['account_id']
|
||||
_weixin_login_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'error':
|
||||
data['error'] = session['error']
|
||||
_weixin_login_sessions.pop(session_id, None)
|
||||
|
||||
return self.success(data=data)
|
||||
|
||||
@self.route('/weixin/login/<session_id>', methods=['DELETE'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Cancel and clean up a WeChat login session."""
|
||||
session = _weixin_login_sessions.pop(session_id, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
return self.success(data={})
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# DingTalk Device Flow QR Code Login
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
_dingtalk_sessions: dict = {}
|
||||
_DINGTALK_SESSION_TTL = 600 # 10 minutes (QR code validity window)
|
||||
|
||||
def _cleanup_expired_dingtalk_sessions():
|
||||
import time
|
||||
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid for sid, s in _dingtalk_sessions.items() if now - s.get('created_at', 0) > _DINGTALK_SESSION_TTL
|
||||
]
|
||||
for sid in expired:
|
||||
session = _dingtalk_sessions.pop(sid, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
|
||||
@self.route('/dingtalk/create-app', methods=['POST'])
|
||||
async def _() -> str:
|
||||
"""Start DingTalk one-click app creation via Device Flow. Returns session_id + QR code URL."""
|
||||
import uuid
|
||||
import time
|
||||
import aiohttp
|
||||
|
||||
DINGTALK_BASE_URL = 'https://oapi.dingtalk.com'
|
||||
|
||||
_cleanup_expired_dingtalk_sessions()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
session = {
|
||||
'status': 'pending',
|
||||
'qr_url': None,
|
||||
'expire_at': None,
|
||||
'client_id': None,
|
||||
'client_secret': None,
|
||||
'error': None,
|
||||
'created_at': time.time(),
|
||||
'device_code': None,
|
||||
'interval': 5,
|
||||
}
|
||||
_dingtalk_sessions[session_id] = session
|
||||
|
||||
async def run_device_flow():
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as http:
|
||||
# Step 1: Init — get nonce
|
||||
async with http.post(
|
||||
f'{DINGTALK_BASE_URL}/app/registration/init',
|
||||
json={'source': 'langbot'},
|
||||
) as resp:
|
||||
try:
|
||||
data = await resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Invalid response from DingTalk service'
|
||||
return
|
||||
if data.get('errcode', -1) != 0:
|
||||
session['status'] = 'error'
|
||||
session['error'] = data.get('errmsg', 'Failed to init')
|
||||
return
|
||||
nonce = data['nonce']
|
||||
|
||||
# Step 2: Begin — get device_code + QR URL
|
||||
async with http.post(
|
||||
f'{DINGTALK_BASE_URL}/app/registration/begin',
|
||||
json={'nonce': nonce},
|
||||
) as resp:
|
||||
try:
|
||||
data = await resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Invalid response from DingTalk service'
|
||||
return
|
||||
if data.get('errcode', -1) != 0:
|
||||
session['status'] = 'error'
|
||||
session['error'] = data.get('errmsg', 'Failed to begin authorization')
|
||||
return
|
||||
|
||||
device_code = data['device_code']
|
||||
verification_uri_complete = data.get('verification_uri_complete', '')
|
||||
expires_in = data.get('expires_in', 7200)
|
||||
interval = data.get('interval', 5)
|
||||
|
||||
session['device_code'] = device_code
|
||||
session['interval'] = interval
|
||||
session['qr_url'] = verification_uri_complete
|
||||
session['expire_at'] = time.time() + 600 # QR code valid for ~10 min
|
||||
session['status'] = 'waiting'
|
||||
|
||||
# Step 3: Poll for authorization result
|
||||
deadline = time.time() + expires_in
|
||||
while time.time() < deadline:
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
async with http.post(
|
||||
f'{DINGTALK_BASE_URL}/app/registration/poll',
|
||||
json={'device_code': device_code},
|
||||
) as poll_resp:
|
||||
try:
|
||||
poll_data = await poll_resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
continue
|
||||
|
||||
if poll_data.get('errcode', -1) != 0:
|
||||
session['status'] = 'error'
|
||||
session['error'] = poll_data.get('errmsg', 'Poll failed')
|
||||
return
|
||||
|
||||
status = poll_data.get('status', '')
|
||||
|
||||
if status == 'SUCCESS':
|
||||
session['status'] = 'success'
|
||||
session['client_id'] = poll_data.get('client_id', '')
|
||||
session['client_secret'] = poll_data.get('client_secret', '')
|
||||
return
|
||||
elif status == 'FAIL':
|
||||
session['status'] = 'error'
|
||||
session['error'] = poll_data.get('fail_reason', 'Authorization failed')
|
||||
return
|
||||
elif status == 'EXPIRED':
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code expired'
|
||||
return
|
||||
# status == 'WAITING': continue polling
|
||||
|
||||
# Timeout
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code expired'
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
session['status'] = 'error'
|
||||
session['error'] = str(e)
|
||||
|
||||
task = asyncio.create_task(run_device_flow())
|
||||
session['task'] = task
|
||||
|
||||
# Wait for QR code to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
if session['qr_url'] or session['error']:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if session['error']:
|
||||
task.cancel()
|
||||
return self.http_status(502, -1, session['error'])
|
||||
|
||||
if not session['qr_url']:
|
||||
task.cancel()
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Timeout waiting for QR code'
|
||||
return self.http_status(504, -1, 'Timeout waiting for QR code')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'session_id': session_id,
|
||||
'qr_url': session['qr_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/dingtalk/create-app/status/<session_id>', methods=['GET'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Poll DingTalk Device Flow status."""
|
||||
_cleanup_expired_dingtalk_sessions()
|
||||
session = _dingtalk_sessions.get(session_id)
|
||||
if not session:
|
||||
return self.http_status(404, -1, 'Session not found')
|
||||
|
||||
data = {'status': session['status']}
|
||||
|
||||
if session['status'] == 'success':
|
||||
data['client_id'] = session['client_id']
|
||||
data['client_secret'] = session['client_secret']
|
||||
_dingtalk_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'error':
|
||||
data['error'] = session['error']
|
||||
_dingtalk_sessions.pop(session_id, None)
|
||||
|
||||
return self.success(data=data)
|
||||
|
||||
@self.route('/dingtalk/create-app/<session_id>', methods=['DELETE'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Cancel and clean up a DingTalk Device Flow session."""
|
||||
session = _dingtalk_sessions.pop(session_id, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
return self.success(data={})
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# WeComBot QR Code One-Click Create
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
_wecombot_sessions: dict = {}
|
||||
_WECOMBOT_SESSION_TTL = 300 # 5 minutes (WeCom QR validity window)
|
||||
|
||||
def _cleanup_expired_wecombot_sessions():
|
||||
import time
|
||||
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid for sid, s in _wecombot_sessions.items() if now - s.get('created_at', 0) > _WECOMBOT_SESSION_TTL
|
||||
]
|
||||
for sid in expired:
|
||||
session = _wecombot_sessions.pop(sid, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
|
||||
@self.route('/wecombot/create-bot', methods=['POST'])
|
||||
async def _() -> str:
|
||||
"""Start WeComBot one-click creation via QR code. Returns session_id + QR code URL."""
|
||||
import uuid
|
||||
import time
|
||||
import aiohttp
|
||||
|
||||
WECOM_QC_GENERATE_URL = 'https://work.weixin.qq.com/ai/qc/generate'
|
||||
WECOM_QC_QUERY_URL = 'https://work.weixin.qq.com/ai/qc/query_result'
|
||||
|
||||
_cleanup_expired_wecombot_sessions()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
session = {
|
||||
'status': 'pending',
|
||||
'qr_url': None,
|
||||
'expire_at': None,
|
||||
'botid': None,
|
||||
'secret': None,
|
||||
'error': None,
|
||||
'created_at': time.time(),
|
||||
'scode': None,
|
||||
'task': None,
|
||||
}
|
||||
_wecombot_sessions[session_id] = session
|
||||
|
||||
async def run_qr_flow():
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as http:
|
||||
# Step 1: Generate QR code
|
||||
async with http.get(
|
||||
f'{WECOM_QC_GENERATE_URL}?source=langbot&plat=0',
|
||||
) as resp:
|
||||
try:
|
||||
data = await resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Invalid response from WeCom service'
|
||||
return
|
||||
if not data.get('data', {}).get('scode') or not data.get('data', {}).get('auth_url'):
|
||||
session['status'] = 'error'
|
||||
session['error'] = data.get('errmsg', 'Failed to generate QR code')
|
||||
return
|
||||
|
||||
scode = data['data']['scode']
|
||||
auth_url = data['data']['auth_url']
|
||||
|
||||
session['scode'] = scode
|
||||
session['qr_url'] = auth_url
|
||||
session['expire_at'] = time.time() + _WECOMBOT_SESSION_TTL
|
||||
session['status'] = 'waiting'
|
||||
|
||||
# Step 2: Poll for scan result
|
||||
deadline = time.time() + _WECOMBOT_SESSION_TTL
|
||||
while time.time() < deadline:
|
||||
await asyncio.sleep(3)
|
||||
|
||||
async with http.get(
|
||||
f'{WECOM_QC_QUERY_URL}?scode={scode}',
|
||||
) as poll_resp:
|
||||
try:
|
||||
poll_data = await poll_resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
continue
|
||||
|
||||
status = poll_data.get('data', {}).get('status', '')
|
||||
if status == 'success':
|
||||
bot_info = poll_data.get('data', {}).get('bot_info', {})
|
||||
if bot_info.get('botid') and bot_info.get('secret'):
|
||||
session['status'] = 'success'
|
||||
session['botid'] = bot_info['botid']
|
||||
session['secret'] = bot_info['secret']
|
||||
return
|
||||
else:
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Scan succeeded but bot info is incomplete'
|
||||
return
|
||||
|
||||
# Timeout
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code expired'
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
session['status'] = 'error'
|
||||
session['error'] = str(e)
|
||||
|
||||
task = asyncio.create_task(run_qr_flow())
|
||||
session['task'] = task
|
||||
|
||||
# Wait for QR code to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
if session['qr_url'] or session['error']:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if session['error']:
|
||||
task.cancel()
|
||||
return self.http_status(502, -1, session['error'])
|
||||
|
||||
if not session['qr_url']:
|
||||
task.cancel()
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'Timeout waiting for QR code'
|
||||
return self.http_status(504, -1, 'Timeout waiting for QR code')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'session_id': session_id,
|
||||
'qr_url': session['qr_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/wecombot/create-bot/status/<session_id>', methods=['GET'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Poll WeComBot creation status."""
|
||||
_cleanup_expired_wecombot_sessions()
|
||||
session = _wecombot_sessions.get(session_id)
|
||||
if not session:
|
||||
return self.http_status(404, -1, 'Session not found')
|
||||
|
||||
data = {'status': session['status']}
|
||||
|
||||
if session['status'] == 'success':
|
||||
data['botid'] = session['botid']
|
||||
data['secret'] = session['secret']
|
||||
_wecombot_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'error':
|
||||
data['error'] = session['error']
|
||||
_wecombot_sessions.pop(session_id, None)
|
||||
|
||||
return self.success(data=data)
|
||||
|
||||
@self.route('/wecombot/create-bot/<session_id>', methods=['DELETE'])
|
||||
async def _(session_id: str) -> str:
|
||||
"""Cancel and clean up a WeComBot creation session."""
|
||||
session = _wecombot_sessions.pop(session_id, None)
|
||||
if session and session.get('task') and not session['task'].done():
|
||||
session['task'].cancel()
|
||||
return self.success(data={})
|
||||
|
||||
@@ -6,48 +6,11 @@ import re
|
||||
import httpx
|
||||
import uuid
|
||||
import os
|
||||
import posixpath
|
||||
|
||||
from .....core import taskmgr
|
||||
from .. import group
|
||||
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
|
||||
|
||||
# Resolve the built-in page SDK JS from the langbot_plugin package
|
||||
_PAGE_SDK_PATH = None
|
||||
try:
|
||||
import langbot_plugin.assets as _assets_pkg
|
||||
|
||||
_candidate = os.path.join(os.path.dirname(_assets_pkg.__file__), 'langbot-page-sdk.js')
|
||||
if os.path.exists(_candidate):
|
||||
_PAGE_SDK_PATH = _candidate
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _normalize_plugin_asset_path(filepath: str) -> str | None:
|
||||
filepath = filepath.replace('\\', '/')
|
||||
if filepath.startswith('/'):
|
||||
return None
|
||||
|
||||
normalized = posixpath.normpath(filepath)
|
||||
if normalized == '.' or normalized.startswith('../') or normalized == '..':
|
||||
return None
|
||||
|
||||
if normalized.startswith('components/pages/'):
|
||||
return normalized
|
||||
|
||||
return f'assets/{normalized}'
|
||||
|
||||
|
||||
def _get_request_origin() -> str:
|
||||
"""Return the public request origin, respecting reverse-proxy headers."""
|
||||
forwarded_proto = quart.request.headers.get('X-Forwarded-Proto', '').split(',')[0].strip()
|
||||
forwarded_host = quart.request.headers.get('X-Forwarded-Host', '').split(',')[0].strip()
|
||||
|
||||
scheme = forwarded_proto or quart.request.scheme
|
||||
host = forwarded_host or quart.request.host
|
||||
return f'{scheme}://{host}'
|
||||
|
||||
|
||||
@group.group_class('plugins', '/api/v1/plugins')
|
||||
class PluginsRouterGroup(group.RouterGroup):
|
||||
@@ -64,15 +27,6 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
return None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/_sdk/page-sdk.js', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def _() -> quart.Response:
|
||||
"""Serve the built-in LangBot page SDK JavaScript."""
|
||||
if _PAGE_SDK_PATH and os.path.exists(_PAGE_SDK_PATH):
|
||||
with open(_PAGE_SDK_PATH, 'r') as f:
|
||||
content = f.read()
|
||||
return quart.Response(content, mimetype='application/javascript')
|
||||
return quart.Response('// SDK not found', status=404, mimetype='application/javascript')
|
||||
|
||||
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
plugins = await self.ap.plugin_connector.list_plugins()
|
||||
@@ -181,62 +135,15 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
return quart.Response(icon_data, mimetype=mime_type)
|
||||
|
||||
@self.route(
|
||||
'/<author>/<plugin_name>/assets/<path:filepath>',
|
||||
'/<author>/<plugin_name>/assets/<filepath>',
|
||||
methods=['GET'],
|
||||
auth_type=group.AuthType.NONE,
|
||||
)
|
||||
async def _(author: str, plugin_name: str, filepath: str) -> quart.Response:
|
||||
asset_path = _normalize_plugin_asset_path(filepath)
|
||||
if asset_path is None:
|
||||
return quart.Response('Asset not found', status=404)
|
||||
|
||||
asset_data = await self.ap.plugin_connector.get_plugin_assets(author, plugin_name, asset_path)
|
||||
if not asset_data.get('asset_base64'):
|
||||
return quart.Response('Asset not found', status=404)
|
||||
asset_data = await self.ap.plugin_connector.get_plugin_assets(author, plugin_name, filepath)
|
||||
asset_bytes = base64.b64decode(asset_data['asset_base64'])
|
||||
mime_type = asset_data['mime_type']
|
||||
resp = quart.Response(asset_bytes, mimetype=mime_type)
|
||||
# CSP for HTML pages served to sandboxed iframes (opaque origin).
|
||||
# 'self' doesn't work in sandboxed iframes — use actual server origin.
|
||||
if mime_type and mime_type.startswith('text/html'):
|
||||
origin = _get_request_origin()
|
||||
resp.headers['Content-Security-Policy'] = (
|
||||
f'default-src {origin}; '
|
||||
f"script-src {origin} 'unsafe-inline'; "
|
||||
f"style-src {origin} 'unsafe-inline'; "
|
||||
f'img-src {origin} data:; '
|
||||
f'connect-src {origin}; '
|
||||
"frame-src 'none'; "
|
||||
"object-src 'none'"
|
||||
)
|
||||
return resp
|
||||
|
||||
@self.route(
|
||||
'/<author>/<plugin_name>/page-api',
|
||||
methods=['POST'],
|
||||
auth_type=group.AuthType.USER_TOKEN_OR_API_KEY,
|
||||
)
|
||||
async def _(author: str, plugin_name: str) -> str:
|
||||
"""Forward a page API request to the plugin."""
|
||||
data = await quart.request.json
|
||||
if not isinstance(data, dict):
|
||||
return self.http_status(400, -1, 'invalid request body')
|
||||
|
||||
page_id = data.get('page_id', '')
|
||||
endpoint = data.get('endpoint', '')
|
||||
method = data.get('method', 'POST')
|
||||
body = data.get('body')
|
||||
if not isinstance(page_id, str) or not isinstance(endpoint, str) or not isinstance(method, str):
|
||||
return self.http_status(400, -1, 'invalid page api request')
|
||||
if not endpoint.startswith('/') or '..' in endpoint:
|
||||
return self.http_status(400, -1, 'invalid endpoint')
|
||||
|
||||
result = await self.ap.plugin_connector.handle_page_api(
|
||||
author, plugin_name, page_id, endpoint, method.upper(), body
|
||||
)
|
||||
if result.get('error'):
|
||||
return self.http_status(400, -1, result['error'])
|
||||
return self.success(data=result.get('data'))
|
||||
return quart.Response(asset_bytes, mimetype=mime_type)
|
||||
|
||||
@self.route('/github/releases', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
|
||||
@@ -97,51 +97,3 @@ class EmbeddingModelsRouterGroup(group.RouterGroup):
|
||||
await self.ap.embedding_models_service.test_embedding_model(model_uuid, json_data)
|
||||
|
||||
return self.success()
|
||||
|
||||
|
||||
@group.group_class('models/rerank', '/api/v1/provider/models/rerank')
|
||||
class RerankModelsRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
provider_uuid = quart.request.args.get('provider_uuid')
|
||||
if provider_uuid:
|
||||
return self.success(
|
||||
data={
|
||||
'models': await self.ap.rerank_models_service.get_rerank_models_by_provider(provider_uuid)
|
||||
}
|
||||
)
|
||||
return self.success(data={'models': await self.ap.rerank_models_service.get_rerank_models()})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
model_uuid = await self.ap.rerank_models_service.create_rerank_model(json_data)
|
||||
return self.success(data={'uuid': model_uuid})
|
||||
|
||||
@self.route('/<model_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(model_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
model = await self.ap.rerank_models_service.get_rerank_model(model_uuid)
|
||||
|
||||
if model is None:
|
||||
return self.http_status(404, -1, 'model not found')
|
||||
|
||||
return self.success(data={'model': model})
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
|
||||
await self.ap.rerank_models_service.update_rerank_model(model_uuid, json_data)
|
||||
|
||||
return self.success()
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.rerank_models_service.delete_rerank_model(model_uuid)
|
||||
|
||||
return self.success()
|
||||
|
||||
@self.route('/<model_uuid>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(model_uuid: str) -> str:
|
||||
json_data = await quart.request.json
|
||||
|
||||
await self.ap.rerank_models_service.test_rerank_model(model_uuid, json_data)
|
||||
|
||||
return self.success()
|
||||
|
||||
@@ -15,7 +15,6 @@ class ModelProvidersRouterGroup(group.RouterGroup):
|
||||
counts = await self.ap.provider_service.get_provider_model_counts(provider['uuid'])
|
||||
provider['llm_count'] = counts['llm_count']
|
||||
provider['embedding_count'] = counts['embedding_count']
|
||||
provider['rerank_count'] = counts['rerank_count']
|
||||
return self.success(data={'providers': providers})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
@@ -33,7 +32,6 @@ class ModelProvidersRouterGroup(group.RouterGroup):
|
||||
counts = await self.ap.provider_service.get_provider_model_counts(provider_uuid)
|
||||
provider['llm_count'] = counts['llm_count']
|
||||
provider['embedding_count'] = counts['embedding_count']
|
||||
provider['rerank_count'] = counts['rerank_count']
|
||||
return self.success(data={'provider': provider})
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
@@ -45,12 +43,3 @@ class ModelProvidersRouterGroup(group.RouterGroup):
|
||||
return self.success()
|
||||
except ValueError as e:
|
||||
return self.http_status(400, -1, str(e))
|
||||
|
||||
@self.route('/<provider_uuid>/scan-models', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(provider_uuid: str) -> str:
|
||||
try:
|
||||
model_type = quart.request.args.get('type')
|
||||
result = await self.ap.provider_service.scan_provider_models(provider_uuid, model_type)
|
||||
return self.success(data=result)
|
||||
except ValueError as e:
|
||||
return self.http_status(400, -1, str(e))
|
||||
|
||||
@@ -136,10 +136,6 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
|
||||
return self.success(data=task.to_dict())
|
||||
|
||||
@self.route('/storage-analysis', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
return self.success(data=await self.ap.maintenance_service.get_storage_analysis())
|
||||
|
||||
@self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
|
||||
@@ -146,7 +146,6 @@ class UserRouterGroup(group.RouterGroup):
|
||||
return self.fail(3, str(e))
|
||||
except ValueError as e:
|
||||
traceback.print_exc()
|
||||
self.ap.logger.warning(f'Space OAuth callback failed: {e}')
|
||||
return self.fail(1, str(e))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -105,24 +105,23 @@ class HTTPController:
|
||||
):
|
||||
if os.path.exists(os.path.join(frontend_path, path + '.html')):
|
||||
path += '.html'
|
||||
elif not path.startswith('api/'):
|
||||
# SPA fallback: serve index.html for all non-API, non-static routes
|
||||
# so that React Router can handle client-side routing (Vite SPA).
|
||||
# For /home/* sub-routes, first try parent .html files (pre-rendered pages).
|
||||
if path.startswith('home/'):
|
||||
segments = path.rstrip('/').split('/')
|
||||
for i in range(len(segments) - 1, 0, -1):
|
||||
parent_path = '/'.join(segments[:i]) + '.html'
|
||||
if os.path.exists(os.path.join(frontend_path, parent_path)):
|
||||
response = await quart.send_from_directory(
|
||||
frontend_path, parent_path, mimetype='text/html'
|
||||
)
|
||||
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate'
|
||||
response.headers['Pragma'] = 'no-cache'
|
||||
response.headers['Expires'] = '0'
|
||||
return response
|
||||
elif path.startswith('home/'):
|
||||
# SPA fallback for /home/* sub-routes.
|
||||
# Entity detail views use query params (e.g. /home/bots?id=uuid),
|
||||
# so the pre-rendered list page is served directly via path + '.html'.
|
||||
# This fallback handles any remaining unmatched sub-paths.
|
||||
segments = path.rstrip('/').split('/')
|
||||
|
||||
# Fallback to index.html for SPA client-side routing
|
||||
# Walk up parent segments looking for matching .html files
|
||||
for i in range(len(segments) - 1, 0, -1):
|
||||
parent_path = '/'.join(segments[:i]) + '.html'
|
||||
if os.path.exists(os.path.join(frontend_path, parent_path)):
|
||||
response = await quart.send_from_directory(frontend_path, parent_path, mimetype='text/html')
|
||||
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate'
|
||||
response.headers['Pragma'] = 'no-cache'
|
||||
response.headers['Expires'] = '0'
|
||||
return response
|
||||
# Final fallback to index.html for /home/* routes
|
||||
response = await quart.send_from_directory(frontend_path, 'index.html', mimetype='text/html')
|
||||
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate'
|
||||
response.headers['Pragma'] = 'no-cache'
|
||||
|
||||
@@ -99,11 +99,11 @@ class BotService:
|
||||
# TODO: 检查配置信息格式
|
||||
bot_data['uuid'] = str(uuid.uuid4())
|
||||
|
||||
# bind the most recently updated pipeline if any exist
|
||||
# checkout the default pipeline
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
|
||||
.order_by(persistence_pipeline.LegacyPipeline.updated_at.desc())
|
||||
.limit(1)
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
||||
persistence_pipeline.LegacyPipeline.is_default == True
|
||||
)
|
||||
)
|
||||
pipeline = result.first()
|
||||
if pipeline is not None:
|
||||
|
||||
314
src/langbot/pkg/api/http/service/human_takeover.py
Normal file
@@ -0,0 +1,314 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ....core import app
|
||||
from ....entity.persistence import human_takeover as persistence_human_takeover
|
||||
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
|
||||
|
||||
class HumanTakeoverService:
|
||||
"""Human takeover service.
|
||||
|
||||
Manages operator takeover of user conversation sessions, bypassing
|
||||
the normal AI pipeline. Uses an in-memory cache for fast synchronous
|
||||
lookups on the hot message path, backed by database persistence.
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
# In-memory cache: session_id -> HumanTakeoverSession record id
|
||||
# Only contains sessions with status='active'
|
||||
_active_sessions: dict[str, str]
|
||||
|
||||
logger: logging.Logger
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
self._active_sessions = {}
|
||||
self.logger = logging.getLogger('human-takeover')
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Load active takeover sessions from DB into memory cache."""
|
||||
try:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_human_takeover.HumanTakeoverSession).where(
|
||||
persistence_human_takeover.HumanTakeoverSession.status == 'active'
|
||||
)
|
||||
)
|
||||
rows = result.all()
|
||||
for row in rows:
|
||||
session = row[0] if isinstance(row, tuple) else row
|
||||
self._active_sessions[session.session_id] = session.id
|
||||
self.logger.info(f'Loaded {len(self._active_sessions)} active takeover sessions from DB')
|
||||
except Exception as e:
|
||||
self.logger.warning(f'Failed to load active takeover sessions: {e}')
|
||||
|
||||
def is_taken_over(self, session_id: str) -> bool:
|
||||
"""Check if a session is currently under human takeover.
|
||||
|
||||
This is a synchronous in-memory lookup for performance, since it
|
||||
is called on every incoming message (hot path).
|
||||
"""
|
||||
return session_id in self._active_sessions
|
||||
|
||||
async def takeover_session(
|
||||
self,
|
||||
session_id: str,
|
||||
bot_uuid: str,
|
||||
taken_by: str | None = None,
|
||||
platform: str | None = None,
|
||||
user_id: str | None = None,
|
||||
user_name: str | None = None,
|
||||
) -> dict:
|
||||
"""Take over a conversation session.
|
||||
|
||||
Args:
|
||||
session_id: The session to take over (e.g. 'person_123' or 'group_456').
|
||||
bot_uuid: UUID of the bot whose session is being taken over.
|
||||
taken_by: Email/username of the admin performing the takeover.
|
||||
platform: Platform name.
|
||||
user_id: The end-user's ID in the session.
|
||||
user_name: The end-user's display name.
|
||||
|
||||
Returns:
|
||||
Dict with the created takeover session record.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session is already taken over.
|
||||
"""
|
||||
if self.is_taken_over(session_id):
|
||||
raise ValueError(f'Session {session_id} is already taken over')
|
||||
|
||||
record_id = str(uuid.uuid4())
|
||||
now = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
|
||||
record_data = {
|
||||
'id': record_id,
|
||||
'session_id': session_id,
|
||||
'bot_uuid': bot_uuid,
|
||||
'status': 'active',
|
||||
'taken_by': taken_by,
|
||||
'taken_at': now,
|
||||
'released_at': None,
|
||||
'platform': platform,
|
||||
'user_id': user_id,
|
||||
'user_name': user_name,
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_human_takeover.HumanTakeoverSession).values(record_data)
|
||||
)
|
||||
|
||||
# Update in-memory cache
|
||||
self._active_sessions[session_id] = record_id
|
||||
|
||||
self.logger.info(f'Session {session_id} taken over by {taken_by}')
|
||||
|
||||
return record_data
|
||||
|
||||
async def release_session(self, session_id: str) -> dict:
|
||||
"""Release a taken-over session back to AI pipeline processing.
|
||||
|
||||
Args:
|
||||
session_id: The session to release.
|
||||
|
||||
Returns:
|
||||
Dict with the updated takeover session record.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session is not currently taken over.
|
||||
"""
|
||||
if not self.is_taken_over(session_id):
|
||||
raise ValueError(f'Session {session_id} is not currently taken over')
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_human_takeover.HumanTakeoverSession)
|
||||
.where(
|
||||
sqlalchemy.and_(
|
||||
persistence_human_takeover.HumanTakeoverSession.session_id == session_id,
|
||||
persistence_human_takeover.HumanTakeoverSession.status == 'active',
|
||||
)
|
||||
)
|
||||
.values(status='released', released_at=now)
|
||||
)
|
||||
|
||||
# Remove from in-memory cache
|
||||
self._active_sessions.pop(session_id, None)
|
||||
|
||||
self.logger.info(f'Session {session_id} released back to AI pipeline')
|
||||
|
||||
return {
|
||||
'session_id': session_id,
|
||||
'status': 'released',
|
||||
'released_at': now.isoformat(),
|
||||
}
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
session_id: str,
|
||||
message_text: str,
|
||||
operator_name: str | None = None,
|
||||
) -> dict:
|
||||
"""Send a message from the operator to the user via the platform adapter.
|
||||
|
||||
Args:
|
||||
session_id: The taken-over session ID (e.g. 'person_123' or 'group_456').
|
||||
message_text: The text message to send.
|
||||
operator_name: Name of the operator sending the message.
|
||||
|
||||
Returns:
|
||||
Dict with send result info.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session is not currently taken over.
|
||||
RuntimeError: If the bot or adapter cannot be found.
|
||||
"""
|
||||
if not self.is_taken_over(session_id):
|
||||
raise ValueError(f'Session {session_id} is not currently taken over')
|
||||
|
||||
# Look up the takeover record to get bot_uuid
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_human_takeover.HumanTakeoverSession).where(
|
||||
sqlalchemy.and_(
|
||||
persistence_human_takeover.HumanTakeoverSession.session_id == session_id,
|
||||
persistence_human_takeover.HumanTakeoverSession.status == 'active',
|
||||
)
|
||||
)
|
||||
)
|
||||
row = result.first()
|
||||
if not row:
|
||||
raise RuntimeError(f'Active takeover record not found for session {session_id}')
|
||||
|
||||
takeover_record = row[0] if isinstance(row, tuple) else row
|
||||
bot_uuid = takeover_record.bot_uuid
|
||||
|
||||
# Get the runtime bot
|
||||
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
|
||||
if not runtime_bot:
|
||||
raise RuntimeError(f'Bot {bot_uuid} not found or not running')
|
||||
|
||||
# Parse session_id to determine target_type and target_id
|
||||
# Format: 'person_{id}' or 'group_{id}'
|
||||
if session_id.startswith('person_'):
|
||||
target_type = 'person'
|
||||
target_id = session_id[len('person_') :]
|
||||
elif session_id.startswith('group_'):
|
||||
target_type = 'group'
|
||||
target_id = session_id[len('group_') :]
|
||||
else:
|
||||
raise ValueError(f'Invalid session_id format: {session_id}')
|
||||
|
||||
# Build message chain
|
||||
message_chain = platform_message.MessageChain([platform_message.Plain(text=message_text)])
|
||||
|
||||
# Send via adapter
|
||||
await runtime_bot.adapter.send_message(target_type, target_id, message_chain)
|
||||
|
||||
# Record the operator message in monitoring
|
||||
bot_name = runtime_bot.bot_entity.name or bot_uuid
|
||||
try:
|
||||
message_content = json.dumps(message_chain.model_dump(), ensure_ascii=False)
|
||||
except Exception:
|
||||
message_content = message_text
|
||||
|
||||
await self.ap.monitoring_service.record_message(
|
||||
bot_id=bot_uuid,
|
||||
bot_name=bot_name,
|
||||
pipeline_id='__human_takeover__',
|
||||
pipeline_name='Human Takeover',
|
||||
message_content=message_content,
|
||||
session_id=session_id,
|
||||
status='success',
|
||||
level='info',
|
||||
platform=takeover_record.platform,
|
||||
user_id=operator_name or 'operator',
|
||||
user_name=operator_name or 'Operator',
|
||||
role='operator',
|
||||
)
|
||||
|
||||
self.logger.info(f'Operator message sent to session {session_id}: {message_text[:50]}...')
|
||||
|
||||
return {
|
||||
'session_id': session_id,
|
||||
'message_sent': True,
|
||||
}
|
||||
|
||||
async def get_active_sessions(
|
||||
self,
|
||||
bot_uuid: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get list of active (or all) takeover sessions.
|
||||
|
||||
Args:
|
||||
bot_uuid: Optional filter by bot UUID.
|
||||
limit: Maximum number of results.
|
||||
offset: Pagination offset.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of session dicts, total count).
|
||||
"""
|
||||
conditions = []
|
||||
|
||||
if bot_uuid:
|
||||
conditions.append(persistence_human_takeover.HumanTakeoverSession.bot_uuid == bot_uuid)
|
||||
|
||||
# Count
|
||||
count_query = sqlalchemy.select(sqlalchemy.func.count(persistence_human_takeover.HumanTakeoverSession.id))
|
||||
if conditions:
|
||||
count_query = count_query.where(sqlalchemy.and_(*conditions))
|
||||
|
||||
count_result = await self.ap.persistence_mgr.execute_async(count_query)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
# Fetch records
|
||||
query = sqlalchemy.select(persistence_human_takeover.HumanTakeoverSession).order_by(
|
||||
persistence_human_takeover.HumanTakeoverSession.taken_at.desc()
|
||||
)
|
||||
if conditions:
|
||||
query = query.where(sqlalchemy.and_(*conditions))
|
||||
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(query)
|
||||
rows = result.all()
|
||||
|
||||
sessions = []
|
||||
for row in rows:
|
||||
session = row[0] if isinstance(row, tuple) else row
|
||||
sessions.append(
|
||||
self.ap.persistence_mgr.serialize_model(persistence_human_takeover.HumanTakeoverSession, session)
|
||||
)
|
||||
|
||||
return sessions, total
|
||||
|
||||
async def get_session_detail(self, session_id: str) -> dict | None:
|
||||
"""Get detail for a specific takeover session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to look up.
|
||||
|
||||
Returns:
|
||||
Session dict or None if not found.
|
||||
"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_human_takeover.HumanTakeoverSession)
|
||||
.where(persistence_human_takeover.HumanTakeoverSession.session_id == session_id)
|
||||
.order_by(persistence_human_takeover.HumanTakeoverSession.taken_at.desc())
|
||||
)
|
||||
row = result.first()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
session = row[0] if isinstance(row, tuple) else row
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_human_takeover.HumanTakeoverSession, session)
|
||||
@@ -31,126 +31,15 @@ class KnowledgeService:
|
||||
if not knowledge_engine_plugin_id:
|
||||
raise ValueError('knowledge_engine_plugin_id is required')
|
||||
|
||||
creation_settings = kb_data.get('creation_settings', {})
|
||||
retrieval_settings = kb_data.get('retrieval_settings', {})
|
||||
|
||||
# Validate required fields based on plugin's creation_schema and retrieval_schema
|
||||
await self._validate_schema_required_fields(
|
||||
knowledge_engine_plugin_id,
|
||||
creation_settings,
|
||||
retrieval_settings,
|
||||
)
|
||||
|
||||
kb = await self.ap.rag_mgr.create_knowledge_base(
|
||||
name=kb_data.get('name', 'Untitled'),
|
||||
knowledge_engine_plugin_id=knowledge_engine_plugin_id,
|
||||
creation_settings=creation_settings,
|
||||
retrieval_settings=retrieval_settings,
|
||||
creation_settings=kb_data.get('creation_settings', {}),
|
||||
retrieval_settings=kb_data.get('retrieval_settings', {}),
|
||||
description=kb_data.get('description', ''),
|
||||
)
|
||||
return kb.uuid
|
||||
|
||||
async def _validate_schema_required_fields(
|
||||
self,
|
||||
plugin_id: str,
|
||||
creation_settings: dict,
|
||||
retrieval_settings: dict,
|
||||
) -> None:
|
||||
"""Validate required fields based on plugin's creation_schema and retrieval_schema.
|
||||
|
||||
This is a business-agnostic validation that checks all fields marked as
|
||||
required in the plugin's schema, regardless of field type.
|
||||
|
||||
Args:
|
||||
plugin_id: Knowledge Engine plugin ID.
|
||||
creation_settings: User-provided creation settings.
|
||||
retrieval_settings: User-provided retrieval settings.
|
||||
|
||||
Raises:
|
||||
ValueError: If any required field is missing or empty.
|
||||
"""
|
||||
# Validate creation_schema
|
||||
try:
|
||||
creation_schema = await self.ap.plugin_connector.get_rag_creation_schema(plugin_id)
|
||||
self._check_required_fields(creation_schema, creation_settings, 'creation_settings')
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to get creation_schema for validation: {e}')
|
||||
|
||||
# Validate retrieval_schema
|
||||
try:
|
||||
retrieval_schema = await self.ap.plugin_connector.get_rag_retrieval_schema(plugin_id)
|
||||
self._check_required_fields(retrieval_schema, retrieval_settings, 'retrieval_settings')
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to get retrieval_schema for validation: {e}')
|
||||
|
||||
def _check_required_fields(
|
||||
self,
|
||||
schema: dict | list,
|
||||
settings: dict,
|
||||
context: str,
|
||||
) -> None:
|
||||
"""Check required fields in schema against provided settings.
|
||||
|
||||
Args:
|
||||
schema: Plugin-defined schema (can be list or dict with 'schema' key).
|
||||
settings: User-provided settings values.
|
||||
context: Context name for error messages (e.g., 'creation_settings').
|
||||
|
||||
Raises:
|
||||
ValueError: If a required field is missing or empty.
|
||||
"""
|
||||
if not schema:
|
||||
return
|
||||
|
||||
# schema can be a list directly, or a dict with 'schema' key
|
||||
items = schema if isinstance(schema, list) else schema.get('schema', [])
|
||||
if not items:
|
||||
return
|
||||
|
||||
for item in items:
|
||||
field_name = item.get('name')
|
||||
if not field_name:
|
||||
continue
|
||||
|
||||
is_required = item.get('required', False)
|
||||
if not is_required:
|
||||
continue
|
||||
|
||||
# Check show_if condition - if field is conditionally shown, only validate when condition is met
|
||||
show_if = item.get('show_if')
|
||||
if show_if:
|
||||
depend_field = show_if.get('field')
|
||||
operator = show_if.get('operator')
|
||||
expected_value = show_if.get('value')
|
||||
|
||||
if depend_field and operator:
|
||||
depend_value = settings.get(depend_field)
|
||||
# If show_if condition is not met, skip validation for this field
|
||||
if operator == 'eq' and depend_value != expected_value:
|
||||
continue
|
||||
if operator == 'neq' and depend_value == expected_value:
|
||||
continue
|
||||
if operator == 'in' and isinstance(expected_value, list) and depend_value not in expected_value:
|
||||
continue
|
||||
|
||||
value = settings.get(field_name)
|
||||
|
||||
# Validate required field has a non-empty value
|
||||
if value is None or (isinstance(value, str) and value.strip() == ''):
|
||||
# Get field label for friendly error message
|
||||
label = item.get('label', {})
|
||||
field_label = (
|
||||
label.get('en_US', field_name)
|
||||
or label.get('zh_Hans', field_name)
|
||||
or label.get('zh_Hant', field_name)
|
||||
or field_name
|
||||
)
|
||||
raise ValueError(f'{field_label} is required ({context}.{field_name})')
|
||||
|
||||
async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
|
||||
"""更新知识库"""
|
||||
# Filter to only mutable fields
|
||||
|
||||
@@ -1,309 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ....core import app
|
||||
from ....entity.persistence import bstorage as persistence_bstorage
|
||||
from ....entity.persistence import monitoring as persistence_monitoring
|
||||
|
||||
|
||||
LOG_FILE_PATTERN = re.compile(r'^langbot-(\d{4}-\d{2}-\d{2})\.log(?:\.\d+)?$')
|
||||
DEFAULT_UPLOAD_FILE_RETENTION_DAYS = 7
|
||||
DEFAULT_LOG_RETENTION_DAYS = 3
|
||||
|
||||
|
||||
class MaintenanceService:
|
||||
"""Storage maintenance and diagnostics."""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def cleanup_expired_files(self) -> dict[str, int]:
|
||||
cleanup_cfg = self.ap.instance_config.data.get('storage', {}).get('cleanup', {})
|
||||
upload_retention_days = self._positive_int(
|
||||
cleanup_cfg.get('uploaded_file_retention_days'),
|
||||
DEFAULT_UPLOAD_FILE_RETENTION_DAYS,
|
||||
'storage.cleanup.uploaded_file_retention_days',
|
||||
)
|
||||
log_retention_days = self._positive_int(
|
||||
cleanup_cfg.get('log_retention_days'),
|
||||
DEFAULT_LOG_RETENTION_DAYS,
|
||||
'storage.cleanup.log_retention_days',
|
||||
)
|
||||
|
||||
return {
|
||||
'uploaded_files': await self._cleanup_expired_uploaded_files(upload_retention_days),
|
||||
'log_files': self._cleanup_expired_log_files(log_retention_days),
|
||||
}
|
||||
|
||||
async def get_storage_analysis(self) -> dict[str, Any]:
|
||||
cleanup_cfg = self.ap.instance_config.data.get('storage', {}).get('cleanup', {})
|
||||
upload_retention_days = self._positive_int(
|
||||
cleanup_cfg.get('uploaded_file_retention_days'),
|
||||
DEFAULT_UPLOAD_FILE_RETENTION_DAYS,
|
||||
'storage.cleanup.uploaded_file_retention_days',
|
||||
)
|
||||
log_retention_days = self._positive_int(
|
||||
cleanup_cfg.get('log_retention_days'),
|
||||
DEFAULT_LOG_RETENTION_DAYS,
|
||||
'storage.cleanup.log_retention_days',
|
||||
)
|
||||
|
||||
database_cfg = self.ap.instance_config.data.get('database', {})
|
||||
database_type = database_cfg.get('use', 'sqlite')
|
||||
database_path = (
|
||||
Path(database_cfg.get('sqlite', {}).get('path', 'data/langbot.db')) if database_type == 'sqlite' else None
|
||||
)
|
||||
roots: list[tuple[str, Path | None]] = [
|
||||
('database', database_path),
|
||||
('logs', Path('data/logs')),
|
||||
('storage', Path('data/storage')),
|
||||
('vector_store', Path('data/chroma')),
|
||||
('plugins', Path('data/plugins')),
|
||||
('mcp', Path('data/mcp')),
|
||||
('temp', Path('data/temp')),
|
||||
]
|
||||
|
||||
sections = []
|
||||
for key, path in roots:
|
||||
sections.append(
|
||||
{
|
||||
'key': key,
|
||||
'path': str(path) if path else '',
|
||||
'exists': path.exists() if path else False,
|
||||
'size_bytes': self._path_size(path) if path else 0,
|
||||
'file_count': self._file_count(path) if path else 0,
|
||||
}
|
||||
)
|
||||
|
||||
monitoring_counts = await self._monitoring_counts()
|
||||
binary_storage = await self._binary_storage_stats()
|
||||
upload_candidates = await self._expired_uploaded_candidates(upload_retention_days)
|
||||
log_candidates = self._expired_log_candidates(log_retention_days)
|
||||
|
||||
return {
|
||||
'generated_at': datetime.datetime.now(datetime.timezone.utc).isoformat(),
|
||||
'cleanup_policy': {
|
||||
'uploaded_file_retention_days': upload_retention_days,
|
||||
'log_retention_days': log_retention_days,
|
||||
},
|
||||
'sections': sections,
|
||||
'database': {
|
||||
'type': database_type,
|
||||
'monitoring_counts': monitoring_counts,
|
||||
'binary_storage': binary_storage,
|
||||
},
|
||||
'cleanup_candidates': {
|
||||
'uploaded_files': upload_candidates,
|
||||
'log_files': log_candidates,
|
||||
},
|
||||
'tasks': self.ap.task_mgr.get_stats() if self.ap.task_mgr else {},
|
||||
}
|
||||
|
||||
async def _cleanup_expired_uploaded_files(self, retention_days: int) -> int:
|
||||
provider = self.ap.storage_mgr.storage_provider
|
||||
provider_name = provider.__class__.__name__
|
||||
if provider_name == 'LocalStorageProvider':
|
||||
candidates = self._expired_local_upload_candidates(retention_days, include_paths=True)
|
||||
deleted = 0
|
||||
for item in candidates:
|
||||
try:
|
||||
os.remove(item['path'])
|
||||
deleted += 1
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to delete expired uploaded file {item["key"]}: {e}')
|
||||
return deleted
|
||||
|
||||
if provider_name == 'S3StorageProvider':
|
||||
return await self._cleanup_expired_s3_uploaded_files(retention_days)
|
||||
|
||||
return 0
|
||||
|
||||
async def _expired_uploaded_candidates(self, retention_days: int) -> list[dict[str, Any]]:
|
||||
provider_name = self.ap.storage_mgr.storage_provider.__class__.__name__
|
||||
if provider_name == 'LocalStorageProvider':
|
||||
return self._expired_local_upload_candidates(retention_days)
|
||||
if provider_name == 'S3StorageProvider':
|
||||
return await self._expired_s3_upload_candidates(retention_days)
|
||||
return []
|
||||
|
||||
async def _cleanup_expired_s3_uploaded_files(self, retention_days: int) -> int:
|
||||
provider = self.ap.storage_mgr.storage_provider
|
||||
candidates = await self._expired_s3_upload_candidates(retention_days)
|
||||
deleted = 0
|
||||
for item in candidates:
|
||||
await provider.delete(item['key'])
|
||||
deleted += 1
|
||||
return deleted
|
||||
|
||||
async def _expired_s3_upload_candidates(self, retention_days: int) -> list[dict[str, Any]]:
|
||||
provider = self.ap.storage_mgr.storage_provider
|
||||
cutoff = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=retention_days)
|
||||
candidates = []
|
||||
paginator = provider.s3_client.get_paginator('list_objects_v2')
|
||||
|
||||
for page in paginator.paginate(Bucket=provider.bucket_name):
|
||||
for obj in page.get('Contents', []):
|
||||
key = obj.get('Key', '')
|
||||
last_modified = obj.get('LastModified')
|
||||
if not self._is_uploaded_file_key(key):
|
||||
continue
|
||||
if last_modified and last_modified < cutoff:
|
||||
candidates.append(
|
||||
{
|
||||
'key': key,
|
||||
'size_bytes': obj.get('Size', 0),
|
||||
'modified_at': last_modified.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
return candidates
|
||||
|
||||
def _cleanup_expired_log_files(self, retention_days: int) -> int:
|
||||
deleted = 0
|
||||
for item in self._expired_log_candidates(retention_days, include_paths=True):
|
||||
try:
|
||||
os.remove(item['path'])
|
||||
deleted += 1
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to delete expired log file {item["name"]}: {e}')
|
||||
return deleted
|
||||
|
||||
def _expired_local_upload_candidates(
|
||||
self, retention_days: int, include_paths: bool = False
|
||||
) -> list[dict[str, Any]]:
|
||||
storage_root = Path('data/storage')
|
||||
if not storage_root.exists():
|
||||
return []
|
||||
|
||||
cutoff = datetime.datetime.now().timestamp() - retention_days * 86400
|
||||
candidates = []
|
||||
for entry in storage_root.iterdir():
|
||||
if not entry.is_file() or not self._is_uploaded_file_key(entry.name):
|
||||
continue
|
||||
stat = entry.stat()
|
||||
if stat.st_mtime >= cutoff:
|
||||
continue
|
||||
item = {
|
||||
'key': entry.name,
|
||||
'size_bytes': stat.st_size,
|
||||
'modified_at': datetime.datetime.fromtimestamp(stat.st_mtime, datetime.timezone.utc).isoformat(),
|
||||
}
|
||||
if include_paths:
|
||||
item['path'] = str(entry)
|
||||
candidates.append(item)
|
||||
return candidates
|
||||
|
||||
def _expired_log_candidates(self, retention_days: int, include_paths: bool = False) -> list[dict[str, Any]]:
|
||||
log_root = Path('data/logs')
|
||||
if not log_root.exists():
|
||||
return []
|
||||
|
||||
cutoff_date = datetime.date.today() - datetime.timedelta(days=retention_days - 1)
|
||||
candidates = []
|
||||
for entry in log_root.iterdir():
|
||||
if not entry.is_file():
|
||||
continue
|
||||
match = LOG_FILE_PATTERN.match(entry.name)
|
||||
if not match:
|
||||
continue
|
||||
try:
|
||||
file_date = datetime.date.fromisoformat(match.group(1))
|
||||
except ValueError:
|
||||
continue
|
||||
if file_date >= cutoff_date:
|
||||
continue
|
||||
stat = entry.stat()
|
||||
item = {
|
||||
'name': entry.name,
|
||||
'date': file_date.isoformat(),
|
||||
'size_bytes': stat.st_size,
|
||||
}
|
||||
if include_paths:
|
||||
item['path'] = str(entry)
|
||||
candidates.append(item)
|
||||
return candidates
|
||||
|
||||
def _is_uploaded_file_key(self, key: str) -> bool:
|
||||
return '/' not in key and not key.startswith('plugin_config_')
|
||||
|
||||
async def _monitoring_counts(self) -> dict[str, int]:
|
||||
tables = {
|
||||
'messages': persistence_monitoring.MonitoringMessage.id,
|
||||
'llm_calls': persistence_monitoring.MonitoringLLMCall.id,
|
||||
'embedding_calls': persistence_monitoring.MonitoringEmbeddingCall.id,
|
||||
'errors': persistence_monitoring.MonitoringError.id,
|
||||
'sessions': persistence_monitoring.MonitoringSession.session_id,
|
||||
'feedback': persistence_monitoring.MonitoringFeedback.id,
|
||||
}
|
||||
counts: dict[str, int] = {}
|
||||
for key, column in tables.items():
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(sqlalchemy.func.count(column)))
|
||||
counts[key] = result.scalar() or 0
|
||||
return counts
|
||||
|
||||
async def _binary_storage_stats(self) -> dict[str, Any]:
|
||||
count_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(sqlalchemy.func.count(persistence_bstorage.BinaryStorage.unique_key))
|
||||
)
|
||||
size_bytes = None
|
||||
try:
|
||||
size_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(sqlalchemy.func.sum(sqlalchemy.func.length(persistence_bstorage.BinaryStorage.value)))
|
||||
)
|
||||
size_bytes = size_result.scalar() or 0
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to estimate binary storage size: {e}')
|
||||
|
||||
return {
|
||||
'count': count_result.scalar() or 0,
|
||||
'size_bytes': size_bytes,
|
||||
}
|
||||
|
||||
def _path_size(self, path: Path) -> int:
|
||||
if not path.exists():
|
||||
return 0
|
||||
if path.is_file():
|
||||
return path.stat().st_size
|
||||
total = 0
|
||||
for root, _, files in os.walk(path):
|
||||
for file_name in files:
|
||||
file_path = Path(root) / file_name
|
||||
try:
|
||||
total += file_path.stat().st_size
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return total
|
||||
|
||||
def _file_count(self, path: Path) -> int:
|
||||
if not path.exists():
|
||||
return 0
|
||||
if path.is_file():
|
||||
return 1
|
||||
count = 0
|
||||
for _, _, files in os.walk(path):
|
||||
count += len(files)
|
||||
return count
|
||||
|
||||
def _positive_int(self, value: Any, default: int, name: str) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
self.ap.logger.warning(f'Invalid {name}: {value!r}, using {default}')
|
||||
return default
|
||||
if parsed < 1:
|
||||
self.ap.logger.warning(f'Invalid {name}: {value!r}, using {default}')
|
||||
return default
|
||||
return parsed
|
||||
@@ -23,17 +23,6 @@ def _parse_provider_api_keys(provider_dict: dict) -> dict:
|
||||
return provider_dict
|
||||
|
||||
|
||||
def _runtime_model_data(model_uuid: str, model_data: dict) -> dict:
|
||||
"""Return model data for rebuilding runtime models after an update.
|
||||
|
||||
Update payloads intentionally omit uuid before writing to the database.
|
||||
Runtime model entities still need the stable uuid so pipeline configs can
|
||||
resolve the in-memory model immediately after an edit, without requiring a
|
||||
process restart.
|
||||
"""
|
||||
return {**model_data, 'uuid': model_uuid}
|
||||
|
||||
|
||||
class LLMModelsService:
|
||||
ap: app.Application
|
||||
|
||||
@@ -184,7 +173,7 @@ class LLMModelsService:
|
||||
raise Exception('provider not found')
|
||||
|
||||
runtime_llm_model = await self.ap.model_mgr.load_llm_model_with_provider(
|
||||
persistence_model.LLMModel(**_runtime_model_data(model_uuid, model_data)),
|
||||
persistence_model.LLMModel(**model_data),
|
||||
runtime_provider,
|
||||
)
|
||||
self.ap.model_mgr.llm_models.append(runtime_llm_model)
|
||||
@@ -345,7 +334,7 @@ class EmbeddingModelsService:
|
||||
raise Exception('provider not found')
|
||||
|
||||
runtime_embedding_model = await self.ap.model_mgr.load_embedding_model_with_provider(
|
||||
persistence_model.EmbeddingModel(**_runtime_model_data(model_uuid, model_data)),
|
||||
persistence_model.EmbeddingModel(**model_data),
|
||||
runtime_provider,
|
||||
)
|
||||
self.ap.model_mgr.embedding_models.append(runtime_embedding_model)
|
||||
@@ -378,162 +367,3 @@ class EmbeddingModelsService:
|
||||
input_text=['Hello, world!'],
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
|
||||
class RerankModelsService:
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def get_rerank_models(self) -> list[dict]:
|
||||
"""Get all rerank models with provider info"""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.RerankModel))
|
||||
models = result.all()
|
||||
|
||||
providers_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider)
|
||||
)
|
||||
providers = {p.uuid: p for p in providers_result.all()}
|
||||
|
||||
models_list = []
|
||||
for model in models:
|
||||
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, model)
|
||||
provider = providers.get(model.provider_uuid)
|
||||
if provider:
|
||||
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
||||
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
||||
models_list.append(model_dict)
|
||||
|
||||
return models_list
|
||||
|
||||
async def get_rerank_models_by_provider(self, provider_uuid: str) -> list[dict]:
|
||||
"""Get rerank models by provider UUID"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.RerankModel).where(
|
||||
persistence_model.RerankModel.provider_uuid == provider_uuid
|
||||
)
|
||||
)
|
||||
models = result.all()
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, m) for m in models]
|
||||
|
||||
async def create_rerank_model(self, model_data: dict, preserve_uuid: bool = False) -> str:
|
||||
"""Create a new rerank model"""
|
||||
if not preserve_uuid:
|
||||
model_data['uuid'] = str(uuid.uuid4())
|
||||
|
||||
if 'provider' in model_data:
|
||||
provider_data = model_data.pop('provider')
|
||||
if provider_data.get('uuid'):
|
||||
model_data['provider_uuid'] = provider_data['uuid']
|
||||
else:
|
||||
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
||||
requester=provider_data.get('requester', ''),
|
||||
base_url=provider_data.get('base_url', ''),
|
||||
api_keys=provider_data.get('api_keys', []),
|
||||
)
|
||||
model_data['provider_uuid'] = provider_uuid
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_model.RerankModel).values(**model_data)
|
||||
)
|
||||
|
||||
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
||||
if runtime_provider is None:
|
||||
raise Exception('provider not found')
|
||||
|
||||
runtime_rerank_model = await self.ap.model_mgr.load_rerank_model_with_provider(
|
||||
persistence_model.RerankModel(**model_data),
|
||||
runtime_provider,
|
||||
)
|
||||
self.ap.model_mgr.rerank_models.append(runtime_rerank_model)
|
||||
|
||||
return model_data['uuid']
|
||||
|
||||
async def get_rerank_model(self, model_uuid: str) -> dict | None:
|
||||
"""Get a single rerank model with provider info"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.RerankModel).where(persistence_model.RerankModel.uuid == model_uuid)
|
||||
)
|
||||
model = result.first()
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, model)
|
||||
|
||||
provider_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||
persistence_model.ModelProvider.uuid == model.provider_uuid
|
||||
)
|
||||
)
|
||||
provider = provider_result.first()
|
||||
if provider:
|
||||
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
||||
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
||||
|
||||
return model_dict
|
||||
|
||||
async def update_rerank_model(self, model_uuid: str, model_data: dict) -> None:
|
||||
"""Update an existing rerank model"""
|
||||
if 'uuid' in model_data:
|
||||
del model_data['uuid']
|
||||
|
||||
if 'provider' in model_data:
|
||||
provider_data = model_data.pop('provider')
|
||||
if provider_data.get('uuid'):
|
||||
model_data['provider_uuid'] = provider_data['uuid']
|
||||
else:
|
||||
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
||||
requester=provider_data.get('requester', ''),
|
||||
base_url=provider_data.get('base_url', ''),
|
||||
api_keys=provider_data.get('api_keys', []),
|
||||
)
|
||||
model_data['provider_uuid'] = provider_uuid
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.RerankModel)
|
||||
.where(persistence_model.RerankModel.uuid == model_uuid)
|
||||
.values(**model_data)
|
||||
)
|
||||
|
||||
await self.ap.model_mgr.remove_rerank_model(model_uuid)
|
||||
|
||||
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
||||
if runtime_provider is None:
|
||||
raise Exception('provider not found')
|
||||
|
||||
runtime_rerank_model = await self.ap.model_mgr.load_rerank_model_with_provider(
|
||||
persistence_model.RerankModel(**_runtime_model_data(model_uuid, model_data)),
|
||||
runtime_provider,
|
||||
)
|
||||
self.ap.model_mgr.rerank_models.append(runtime_rerank_model)
|
||||
|
||||
async def delete_rerank_model(self, model_uuid: str) -> None:
|
||||
"""Delete a rerank model"""
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_model.RerankModel).where(persistence_model.RerankModel.uuid == model_uuid)
|
||||
)
|
||||
await self.ap.model_mgr.remove_rerank_model(model_uuid)
|
||||
|
||||
async def test_rerank_model(self, model_uuid: str, model_data: dict) -> None:
|
||||
"""Test a rerank model"""
|
||||
runtime_rerank_model: model_requester.RuntimeRerankModel | None = None
|
||||
|
||||
if model_uuid != '_':
|
||||
for model in self.ap.model_mgr.rerank_models:
|
||||
if model.model_entity.uuid == model_uuid:
|
||||
runtime_rerank_model = model
|
||||
break
|
||||
if runtime_rerank_model is None:
|
||||
raise Exception('model not found')
|
||||
else:
|
||||
runtime_rerank_model = await self.ap.model_mgr.init_temporary_runtime_rerank_model(model_data)
|
||||
|
||||
await runtime_rerank_model.provider.invoke_rerank(
|
||||
model=runtime_rerank_model,
|
||||
query='What is artificial intelligence?',
|
||||
documents=[
|
||||
'Artificial intelligence is a branch of computer science.',
|
||||
'The weather is nice today.',
|
||||
],
|
||||
)
|
||||
|
||||
@@ -18,119 +18,55 @@ class MonitoringService:
|
||||
|
||||
# ========== Cleanup Methods ==========
|
||||
|
||||
async def cleanup_expired_records(self, retention_days: int, batch_size: int = 1000) -> dict[str, int]:
|
||||
async def cleanup_expired_records(self, retention_days: int) -> dict[str, int]:
|
||||
"""Delete monitoring records older than the specified retention period.
|
||||
|
||||
Args:
|
||||
retention_days: Number of days to retain records.
|
||||
batch_size: Maximum rows to delete per table batch.
|
||||
|
||||
Returns:
|
||||
A dict mapping table name to the number of deleted rows.
|
||||
"""
|
||||
if retention_days < 1:
|
||||
raise ValueError('retention_days must be >= 1')
|
||||
if batch_size < 1:
|
||||
raise ValueError('batch_size must be >= 1')
|
||||
|
||||
cutoff = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - datetime.timedelta(
|
||||
days=retention_days
|
||||
)
|
||||
|
||||
tables_and_columns: list[tuple[str, type, sqlalchemy.Column, sqlalchemy.Column]] = [
|
||||
tables_and_columns: list[tuple[str, type, sqlalchemy.Column]] = [
|
||||
(
|
||||
'monitoring_messages',
|
||||
persistence_monitoring.MonitoringMessage,
|
||||
persistence_monitoring.MonitoringMessage.timestamp,
|
||||
persistence_monitoring.MonitoringMessage.id,
|
||||
),
|
||||
(
|
||||
'monitoring_llm_calls',
|
||||
persistence_monitoring.MonitoringLLMCall,
|
||||
persistence_monitoring.MonitoringLLMCall.timestamp,
|
||||
persistence_monitoring.MonitoringLLMCall.id,
|
||||
),
|
||||
(
|
||||
'monitoring_embedding_calls',
|
||||
persistence_monitoring.MonitoringEmbeddingCall,
|
||||
persistence_monitoring.MonitoringEmbeddingCall.timestamp,
|
||||
persistence_monitoring.MonitoringEmbeddingCall.id,
|
||||
),
|
||||
(
|
||||
'monitoring_errors',
|
||||
persistence_monitoring.MonitoringError,
|
||||
persistence_monitoring.MonitoringError.timestamp,
|
||||
persistence_monitoring.MonitoringError.id,
|
||||
),
|
||||
(
|
||||
'monitoring_sessions',
|
||||
persistence_monitoring.MonitoringSession,
|
||||
persistence_monitoring.MonitoringSession.last_activity,
|
||||
persistence_monitoring.MonitoringSession.session_id,
|
||||
),
|
||||
(
|
||||
'monitoring_feedback',
|
||||
persistence_monitoring.MonitoringFeedback,
|
||||
persistence_monitoring.MonitoringFeedback.timestamp,
|
||||
persistence_monitoring.MonitoringFeedback.id,
|
||||
),
|
||||
]
|
||||
|
||||
deleted_counts: dict[str, int] = {}
|
||||
|
||||
for table_name, model_cls, ts_column, pk_column in tables_and_columns:
|
||||
deleted_counts[table_name] = await self._delete_expired_in_batches(
|
||||
model_cls=model_cls,
|
||||
ts_column=ts_column,
|
||||
pk_column=pk_column,
|
||||
cutoff=cutoff,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
if sum(deleted_counts.values()) > 0:
|
||||
await self._release_sqlite_space()
|
||||
for table_name, model_cls, ts_column in tables_and_columns:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.delete(model_cls).where(ts_column < cutoff))
|
||||
deleted_counts[table_name] = result.rowcount
|
||||
|
||||
return deleted_counts
|
||||
|
||||
async def _delete_expired_in_batches(
|
||||
self,
|
||||
model_cls: type,
|
||||
ts_column: sqlalchemy.Column,
|
||||
pk_column: sqlalchemy.Column,
|
||||
cutoff: datetime.datetime,
|
||||
batch_size: int,
|
||||
) -> int:
|
||||
deleted_total = 0
|
||||
|
||||
while True:
|
||||
select_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(pk_column).where(ts_column < cutoff).limit(batch_size)
|
||||
)
|
||||
pk_values = list(select_result.scalars().all())
|
||||
if not pk_values:
|
||||
break
|
||||
|
||||
delete_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(model_cls).where(pk_column.in_(pk_values))
|
||||
)
|
||||
deleted = delete_result.rowcount or 0
|
||||
deleted_total += deleted
|
||||
|
||||
if len(pk_values) < batch_size:
|
||||
break
|
||||
|
||||
return deleted_total
|
||||
|
||||
async def _release_sqlite_space(self) -> None:
|
||||
database_type = self.ap.instance_config.data.get('database', {}).get('use', 'sqlite')
|
||||
if database_type != 'sqlite':
|
||||
return
|
||||
|
||||
async with self.ap.persistence_mgr.get_db_engine().connect() as conn:
|
||||
autocommit_conn = await conn.execution_options(isolation_level='AUTOCOMMIT')
|
||||
await autocommit_conn.execute(sqlalchemy.text('PRAGMA wal_checkpoint(TRUNCATE)'))
|
||||
await autocommit_conn.execute(sqlalchemy.text('VACUUM'))
|
||||
|
||||
# ========== Recording Methods ==========
|
||||
|
||||
async def record_message(
|
||||
@@ -1288,83 +1224,30 @@ class MonitoringService:
|
||||
"""
|
||||
import json
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
reasons_json = json.dumps(inaccurate_reasons, ensure_ascii=False) if inaccurate_reasons else None
|
||||
record_id = str(uuid.uuid4())
|
||||
record_data = {
|
||||
'id': record_id,
|
||||
'timestamp': datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
'feedback_id': feedback_id,
|
||||
'feedback_type': feedback_type,
|
||||
'feedback_content': feedback_content,
|
||||
'inaccurate_reasons': json.dumps(inaccurate_reasons, ensure_ascii=False) if inaccurate_reasons else None,
|
||||
'bot_id': bot_id,
|
||||
'bot_name': bot_name,
|
||||
'pipeline_id': pipeline_id,
|
||||
'pipeline_name': pipeline_name,
|
||||
'session_id': session_id,
|
||||
'message_id': message_id,
|
||||
'stream_id': stream_id,
|
||||
'user_id': user_id,
|
||||
'platform': platform,
|
||||
}
|
||||
|
||||
MonitoringFeedback = persistence_monitoring.MonitoringFeedback
|
||||
|
||||
# Handle cancel feedback (type=3): delete existing record
|
||||
if feedback_type == 3:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(MonitoringFeedback).where(MonitoringFeedback.feedback_id == feedback_id)
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if record with this feedback_id already exists
|
||||
existing_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(MonitoringFeedback).where(MonitoringFeedback.feedback_id == feedback_id)
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_monitoring.MonitoringFeedback).values(record_data)
|
||||
)
|
||||
existing_row = existing_result.first()
|
||||
|
||||
if existing_row:
|
||||
# UPDATE existing record
|
||||
existing = existing_row[0] if isinstance(existing_row, tuple) else existing_row
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(MonitoringFeedback)
|
||||
.where(MonitoringFeedback.feedback_id == feedback_id)
|
||||
.values(
|
||||
timestamp=now,
|
||||
feedback_type=feedback_type,
|
||||
feedback_content=feedback_content,
|
||||
inaccurate_reasons=reasons_json,
|
||||
bot_id=bot_id or existing.bot_id,
|
||||
bot_name=bot_name or existing.bot_name,
|
||||
pipeline_id=pipeline_id or existing.pipeline_id,
|
||||
pipeline_name=pipeline_name or existing.pipeline_name,
|
||||
session_id=session_id or existing.session_id,
|
||||
message_id=message_id or existing.message_id,
|
||||
stream_id=stream_id or existing.stream_id,
|
||||
user_id=user_id or existing.user_id,
|
||||
platform=platform or existing.platform,
|
||||
)
|
||||
)
|
||||
return existing.id
|
||||
else:
|
||||
# INSERT new record with IntegrityError defense
|
||||
record_id = str(uuid.uuid4())
|
||||
record_data = {
|
||||
'id': record_id,
|
||||
'timestamp': now,
|
||||
'feedback_id': feedback_id,
|
||||
'feedback_type': feedback_type,
|
||||
'feedback_content': feedback_content,
|
||||
'inaccurate_reasons': reasons_json,
|
||||
'bot_id': bot_id,
|
||||
'bot_name': bot_name,
|
||||
'pipeline_id': pipeline_id,
|
||||
'pipeline_name': pipeline_name,
|
||||
'session_id': session_id,
|
||||
'message_id': message_id,
|
||||
'stream_id': stream_id,
|
||||
'user_id': user_id,
|
||||
'platform': platform,
|
||||
}
|
||||
try:
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(MonitoringFeedback).values(record_data))
|
||||
return record_id
|
||||
except Exception:
|
||||
# UNIQUE constraint conflict (concurrent feedback for same feedback_id)
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(MonitoringFeedback)
|
||||
.where(MonitoringFeedback.feedback_id == feedback_id)
|
||||
.values(
|
||||
timestamp=now,
|
||||
feedback_type=feedback_type,
|
||||
feedback_content=feedback_content,
|
||||
inaccurate_reasons=reasons_json,
|
||||
)
|
||||
)
|
||||
return feedback_id
|
||||
return record_id
|
||||
|
||||
async def get_feedback_stats(
|
||||
self,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import traceback
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
@@ -17,24 +16,6 @@ class ModelProviderService:
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
@staticmethod
|
||||
def _normalize_api_keys(api_keys: str | list[str] | tuple[str, ...] | None) -> list[str]:
|
||||
if api_keys is None:
|
||||
return []
|
||||
|
||||
raw_keys = [api_keys] if isinstance(api_keys, str) else list(api_keys)
|
||||
normalized_keys = []
|
||||
seen_keys = set()
|
||||
|
||||
for raw_key in raw_keys:
|
||||
normalized_key = raw_key.strip() if isinstance(raw_key, str) else ''
|
||||
if not normalized_key or normalized_key in seen_keys:
|
||||
continue
|
||||
normalized_keys.append(normalized_key)
|
||||
seen_keys.add(normalized_key)
|
||||
|
||||
return normalized_keys
|
||||
|
||||
async def get_providers(self) -> list[dict]:
|
||||
"""Get all providers"""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.ModelProvider))
|
||||
@@ -77,7 +58,6 @@ class ModelProviderService:
|
||||
async def create_provider(self, provider_data: dict) -> str:
|
||||
"""Create a new provider"""
|
||||
provider_data['uuid'] = str(uuid.uuid4())
|
||||
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
|
||||
)
|
||||
@@ -91,8 +71,6 @@ class ModelProviderService:
|
||||
"""Update an existing provider"""
|
||||
if 'uuid' in provider_data:
|
||||
del provider_data['uuid']
|
||||
if 'api_keys' in provider_data:
|
||||
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.ModelProvider)
|
||||
.where(persistence_model.ModelProvider.uuid == provider_uuid)
|
||||
@@ -119,14 +97,6 @@ class ModelProviderService:
|
||||
if embedding_result.first() is not None:
|
||||
raise ValueError('Cannot delete provider: Embedding models still reference it')
|
||||
|
||||
rerank_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.RerankModel).where(
|
||||
persistence_model.RerankModel.provider_uuid == provider_uuid
|
||||
)
|
||||
)
|
||||
if rerank_result.first() is not None:
|
||||
raise ValueError('Cannot delete provider: Rerank models still reference it')
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_model.ModelProvider).where(
|
||||
persistence_model.ModelProvider.uuid == provider_uuid
|
||||
@@ -151,19 +121,10 @@ class ModelProviderService:
|
||||
)
|
||||
embedding_count = embedding_result.scalar() or 0
|
||||
|
||||
rerank_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(sqlalchemy.func.count())
|
||||
.select_from(persistence_model.RerankModel)
|
||||
.where(persistence_model.RerankModel.provider_uuid == provider_uuid)
|
||||
)
|
||||
rerank_count = rerank_result.scalar() or 0
|
||||
|
||||
return {'llm_count': llm_count, 'embedding_count': embedding_count, 'rerank_count': rerank_count}
|
||||
return {'llm_count': llm_count, 'embedding_count': embedding_count}
|
||||
|
||||
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
|
||||
"""Find existing provider or create new one"""
|
||||
api_keys = self._normalize_api_keys(api_keys)
|
||||
|
||||
# Try to find existing provider with same config
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||
@@ -191,7 +152,7 @@ class ModelProviderService:
|
||||
'name': provider_name,
|
||||
'requester': requester,
|
||||
'base_url': base_url,
|
||||
'api_keys': api_keys,
|
||||
'api_keys': api_keys or [],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -200,69 +161,6 @@ class ModelProviderService:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.ModelProvider)
|
||||
.where(persistence_model.ModelProvider.uuid == '00000000-0000-0000-0000-000000000000')
|
||||
.values(api_keys=self._normalize_api_keys(api_key))
|
||||
.values(api_keys=[api_key])
|
||||
)
|
||||
await self.ap.model_mgr.reload_provider('00000000-0000-0000-0000-000000000000')
|
||||
|
||||
async def scan_provider_models(self, provider_uuid: str, model_type: str | None = None) -> dict:
|
||||
provider = await self.get_provider(provider_uuid)
|
||||
if provider is None:
|
||||
raise ValueError('provider not found')
|
||||
|
||||
runtime_provider = await self.ap.model_mgr.load_provider(provider)
|
||||
|
||||
try:
|
||||
scan_result = await runtime_provider.requester.scan_models(
|
||||
runtime_provider.token_mgr.get_token() if runtime_provider.token_mgr.tokens else None
|
||||
)
|
||||
except NotImplementedError:
|
||||
raise ValueError('current provider does not support model scanning')
|
||||
except Exception as exc:
|
||||
self.ap.logger.warning(
|
||||
f'Failed to scan models for provider {provider_uuid}: {exc}\n{traceback.format_exc()}'
|
||||
)
|
||||
raise ValueError(str(exc)) from exc
|
||||
|
||||
if isinstance(scan_result, dict):
|
||||
scanned_models = scan_result.get('models', [])
|
||||
debug_info = scan_result.get('debug')
|
||||
else:
|
||||
scanned_models = scan_result
|
||||
debug_info = None
|
||||
|
||||
llm_models = await self.ap.llm_model_service.get_llm_models_by_provider(provider_uuid)
|
||||
embedding_models = await self.ap.embedding_models_service.get_embedding_models_by_provider(provider_uuid)
|
||||
existing_llm_names = {model['name'] for model in llm_models}
|
||||
existing_embedding_names = {model['name'] for model in embedding_models}
|
||||
|
||||
filtered_models = []
|
||||
for model in scanned_models:
|
||||
scanned_type = model.get('type', 'llm')
|
||||
if model_type and scanned_type != model_type:
|
||||
continue
|
||||
|
||||
model_name = model.get('name') or model.get('id')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
filtered_models.append(
|
||||
{
|
||||
'id': model.get('id', model_name),
|
||||
'name': model_name,
|
||||
'type': scanned_type,
|
||||
'abilities': model.get('abilities', []),
|
||||
'display_name': model.get('display_name'),
|
||||
'description': model.get('description'),
|
||||
'context_length': model.get('context_length'),
|
||||
'owned_by': model.get('owned_by'),
|
||||
'input_modalities': model.get('input_modalities', []),
|
||||
'output_modalities': model.get('output_modalities', []),
|
||||
'already_added': (
|
||||
model_name in existing_embedding_names
|
||||
if scanned_type == 'embedding'
|
||||
else model_name in existing_llm_names
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return {'models': filtered_models, 'debug': debug_info}
|
||||
|
||||
@@ -179,7 +179,7 @@ class SpaceService:
|
||||
space_url = space_config['url']
|
||||
|
||||
session = httpclient.get_session()
|
||||
async with session.get(f'{space_url}/api/v1/models', params={'page_size': 100}) as response:
|
||||
async with session.get(f'{space_url}/api/v1/models') as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to get models: {await response.text()}')
|
||||
data = await response.json()
|
||||
|
||||
@@ -65,8 +65,8 @@ class UserService:
|
||||
|
||||
user_obj = result_list[0]
|
||||
|
||||
# Check if this user has a local password set
|
||||
if not user_obj.password:
|
||||
# Check if this is a Space account
|
||||
if user_obj.account_type == 'space':
|
||||
raise ValueError('请使用 Space 账户登录')
|
||||
|
||||
ph = argon2.PasswordHasher()
|
||||
@@ -108,8 +108,9 @@ class UserService:
|
||||
if user_obj is None:
|
||||
raise ValueError('User not found')
|
||||
|
||||
if not user_obj.password:
|
||||
raise ValueError('No local password set, please set a password first')
|
||||
# Space accounts cannot change password locally
|
||||
if user_obj.account_type == 'space':
|
||||
raise ValueError('Space account cannot change password locally')
|
||||
|
||||
ph.verify(user_obj.password, current_password)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from ..api.http.service import mcp as mcp_service
|
||||
from ..api.http.service import apikey as apikey_service
|
||||
from ..api.http.service import webhook as webhook_service
|
||||
from ..api.http.service import monitoring as monitoring_service
|
||||
from ..api.http.service import maintenance as maintenance_service
|
||||
from ..api.http.service import human_takeover as human_takeover_service
|
||||
|
||||
from ..discover import engine as discover_engine
|
||||
from ..storage import mgr as storagemgr
|
||||
@@ -134,8 +134,6 @@ class Application:
|
||||
|
||||
embedding_models_service: model_service.EmbeddingModelsService = None
|
||||
|
||||
rerank_models_service: model_service.RerankModelsService = None
|
||||
|
||||
provider_service: provider_service.ModelProviderService = None
|
||||
|
||||
pipeline_service: pipeline_service.PipelineService = None
|
||||
@@ -156,7 +154,7 @@ class Application:
|
||||
|
||||
monitoring_service: monitoring_service.MonitoringService = None
|
||||
|
||||
maintenance_service: maintenance_service.MaintenanceService = None
|
||||
human_takeover_service: human_takeover_service.HumanTakeoverService = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -197,30 +195,14 @@ class Application:
|
||||
monitoring_cfg = self.instance_config.data.get('monitoring', {})
|
||||
auto_cleanup_cfg = monitoring_cfg.get('auto_cleanup', {})
|
||||
if auto_cleanup_cfg.get('enabled', True):
|
||||
retention_days = self._get_positive_int_config(
|
||||
auto_cleanup_cfg.get('retention_days', 30),
|
||||
default=30,
|
||||
name='monitoring.auto_cleanup.retention_days',
|
||||
)
|
||||
delete_batch_size = self._get_positive_int_config(
|
||||
auto_cleanup_cfg.get('delete_batch_size', 1000),
|
||||
default=1000,
|
||||
name='monitoring.auto_cleanup.delete_batch_size',
|
||||
)
|
||||
check_interval_hours = self._get_positive_float_config(
|
||||
auto_cleanup_cfg.get('check_interval_hours', 1),
|
||||
default=1,
|
||||
name='monitoring.auto_cleanup.check_interval_hours',
|
||||
)
|
||||
retention_days = auto_cleanup_cfg.get('retention_days', 30)
|
||||
check_interval_hours = auto_cleanup_cfg.get('check_interval_hours', 1)
|
||||
|
||||
async def monitoring_cleanup_loop():
|
||||
check_interval_seconds = check_interval_hours * 3600
|
||||
while True:
|
||||
try:
|
||||
deleted = await self.monitoring_service.cleanup_expired_records(
|
||||
retention_days,
|
||||
batch_size=delete_batch_size,
|
||||
)
|
||||
deleted = await self.monitoring_service.cleanup_expired_records(retention_days)
|
||||
total_deleted = sum(deleted.values())
|
||||
if total_deleted > 0:
|
||||
self.logger.info(
|
||||
@@ -237,33 +219,6 @@ class Application:
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||
)
|
||||
|
||||
# Start storage/log maintenance task if enabled
|
||||
storage_cleanup_cfg = self.instance_config.data.get('storage', {}).get('cleanup', {})
|
||||
if storage_cleanup_cfg.get('enabled', True) and self.maintenance_service is not None:
|
||||
check_interval_hours = self._get_positive_float_config(
|
||||
storage_cleanup_cfg.get('check_interval_hours', 1),
|
||||
default=1,
|
||||
name='storage.cleanup.check_interval_hours',
|
||||
)
|
||||
|
||||
async def storage_cleanup_loop():
|
||||
check_interval_seconds = check_interval_hours * 3600
|
||||
while True:
|
||||
try:
|
||||
deleted = await self.maintenance_service.cleanup_expired_files()
|
||||
total_deleted = sum(deleted.values())
|
||||
if total_deleted > 0:
|
||||
self.logger.info(f'Storage maintenance: deleted expired files: {deleted}')
|
||||
except Exception as e:
|
||||
self.logger.warning(f'Storage maintenance error: {e}')
|
||||
await asyncio.sleep(check_interval_seconds)
|
||||
|
||||
self.task_mgr.create_task(
|
||||
storage_cleanup_loop(),
|
||||
name='storage-maintenance',
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||
)
|
||||
|
||||
self.task_mgr.create_task(
|
||||
never_ending(),
|
||||
name='never-ending-task',
|
||||
@@ -278,28 +233,6 @@ class Application:
|
||||
self.logger.error(f'Application runtime fatal exception: {e}')
|
||||
self.logger.debug(f'Traceback: {traceback.format_exc()}')
|
||||
|
||||
def _get_positive_int_config(self, value, default: int, name: str) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
self.logger.warning(f'Invalid {name}: {value!r}, using {default}')
|
||||
return default
|
||||
if parsed < 1:
|
||||
self.logger.warning(f'Invalid {name}: {value!r}, using {default}')
|
||||
return default
|
||||
return parsed
|
||||
|
||||
def _get_positive_float_config(self, value, default: float, name: str) -> float:
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
self.logger.warning(f'Invalid {name}: {value!r}, using {default}')
|
||||
return default
|
||||
if parsed <= 0:
|
||||
self.logger.warning(f'Invalid {name}: {value!r}, using {default}')
|
||||
return default
|
||||
return parsed
|
||||
|
||||
def dispose(self):
|
||||
self.plugin_connector.dispose()
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from ...api.http.service import mcp as mcp_service
|
||||
from ...api.http.service import apikey as apikey_service
|
||||
from ...api.http.service import webhook as webhook_service
|
||||
from ...api.http.service import monitoring as monitoring_service
|
||||
from ...api.http.service import maintenance as maintenance_service
|
||||
from ...api.http.service import human_takeover as human_takeover_service
|
||||
from ...discover import engine as discover_engine
|
||||
from ...storage import mgr as storagemgr
|
||||
from ...utils import logcache
|
||||
@@ -62,9 +62,6 @@ class BuildAppStage(stage.BootingStage):
|
||||
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
|
||||
ap.embedding_models_service = embedding_models_service_inst
|
||||
|
||||
rerank_models_service_inst = model_service.RerankModelsService(ap)
|
||||
ap.rerank_models_service = rerank_models_service_inst
|
||||
|
||||
provider_service_inst = provider_service.ModelProviderService(ap)
|
||||
ap.provider_service = provider_service_inst
|
||||
|
||||
@@ -168,8 +165,9 @@ class BuildAppStage(stage.BootingStage):
|
||||
monitoring_service_inst = monitoring_service.MonitoringService(ap)
|
||||
ap.monitoring_service = monitoring_service_inst
|
||||
|
||||
maintenance_service_inst = maintenance_service.MaintenanceService(ap)
|
||||
ap.maintenance_service = maintenance_service_inst
|
||||
human_takeover_service_inst = human_takeover_service.HumanTakeoverService(ap)
|
||||
await human_takeover_service_inst.initialize()
|
||||
ap.human_takeover_service = human_takeover_service_inst
|
||||
|
||||
async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None:
|
||||
await asyncio.sleep(3)
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import typing
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from . import app
|
||||
from . import entities as core_entities
|
||||
@@ -120,7 +119,6 @@ class TaskWrapper:
|
||||
self.label = label if label != '' else name
|
||||
self.task.set_name(name)
|
||||
self.scopes = scopes
|
||||
self.created_at = time.time()
|
||||
|
||||
def assume_exception(self):
|
||||
try:
|
||||
@@ -156,7 +154,6 @@ class TaskWrapper:
|
||||
'name': self.name,
|
||||
'label': self.label,
|
||||
'scopes': [scope.value for scope in self.scopes],
|
||||
'created_at': self.created_at,
|
||||
'task_context': self.task_context.to_dict(),
|
||||
'runtime': {
|
||||
'done': self.task.done(),
|
||||
@@ -196,8 +193,6 @@ class AsyncTaskManager:
|
||||
) -> TaskWrapper:
|
||||
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes)
|
||||
self.tasks.append(wrapper)
|
||||
wrapper.task.add_done_callback(lambda _: self._prune_completed_tasks())
|
||||
self._prune_completed_tasks()
|
||||
return wrapper
|
||||
|
||||
def create_user_task(
|
||||
@@ -231,15 +226,6 @@ class AsyncTaskManager:
|
||||
'id_index': TaskWrapper._id_index,
|
||||
}
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
completed = sum(1 for t in self.tasks if t.task.done())
|
||||
return {
|
||||
'total': len(self.tasks),
|
||||
'running': len(self.tasks) - completed,
|
||||
'completed': completed,
|
||||
'id_index': TaskWrapper._id_index,
|
||||
}
|
||||
|
||||
def get_task_by_id(self, id: int) -> TaskWrapper | None:
|
||||
for t in self.tasks:
|
||||
if t.id == id:
|
||||
@@ -257,27 +243,3 @@ class AsyncTaskManager:
|
||||
if not wrapper.task.done():
|
||||
wrapper.task.cancel()
|
||||
return
|
||||
|
||||
def _prune_completed_tasks(self):
|
||||
completed_limit = (
|
||||
self.ap.instance_config.data.get('system', {})
|
||||
.get('task_retention', {})
|
||||
.get(
|
||||
'completed_limit',
|
||||
200,
|
||||
)
|
||||
)
|
||||
try:
|
||||
completed_limit = int(completed_limit)
|
||||
except (TypeError, ValueError):
|
||||
completed_limit = 200
|
||||
if completed_limit < 1:
|
||||
completed_limit = 1
|
||||
|
||||
completed_tasks = [wrapper for wrapper in self.tasks if wrapper.task.done()]
|
||||
overflow = len(completed_tasks) - completed_limit
|
||||
if overflow <= 0:
|
||||
return
|
||||
|
||||
remove_ids = {wrapper.id for wrapper in completed_tasks[:overflow]}
|
||||
self.tasks = [wrapper for wrapper in self.tasks if wrapper.id not in remove_ids]
|
||||
|
||||
36
src/langbot/pkg/entity/persistence/human_takeover.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import sqlalchemy
|
||||
|
||||
from .base import Base
|
||||
|
||||
|
||||
class HumanTakeoverSession(Base):
|
||||
"""Human takeover session records.
|
||||
|
||||
Tracks which conversation sessions are currently under human operator control,
|
||||
bypassing the normal AI pipeline processing.
|
||||
"""
|
||||
|
||||
__tablename__ = 'human_takeover_sessions'
|
||||
|
||||
id = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
|
||||
session_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, unique=True, index=True)
|
||||
"""Corresponds to monitoring_sessions.session_id, format: 'person_{id}' or 'group_{id}'"""
|
||||
|
||||
bot_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
"""UUID of the bot whose session is being taken over"""
|
||||
|
||||
status = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, default='active', index=True)
|
||||
"""Takeover status: 'active' or 'released'"""
|
||||
|
||||
taken_by = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
"""Email/username of the admin who took over the session"""
|
||||
|
||||
taken_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False)
|
||||
"""Timestamp when the takeover started"""
|
||||
|
||||
released_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
"""Timestamp when the takeover was released (null if still active)"""
|
||||
|
||||
platform = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
user_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
user_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
@@ -59,22 +59,3 @@ class EmbeddingModel(Base):
|
||||
server_default=sqlalchemy.func.now(),
|
||||
onupdate=sqlalchemy.func.now(),
|
||||
)
|
||||
|
||||
|
||||
class RerankModel(Base):
|
||||
"""Rerank model"""
|
||||
|
||||
__tablename__ = 'rerank_models'
|
||||
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
prefered_ranking = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
server_default=sqlalchemy.func.now(),
|
||||
onupdate=sqlalchemy.func.now(),
|
||||
)
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
"""Alembic environment for LangBot.
|
||||
|
||||
This env.py is designed to be called programmatically (not via CLI).
|
||||
It supports both SQLite and PostgreSQL.
|
||||
|
||||
The sync connection is passed via config attributes by the runner.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode — emit SQL without a live connection."""
|
||||
url = context.config.get_main_option('sqlalchemy.url')
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={'paramstyle': 'named'},
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations with a live sync connection passed via config attributes."""
|
||||
connection: Connection = context.config.attributes.get('connection')
|
||||
if connection is None:
|
||||
raise RuntimeError('connection not provided in alembic config attributes')
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
# render_as_batch=True is critical for SQLite ALTER TABLE support
|
||||
render_as_batch=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -1,24 +0,0 @@
|
||||
# Alembic script.py.mako — template for auto-generated revisions
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -1,24 +0,0 @@
|
||||
"""baseline: stamp existing schema (db version 25)
|
||||
|
||||
This is a no-op migration that marks the starting point for Alembic.
|
||||
All tables already exist via create_all() + legacy DBMigration system.
|
||||
|
||||
Revision ID: 0001_baseline
|
||||
Revises: None
|
||||
Create Date: 2026-04-08
|
||||
"""
|
||||
|
||||
revision = '0001_baseline'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# No-op: existing schema is already at database_version=25
|
||||
# This revision serves as the Alembic baseline.
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -1,62 +0,0 @@
|
||||
"""example: sample migration demonstrating Alembic patterns
|
||||
|
||||
This is a SAMPLE showing how to write migrations that work
|
||||
seamlessly across SQLite and PostgreSQL. Delete or adapt as needed.
|
||||
|
||||
Revision ID: 0002_sample
|
||||
Revises: 0001_baseline
|
||||
Create Date: 2026-04-08
|
||||
|
||||
Patterns demonstrated:
|
||||
1. Schema change (add column) — works on both DBs via render_as_batch
|
||||
2. Data migration (read + modify JSON) — pure SQLAlchemy, no dialect branching
|
||||
"""
|
||||
|
||||
revision = '0002_sample'
|
||||
down_revision = '0001_baseline'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
EXAMPLE: Uncomment to use. This shows the patterns.
|
||||
|
||||
# --- Pattern 1: Schema change (add/drop column) ---
|
||||
# render_as_batch=True in env.py makes this work on SQLite too.
|
||||
#
|
||||
# op.add_column('pipelines', sa.Column('description', sa.String(512), server_default=''))
|
||||
|
||||
# --- Pattern 2: Data migration (read + modify JSON field) ---
|
||||
# No if/else for sqlite vs postgres needed!
|
||||
#
|
||||
# conn = op.get_bind()
|
||||
# rows = conn.execute(sa.text("SELECT uuid, config FROM pipelines")).fetchall()
|
||||
# for row in rows:
|
||||
# config = json.loads(row[1]) if isinstance(row[1], str) else row[1]
|
||||
# # Modify the config
|
||||
# config.setdefault('ai', {}).setdefault('some_new_key', 'default_value')
|
||||
# conn.execute(
|
||||
# sa.text("UPDATE pipelines SET config = :cfg WHERE uuid = :uuid"),
|
||||
# {"cfg": json.dumps(config), "uuid": row[0]}
|
||||
# )
|
||||
|
||||
# --- Pattern 3: Create a new table ---
|
||||
#
|
||||
# op.create_table(
|
||||
# 'audit_log',
|
||||
# sa.Column('id', sa.Integer, primary_key=True, autoincrement=True),
|
||||
# sa.Column('action', sa.String(255), nullable=False),
|
||||
# sa.Column('detail', sa.Text),
|
||||
# sa.Column('created_at', sa.DateTime, server_default=sa.func.now()),
|
||||
# )
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""
|
||||
# op.drop_column('pipelines', 'description')
|
||||
# op.drop_table('audit_log')
|
||||
"""
|
||||
pass
|
||||
@@ -1,35 +0,0 @@
|
||||
"""add rerank_models table
|
||||
|
||||
Revision ID: 0003_add_rerank_models
|
||||
Revises: 0002_sample
|
||||
Create Date: 2026-04-19
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '0003_add_rerank_models'
|
||||
down_revision = '0002_sample'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Check if table already exists (may have been created by create_all())
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
if 'rerank_models' not in inspector.get_table_names():
|
||||
op.create_table(
|
||||
'rerank_models',
|
||||
sa.Column('uuid', sa.String(255), primary_key=True, unique=True),
|
||||
sa.Column('name', sa.String(255), nullable=False),
|
||||
sa.Column('provider_uuid', sa.String(255), nullable=False),
|
||||
sa.Column('extra_args', sa.JSON, nullable=False, server_default='{}'),
|
||||
sa.Column('prefered_ranking', sa.Integer, nullable=False, server_default='0'),
|
||||
sa.Column('created_at', sa.DateTime, nullable=False, server_default=sa.func.now()),
|
||||
sa.Column('updated_at', sa.DateTime, nullable=False, server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('rerank_models')
|
||||
@@ -1,150 +0,0 @@
|
||||
"""Programmatic Alembic runner for LangBot.
|
||||
|
||||
Usage from async code:
|
||||
from langbot.pkg.persistence.alembic_runner import run_alembic_upgrade
|
||||
await run_alembic_upgrade(async_engine)
|
||||
|
||||
CLI usage (autogenerate):
|
||||
python -m langbot.pkg.persistence.alembic_runner autogenerate "add description column"
|
||||
python -m langbot.pkg.persistence.alembic_runner upgrade
|
||||
python -m langbot.pkg.persistence.alembic_runner current
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from alembic.config import Config
|
||||
from alembic import command
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
|
||||
_ALEMBIC_DIR = os.path.join(os.path.dirname(__file__), 'alembic')
|
||||
|
||||
|
||||
def _build_config(connection: Connection) -> Config:
|
||||
"""Build an Alembic Config with sync connection attached."""
|
||||
cfg = Config()
|
||||
cfg.set_main_option('script_location', _ALEMBIC_DIR)
|
||||
cfg.attributes['connection'] = connection
|
||||
return cfg
|
||||
|
||||
|
||||
def _do_upgrade(connection: Connection, revision: str = 'head') -> None:
|
||||
"""Synchronous upgrade — runs inside run_sync."""
|
||||
cfg = _build_config(connection)
|
||||
command.upgrade(cfg, revision)
|
||||
|
||||
|
||||
def _do_stamp(connection: Connection, revision: str = 'head') -> None:
|
||||
"""Synchronous stamp — runs inside run_sync."""
|
||||
cfg = _build_config(connection)
|
||||
command.stamp(cfg, revision)
|
||||
|
||||
|
||||
def _do_get_current(connection: Connection) -> str | None:
|
||||
"""Get current alembic revision synchronously."""
|
||||
ctx = MigrationContext.configure(connection)
|
||||
return ctx.get_current_revision()
|
||||
|
||||
|
||||
def _do_autogenerate(connection: Connection, message: str = 'auto migration') -> None:
|
||||
"""Synchronous autogenerate — runs inside run_sync."""
|
||||
cfg = _build_config(connection)
|
||||
command.revision(cfg, message=message, autogenerate=True)
|
||||
|
||||
|
||||
async def run_alembic_upgrade(async_engine: AsyncEngine, revision: str = 'head') -> None:
|
||||
"""Run Alembic upgrade to the given revision."""
|
||||
async with async_engine.connect() as conn:
|
||||
await conn.run_sync(_do_upgrade, revision)
|
||||
await conn.commit()
|
||||
|
||||
|
||||
async def run_alembic_stamp(async_engine: AsyncEngine, revision: str = 'head') -> None:
|
||||
"""Stamp the database with a revision without running migrations."""
|
||||
async with async_engine.connect() as conn:
|
||||
await conn.run_sync(_do_stamp, revision)
|
||||
await conn.commit()
|
||||
|
||||
|
||||
async def get_alembic_current(async_engine: AsyncEngine) -> str | None:
|
||||
"""Get current alembic revision, or None if not stamped."""
|
||||
async with async_engine.connect() as conn:
|
||||
return await conn.run_sync(_do_get_current)
|
||||
|
||||
|
||||
async def run_alembic_autogenerate(async_engine: AsyncEngine, message: str = 'auto migration') -> None:
|
||||
"""Compare ORM models against DB schema and generate a migration script."""
|
||||
async with async_engine.connect() as conn:
|
||||
await conn.run_sync(_do_autogenerate, message)
|
||||
|
||||
|
||||
# CLI entrypoint: python -m langbot.pkg.persistence.alembic_runner <command> [args]
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
def _get_engine():
|
||||
"""Create engine from data/config.yaml or default SQLite."""
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
try:
|
||||
import yaml
|
||||
|
||||
with open('data/config.yaml') as f:
|
||||
config = yaml.safe_load(f)
|
||||
db_cfg = config.get('database', {})
|
||||
db_type = db_cfg.get('use', 'sqlite')
|
||||
if db_type == 'postgresql':
|
||||
pg = db_cfg.get('postgresql', {})
|
||||
url = (
|
||||
f'postgresql+asyncpg://{pg.get("user", "postgres")}:{pg.get("password", "postgres")}'
|
||||
f'@{pg.get("host", "127.0.0.1")}:{pg.get("port", 5432)}/{pg.get("database", "postgres")}'
|
||||
)
|
||||
else:
|
||||
path = db_cfg.get('sqlite', {}).get('path', 'data/langbot.db')
|
||||
url = f'sqlite+aiosqlite:///{path}'
|
||||
except Exception:
|
||||
url = 'sqlite+aiosqlite:///data/langbot.db'
|
||||
|
||||
return create_async_engine(url)
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print('Usage: python -m langbot.pkg.persistence.alembic_runner <command> [args]')
|
||||
print('Commands:')
|
||||
print(' autogenerate "message" — Generate migration from ORM model diff')
|
||||
print(' upgrade [revision] — Upgrade database (default: head)')
|
||||
print(' stamp [revision] — Stamp revision without running (default: head)')
|
||||
print(' current — Show current revision')
|
||||
sys.exit(1)
|
||||
|
||||
cmd = sys.argv[1]
|
||||
engine = _get_engine()
|
||||
|
||||
if cmd == 'autogenerate':
|
||||
msg = sys.argv[2] if len(sys.argv) > 2 else 'auto migration'
|
||||
asyncio.run(run_alembic_autogenerate(engine, msg))
|
||||
print(f'Migration generated: {msg}')
|
||||
elif cmd == 'upgrade':
|
||||
rev = sys.argv[2] if len(sys.argv) > 2 else 'head'
|
||||
asyncio.run(run_alembic_upgrade(engine, rev))
|
||||
print(f'Upgraded to: {rev}')
|
||||
elif cmd == 'stamp':
|
||||
rev = sys.argv[2] if len(sys.argv) > 2 else 'head'
|
||||
asyncio.run(run_alembic_stamp(engine, rev))
|
||||
print(f'Stamped: {rev}')
|
||||
elif cmd == 'current':
|
||||
rev = asyncio.run(get_alembic_current(engine))
|
||||
print(f'Current revision: {rev}')
|
||||
else:
|
||||
print(f'Unknown command: {cmd}')
|
||||
sys.exit(1)
|
||||
|
||||
main()
|
||||
@@ -76,9 +76,6 @@ class PersistenceManager:
|
||||
|
||||
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
|
||||
|
||||
# Run Alembic migrations (new migration system)
|
||||
await self._run_alembic_migrations()
|
||||
|
||||
await self.write_space_model_providers()
|
||||
|
||||
async def create_tables(self):
|
||||
@@ -138,28 +135,6 @@ class PersistenceManager:
|
||||
|
||||
# =================================
|
||||
|
||||
async def _run_alembic_migrations(self):
|
||||
"""Run Alembic-based migrations after legacy migrations complete."""
|
||||
from . import alembic_runner
|
||||
|
||||
engine = self.get_db_engine()
|
||||
|
||||
try:
|
||||
current_rev = await alembic_runner.get_alembic_current(engine)
|
||||
|
||||
if current_rev is None:
|
||||
# First time: stamp baseline so Alembic knows existing schema is up-to-date
|
||||
self.ap.logger.info('Alembic: no revision found, stamping baseline...')
|
||||
await alembic_runner.run_alembic_stamp(engine, '0001_baseline')
|
||||
current_rev = '0001_baseline'
|
||||
|
||||
# Upgrade to head
|
||||
await alembic_runner.run_alembic_upgrade(engine, 'head')
|
||||
self.ap.logger.info('Alembic migrations completed.')
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Alembic migration failed: {e}', exc_info=True)
|
||||
raise
|
||||
|
||||
async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult:
|
||||
async with self.get_db_engine().connect() as conn:
|
||||
result = await conn.execute(*args, **kwargs)
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(26)
|
||||
class DBMigrateHumanTakeoverSessions(migration.DBMigration):
|
||||
"""Create human_takeover_sessions table for human operator takeover support"""
|
||||
|
||||
async def upgrade(self):
|
||||
sql_text = sqlalchemy.text("""
|
||||
CREATE TABLE IF NOT EXISTS human_takeover_sessions (
|
||||
id VARCHAR(255) PRIMARY KEY,
|
||||
session_id VARCHAR(255) NOT NULL UNIQUE,
|
||||
bot_uuid VARCHAR(255) NOT NULL,
|
||||
status VARCHAR(50) NOT NULL DEFAULT 'active',
|
||||
taken_by VARCHAR(255),
|
||||
taken_at DATETIME NOT NULL,
|
||||
released_at DATETIME,
|
||||
platform VARCHAR(255),
|
||||
user_id VARCHAR(255),
|
||||
user_name VARCHAR(255)
|
||||
)
|
||||
""")
|
||||
await self.ap.persistence_mgr.execute_async(sql_text)
|
||||
|
||||
# Create indexes
|
||||
for idx_sql in [
|
||||
'CREATE INDEX IF NOT EXISTS idx_hts_session_id ON human_takeover_sessions (session_id)',
|
||||
'CREATE INDEX IF NOT EXISTS idx_hts_bot_uuid ON human_takeover_sessions (bot_uuid)',
|
||||
'CREATE INDEX IF NOT EXISTS idx_hts_status ON human_takeover_sessions (status)',
|
||||
]:
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.text(idx_sql))
|
||||
|
||||
async def downgrade(self):
|
||||
sql_text = sqlalchemy.text('DROP TABLE IF EXISTS human_takeover_sessions')
|
||||
await self.ap.persistence_mgr.execute_async(sql_text)
|
||||
@@ -297,9 +297,6 @@ class RuntimePipeline:
|
||||
)
|
||||
# Store message_id in query variables for LLM call monitoring
|
||||
query.variables['_monitoring_message_id'] = message_id
|
||||
# Notify adapter so it can map platform-specific IDs to monitoring message ID
|
||||
if hasattr(query.adapter, 'on_monitoring_message_created'):
|
||||
await query.adapter.on_monitoring_message_created(query, message_id)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to record query start: {e}')
|
||||
|
||||
|
||||
@@ -75,27 +75,6 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.bot_uuid,
|
||||
)
|
||||
|
||||
# Expire externally managed conversation ids after the conversation has
|
||||
# been idle for longer than the configured conversation expire time.
|
||||
# The idle window is measured from the last preprocess/update time, not
|
||||
# from the conversation creation time.
|
||||
conversation_expire_time = query.pipeline_config.get('ai', {}).get('runner', {}).get('expire-time', None)
|
||||
now = datetime.datetime.now()
|
||||
if conversation_expire_time is not None and conversation_expire_time > 0:
|
||||
last_update_time = getattr(conversation, 'update_time', None) or getattr(conversation, 'create_time', None)
|
||||
if last_update_time is not None:
|
||||
conversation_idle_time = now.timestamp() - last_update_time.timestamp()
|
||||
if conversation_idle_time > conversation_expire_time:
|
||||
self.ap.logger.info(
|
||||
f'Conversation({query.query_id}) is expired (idle: {conversation_idle_time}s), create new conversation'
|
||||
)
|
||||
conversation.uuid = None
|
||||
|
||||
# Treat every preprocess pass as a conversation activity update. This
|
||||
# makes future expiry checks use the latest incoming message/preprocess
|
||||
# time instead of the first message/creation time.
|
||||
conversation.update_time = now
|
||||
|
||||
# 设置query
|
||||
query.session = session
|
||||
query.prompt = conversation.prompt.copy()
|
||||
@@ -181,10 +160,8 @@ class PreProcessor(stage.PipelineStage):
|
||||
elif me.url:
|
||||
content_list.append(provider_message.ContentElement.from_file_url(me.url, 'voice'))
|
||||
elif isinstance(me, platform_message.File):
|
||||
if me.base64:
|
||||
content_list.append(provider_message.ContentElement.from_file_base64(me.base64, me.name))
|
||||
elif me.url:
|
||||
content_list.append(provider_message.ContentElement.from_file_url(me.url, me.name))
|
||||
# if me.url is not None:
|
||||
content_list.append(provider_message.ContentElement.from_file_url(me.url, me.name))
|
||||
elif isinstance(me, platform_message.Quote) and quote_msg:
|
||||
for msg in me.origin:
|
||||
if isinstance(msg, platform_message.Plain):
|
||||
@@ -195,18 +172,6 @@ class PreProcessor(stage.PipelineStage):
|
||||
):
|
||||
if msg.base64 is not None:
|
||||
content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))
|
||||
elif isinstance(msg, platform_message.File):
|
||||
if msg.base64:
|
||||
content_list.append(provider_message.ContentElement.from_file_base64(msg.base64, msg.name))
|
||||
elif msg.url:
|
||||
content_list.append(provider_message.ContentElement.from_file_url(msg.url, msg.name))
|
||||
elif isinstance(msg, platform_message.Voice):
|
||||
if msg.base64:
|
||||
content_list.append(
|
||||
provider_message.ContentElement.from_file_base64(msg.base64, 'voice.silk')
|
||||
)
|
||||
elif msg.url:
|
||||
content_list.append(provider_message.ContentElement.from_file_url(msg.url, 'voice'))
|
||||
|
||||
query.variables['user_message_text'] = plain_text
|
||||
|
||||
|
||||
@@ -208,7 +208,6 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
'model_name': model_name,
|
||||
'version': constants.semantic_version,
|
||||
'instance_id': constants.instance_id,
|
||||
'edition': constants.edition,
|
||||
'pipeline_plugins': pipeline_plugins,
|
||||
'error': locals().get('error_info', None),
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
|
||||
@@ -220,6 +220,47 @@ class RuntimeBot:
|
||||
|
||||
# Only add to query pool if no webhook requested to skip pipeline
|
||||
if not skip_pipeline:
|
||||
# Check if session is under human takeover
|
||||
person_session_id = f'person_{event.sender.id}'
|
||||
if (
|
||||
hasattr(self.ap, 'human_takeover_service')
|
||||
and self.ap.human_takeover_service
|
||||
and self.ap.human_takeover_service.is_taken_over(person_session_id)
|
||||
):
|
||||
# Session is taken over: record message to monitoring then stop
|
||||
await self.logger.info(
|
||||
f'Person message intercepted by human takeover for session {person_session_id}'
|
||||
)
|
||||
try:
|
||||
if hasattr(event.message_chain, 'model_dump'):
|
||||
msg_content = json.dumps(event.message_chain.model_dump(), ensure_ascii=False)
|
||||
else:
|
||||
msg_content = str(event.message_chain)
|
||||
|
||||
sender_name = None
|
||||
if hasattr(event, 'sender') and hasattr(event.sender, 'nickname'):
|
||||
sender_name = event.sender.nickname
|
||||
|
||||
await self.ap.monitoring_service.record_message(
|
||||
bot_id=self.bot_entity.uuid,
|
||||
bot_name=self.bot_entity.name or self.bot_entity.uuid,
|
||||
pipeline_id='__human_takeover__',
|
||||
pipeline_name='Human Takeover',
|
||||
message_content=msg_content,
|
||||
session_id=person_session_id,
|
||||
status='success',
|
||||
level='info',
|
||||
platform=adapter.__class__.__name__,
|
||||
user_id=str(event.sender.id),
|
||||
user_name=sender_name,
|
||||
role='user',
|
||||
)
|
||||
|
||||
await self.ap.monitoring_service.update_session_activity(person_session_id)
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Failed to record takeover message: {e}')
|
||||
return
|
||||
|
||||
launcher_id = event.sender.id
|
||||
|
||||
if hasattr(adapter, 'get_launcher_id'):
|
||||
@@ -281,6 +322,50 @@ class RuntimeBot:
|
||||
|
||||
# Only add to query pool if no webhook requested to skip pipeline
|
||||
if not skip_pipeline:
|
||||
# Check if session is under human takeover
|
||||
group_session_id = f'group_{event.group.id}'
|
||||
if (
|
||||
hasattr(self.ap, 'human_takeover_service')
|
||||
and self.ap.human_takeover_service
|
||||
and self.ap.human_takeover_service.is_taken_over(group_session_id)
|
||||
):
|
||||
# Session is taken over: record message to monitoring then stop
|
||||
await self.logger.info(
|
||||
f'Group message intercepted by human takeover for session {group_session_id}'
|
||||
)
|
||||
try:
|
||||
if hasattr(event.message_chain, 'model_dump'):
|
||||
msg_content = json.dumps(event.message_chain.model_dump(), ensure_ascii=False)
|
||||
else:
|
||||
msg_content = str(event.message_chain)
|
||||
|
||||
sender_name = None
|
||||
if hasattr(event, 'sender'):
|
||||
if hasattr(event.sender, 'member_name'):
|
||||
sender_name = event.sender.member_name
|
||||
elif hasattr(event.sender, 'nickname'):
|
||||
sender_name = event.sender.nickname
|
||||
|
||||
await self.ap.monitoring_service.record_message(
|
||||
bot_id=self.bot_entity.uuid,
|
||||
bot_name=self.bot_entity.name or self.bot_entity.uuid,
|
||||
pipeline_id='__human_takeover__',
|
||||
pipeline_name='Human Takeover',
|
||||
message_content=msg_content,
|
||||
session_id=group_session_id,
|
||||
status='success',
|
||||
level='info',
|
||||
platform=adapter.__class__.__name__,
|
||||
user_id=str(event.sender.id),
|
||||
user_name=sender_name,
|
||||
role='user',
|
||||
)
|
||||
|
||||
await self.ap.monitoring_service.update_session_activity(group_session_id)
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Failed to record takeover message: {e}')
|
||||
return
|
||||
|
||||
launcher_id = event.group.id
|
||||
|
||||
if hasattr(adapter, 'get_launcher_id'):
|
||||
@@ -523,7 +608,7 @@ class PlatformManager:
|
||||
return None
|
||||
|
||||
async def remove_bot(self, bot_uuid: str):
|
||||
for bot in self.bots[:]:
|
||||
for bot in self.bots:
|
||||
if bot.bot_entity.uuid == bot_uuid:
|
||||
if bot.enable:
|
||||
await bot.shutdown()
|
||||
|
||||
@@ -71,8 +71,7 @@ class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
||||
yiri_msg_list.append(platform_message.Image(base64=element['Picture']))
|
||||
else:
|
||||
# 回退到原有简单逻辑
|
||||
# 对于音频消息,content 来自 recognition 转写文字,在下方音频处理块中统一处理
|
||||
if event.content and event.type != 'audio':
|
||||
if event.content:
|
||||
text_content = event.content.replace('@' + bot_name, '')
|
||||
yiri_msg_list.append(platform_message.Plain(text=text_content))
|
||||
if event.picture:
|
||||
@@ -82,38 +81,7 @@ class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
||||
if event.file:
|
||||
yiri_msg_list.append(platform_message.File(url=event.file, name=event.name))
|
||||
if event.audio:
|
||||
# 优先使用钉钉自带的语音转写文字(recognition字段)
|
||||
if event.content and event.type == 'audio':
|
||||
yiri_msg_list.append(platform_message.Plain(text=event.content))
|
||||
else:
|
||||
yiri_msg_list.append(platform_message.Voice(base64=event.audio))
|
||||
|
||||
# Handle quoted/replied message - extract content as top-level components
|
||||
# so that plugins like FileReader can process them the same way as direct messages
|
||||
if event.quoted_message:
|
||||
quote_info = event.quoted_message
|
||||
msg_type = quote_info.get('msg_type', '')
|
||||
|
||||
# Process quoted file - add as top-level File component (same as private chat)
|
||||
if msg_type == 'file' and quote_info.get('file_url'):
|
||||
file_name = quote_info.get('file_name', 'file')
|
||||
yiri_msg_list.append(platform_message.File(url=quote_info['file_url'], name=file_name))
|
||||
|
||||
# Process quoted image - add as top-level Image component
|
||||
elif msg_type == 'picture' and quote_info.get('picture'):
|
||||
yiri_msg_list.append(platform_message.Image(base64=quote_info['picture']))
|
||||
|
||||
# Process quoted audio - add as top-level Voice component
|
||||
elif msg_type == 'audio' and quote_info.get('audio'):
|
||||
yiri_msg_list.append(platform_message.Voice(base64=quote_info['audio']))
|
||||
|
||||
# Process quoted text - add as Plain text with context prefix
|
||||
elif msg_type == 'text' and quote_info.get('content'):
|
||||
yiri_msg_list.append(platform_message.Plain(text=f'[引用消息] {quote_info["content"]}'))
|
||||
|
||||
# Process quoted rich text - add as Plain text with context prefix
|
||||
elif msg_type == 'richText' and quote_info.get('content'):
|
||||
yiri_msg_list.append(platform_message.Plain(text=f'[引用消息] {quote_info["content"]}'))
|
||||
yiri_msg_list.append(platform_message.Voice(base64=event.audio))
|
||||
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
|
||||
|
||||
@@ -19,18 +19,6 @@ spec:
|
||||
en: https://link.langbot.app/en/platforms/dingtalk
|
||||
ja: https://link.langbot.app/ja/platforms/dingtalk
|
||||
config:
|
||||
- name: one-click-create
|
||||
label:
|
||||
en_US: One-Click Create App
|
||||
zh_Hans: 一键创建应用
|
||||
zh_Hant: 一鍵建立應用
|
||||
description:
|
||||
en_US: "Scan QR code with DingTalk to automatically create an app and fill in credentials. Note: Robot Code cannot be obtained automatically, you need to copy it from the DingTalk Developer Backend manually."
|
||||
zh_Hans: "使用钉钉扫码自动创建应用并填写凭据。注意:机器人代码无法自动获取,需前往钉钉开发者后台手动复制。"
|
||||
zh_Hant: "使用釘釘掃碼自動建立應用並填寫憑證。注意:機器人代碼無法自動取得,需前往釘釘開發者後台手動複製。"
|
||||
type: qr-code-login
|
||||
login_platform: dingtalk
|
||||
required: false
|
||||
- name: client_id
|
||||
label:
|
||||
en_US: Client ID
|
||||
@@ -52,10 +40,6 @@ spec:
|
||||
en_US: Robot Code
|
||||
zh_Hans: 机器人代码
|
||||
zh_Hant: 機器人代碼
|
||||
description:
|
||||
en_US: "Required for image recognition, file upload and other features. Get it from DingTalk Developer Backend > Robot Configuration."
|
||||
zh_Hans: "识图、上传文件等功能必填。请前往钉钉开发者后台 > 机器人配置中获取。"
|
||||
zh_Hant: "識圖、上傳檔案等功能必填。請前往釘釘開發者後台 > 機器人設定中取得。"
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
|
||||
@@ -709,29 +709,21 @@ class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
message_chain = await LarkMessageConverter.target2yiri(event.event.message, api_client)
|
||||
|
||||
# Check for quote/reply message
|
||||
# Extract files/images/voice from quote and add them as top-level components
|
||||
# so that plugins like FileReader can process them the same way as direct messages
|
||||
quote_message_id = LarkEventConverter._extract_quote_message_id(event.event.message)
|
||||
if quote_message_id:
|
||||
quote_chain = await LarkEventConverter._fetch_quoted_message(quote_message_id, api_client)
|
||||
if quote_chain:
|
||||
# Filter out Source component from quoted chain, keep only content
|
||||
quote_components = [comp for comp in quote_chain if not isinstance(comp, platform_message.Source)]
|
||||
|
||||
# Add quoted content as top-level components instead of wrapping in Quote
|
||||
for comp in quote_components:
|
||||
if isinstance(comp, platform_message.File):
|
||||
# Add file as top-level component (same as direct message)
|
||||
message_chain.append(comp)
|
||||
elif isinstance(comp, platform_message.Image):
|
||||
# Add image as top-level component
|
||||
message_chain.append(comp)
|
||||
elif isinstance(comp, platform_message.Voice):
|
||||
# Add voice as top-level component
|
||||
message_chain.append(comp)
|
||||
elif isinstance(comp, platform_message.Plain):
|
||||
# Add text with context prefix
|
||||
message_chain.append(platform_message.Plain(text=f'[引用消息] {comp.text}'))
|
||||
quote_origin = platform_message.MessageChain(
|
||||
[comp for comp in quote_chain if not isinstance(comp, platform_message.Source)]
|
||||
)
|
||||
if quote_origin:
|
||||
message_chain.append(
|
||||
platform_message.Quote(
|
||||
message_id=quote_message_id,
|
||||
origin=quote_origin,
|
||||
)
|
||||
)
|
||||
|
||||
if event.event.message.chat_type == 'p2p':
|
||||
return platform_events.FriendMessage(
|
||||
@@ -787,13 +779,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
|
||||
card_id_dict: dict[str, str] # 消息id到卡片id的映射,便于创建卡片后的发送消息到指定卡片
|
||||
|
||||
# Monitoring message ID mapping for feedback correlation
|
||||
# Temp: user Lark message ID → monitoring_message_id (populated by on_monitoring_message_created, consumed by create_message_card)
|
||||
pending_monitoring_msg: dict[str, str]
|
||||
# Final: reply Lark message ID → (monitoring_message_id, timestamp) (used by feedback callbacks)
|
||||
reply_to_monitoring_msg: dict[str, tuple[str, float]]
|
||||
_MONITORING_MAPPING_TTL = 600 # 10 minutes
|
||||
|
||||
seq: int # 用于在发送卡片消息中识别消息顺序,直接以seq作为标识
|
||||
bot_uuid: str = None # 机器人UUID
|
||||
app_ticket: str = None # 商店应用用到
|
||||
@@ -840,11 +825,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
else:
|
||||
session_id = None
|
||||
|
||||
# Resolve monitoring message ID from reply message mapping
|
||||
monitoring_msg_id = None
|
||||
if open_message_id and open_message_id in self.reply_to_monitoring_msg:
|
||||
monitoring_msg_id = self.reply_to_monitoring_msg[open_message_id][0]
|
||||
|
||||
feedback_event = platform_events.FeedbackEvent(
|
||||
feedback_id=getattr(event.header, 'event_id', str(uuid.uuid4())),
|
||||
feedback_type=feedback_type,
|
||||
@@ -852,7 +832,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_id=open_message_id,
|
||||
stream_id=monitoring_msg_id,
|
||||
source_platform_object=event,
|
||||
)
|
||||
|
||||
@@ -891,8 +870,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
logger=logger,
|
||||
lark_tenant_key=config.get('lark_tenant_key', ''),
|
||||
card_id_dict={},
|
||||
pending_monitoring_msg={},
|
||||
reply_to_monitoring_msg={},
|
||||
seq=1,
|
||||
listeners={},
|
||||
quart_app=quart_app,
|
||||
@@ -1025,90 +1002,7 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
return api_client
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
text_elements, media_items = await self.message_converter.yiri2target(message, self.api_client)
|
||||
|
||||
# Map standard target_type to Feishu receive_id_type
|
||||
if target_type == 'person':
|
||||
receive_id_type = 'open_id'
|
||||
elif target_type == 'group':
|
||||
receive_id_type = 'chat_id'
|
||||
else:
|
||||
receive_id_type = target_type
|
||||
|
||||
# Send text message if there are text elements
|
||||
if text_elements:
|
||||
needs_post = any(ele['tag'] == 'at' for paragraph in text_elements for ele in paragraph)
|
||||
|
||||
if needs_post:
|
||||
msg_type = 'post'
|
||||
final_content = json.dumps(
|
||||
{
|
||||
'zh_Hans': {
|
||||
'title': '',
|
||||
'content': text_elements,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
msg_type = 'text'
|
||||
parts = []
|
||||
for paragraph in text_elements:
|
||||
para_text = ''.join(ele.get('text', '') for ele in paragraph)
|
||||
if para_text:
|
||||
parts.append(para_text)
|
||||
final_content = json.dumps({'text': '\n\n'.join(parts)})
|
||||
|
||||
request: CreateMessageRequest = (
|
||||
CreateMessageRequest.builder()
|
||||
.receive_id_type(receive_id_type)
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(target_id)
|
||||
.content(final_content)
|
||||
.msg_type(msg_type)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
app_access_token = self.get_app_access_token()
|
||||
req_opt: RequestOption = (
|
||||
RequestOption.builder().app_ticket(self.app_ticket).app_access_token(app_access_token).build()
|
||||
)
|
||||
response: CreateMessageResponse = self.api_client.im.v1.message.create(request, req_opt)
|
||||
|
||||
if not response.success():
|
||||
raise Exception(
|
||||
f'client.im.v1.message.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
|
||||
# Send media messages separately (image, audio, file, etc.)
|
||||
for media in media_items:
|
||||
request: CreateMessageRequest = (
|
||||
CreateMessageRequest.builder()
|
||||
.receive_id_type(receive_id_type)
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(target_id)
|
||||
.content(json.dumps(media['content']))
|
||||
.msg_type(media['msg_type'])
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
app_access_token = self.get_app_access_token()
|
||||
req_opt: RequestOption = (
|
||||
RequestOption.builder().app_ticket(self.app_ticket).app_access_token(app_access_token).build()
|
||||
)
|
||||
response: CreateMessageResponse = self.api_client.im.v1.message.create(request, req_opt)
|
||||
|
||||
if not response.success():
|
||||
raise Exception(
|
||||
f'client.im.v1.message.create ({media["msg_type"]}) failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
pass
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
is_stream = False
|
||||
@@ -1116,22 +1010,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
is_stream = True
|
||||
return is_stream
|
||||
|
||||
async def on_monitoring_message_created(self, query, monitoring_message_id: str):
|
||||
"""Called by pipeline after monitoring message is created, to map user message ID to monitoring message ID."""
|
||||
try:
|
||||
user_msg_id = query.message_event.message_chain.message_id
|
||||
if user_msg_id:
|
||||
self.pending_monitoring_msg[user_msg_id] = monitoring_message_id
|
||||
except Exception as e:
|
||||
await self.logger.debug(f'Failed to map message to monitoring message: {e}')
|
||||
|
||||
def _cleanup_monitoring_mapping(self):
|
||||
"""Remove entries older than TTL from the reply-to-monitoring mapping."""
|
||||
now = time.time()
|
||||
expired = [k for k, (_, ts) in self.reply_to_monitoring_msg.items() if now - ts > self._MONITORING_MAPPING_TTL]
|
||||
for k in expired:
|
||||
del self.reply_to_monitoring_msg[k]
|
||||
|
||||
async def create_card_id(self, message_id):
|
||||
try:
|
||||
# self.logger.debug('飞书支持stream输出,创建卡片......')
|
||||
@@ -1371,18 +1249,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
raise Exception(
|
||||
f'client.im.v1.message.reply failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
|
||||
# Transfer monitoring message mapping: user msg ID → reply msg ID
|
||||
try:
|
||||
user_msg_id = event.message_chain.message_id
|
||||
reply_msg_id = getattr(response.data, 'message_id', None)
|
||||
monitoring_msg_id = self.pending_monitoring_msg.pop(user_msg_id, None)
|
||||
if reply_msg_id and monitoring_msg_id:
|
||||
self.reply_to_monitoring_msg[reply_msg_id] = (monitoring_msg_id, time.time())
|
||||
self._cleanup_monitoring_mapping()
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.logger.debug(f'Failed to transfer monitoring mapping in create_message_card: {e}'))
|
||||
|
||||
return True
|
||||
|
||||
async def reply_message(
|
||||
@@ -1693,11 +1559,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
else:
|
||||
session_id = None
|
||||
|
||||
# Resolve monitoring message ID from reply message mapping
|
||||
monitoring_msg_id = None
|
||||
if open_message_id and open_message_id in self.reply_to_monitoring_msg:
|
||||
monitoring_msg_id = self.reply_to_monitoring_msg[open_message_id][0]
|
||||
|
||||
feedback_event = platform_events.FeedbackEvent(
|
||||
feedback_id=data.get('header', {}).get('event_id', str(uuid.uuid4())),
|
||||
feedback_type=feedback_type,
|
||||
@@ -1705,7 +1566,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_id=open_message_id,
|
||||
stream_id=monitoring_msg_id,
|
||||
source_platform_object=data,
|
||||
)
|
||||
|
||||
|
||||
@@ -23,20 +23,6 @@ spec:
|
||||
en: https://link.langbot.app/en/platforms/lark
|
||||
ja: https://link.langbot.app/ja/platforms/lark
|
||||
config:
|
||||
- name: one-click-create
|
||||
label:
|
||||
en_US: One-Click Create App
|
||||
zh_Hans: 一键创建应用
|
||||
zh_Hant: 一鍵建立應用
|
||||
ja_JP: ワンクリックでアプリ作成
|
||||
description:
|
||||
en_US: Scan QR code to automatically create a Feishu app and fill in credentials
|
||||
zh_Hans: 扫码自动创建飞书应用并填写凭据
|
||||
zh_Hant: 掃碼自動建立飛書應用並填寫憑證
|
||||
ja_JP: QRコードをスキャンしてFeishuアプリを自動作成し、認証情報を入力
|
||||
type: qr-code-login
|
||||
login_platform: feishu
|
||||
required: false
|
||||
- name: app_id
|
||||
label:
|
||||
en_US: App ID
|
||||
|
||||
|
Before Width: | Height: | Size: 3.4 KiB |
@@ -1,693 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import asyncio
|
||||
import traceback
|
||||
import base64
|
||||
import json
|
||||
|
||||
import nio
|
||||
|
||||
from langbot.pkg.utils import httpclient
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||
|
||||
|
||||
class MatrixMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain: platform_message.MessageChain, client: nio.AsyncClient) -> list[dict]:
|
||||
components = []
|
||||
for component in message_chain:
|
||||
if isinstance(component, platform_message.Plain):
|
||||
components.append({'type': 'text', 'text': component.text})
|
||||
elif isinstance(component, platform_message.Image):
|
||||
image_bytes = None
|
||||
if component.base64:
|
||||
b64_data = component.base64
|
||||
if ';base64,' in b64_data:
|
||||
b64_data = b64_data.split(';base64,', 1)[1]
|
||||
image_bytes = base64.b64decode(b64_data)
|
||||
elif component.url:
|
||||
session = httpclient.get_session()
|
||||
async with session.get(component.url) as response:
|
||||
image_bytes = await response.read()
|
||||
elif component.path:
|
||||
with open(component.path, 'rb') as f:
|
||||
image_bytes = f.read()
|
||||
if image_bytes:
|
||||
resp = await client.upload(image_bytes, content_type='image/png')
|
||||
if isinstance(resp, nio.UploadResponse):
|
||||
components.append({'type': 'image', 'mxc_url': resp.content_uri})
|
||||
elif isinstance(component, platform_message.File):
|
||||
file_bytes = None
|
||||
if component.base64:
|
||||
b64_data = component.base64
|
||||
if ';base64,' in b64_data:
|
||||
b64_data = b64_data.split(';base64,', 1)[1]
|
||||
file_bytes = base64.b64decode(b64_data)
|
||||
elif component.url:
|
||||
session = httpclient.get_session()
|
||||
async with session.get(component.url) as response:
|
||||
file_bytes = await response.read()
|
||||
elif component.path:
|
||||
with open(component.path, 'rb') as f:
|
||||
file_bytes = f.read()
|
||||
if file_bytes:
|
||||
file_name = getattr(component, 'name', None) or 'file'
|
||||
resp = await client.upload(file_bytes, content_type='application/octet-stream', filename=file_name)
|
||||
if isinstance(resp, nio.UploadResponse):
|
||||
components.append(
|
||||
{
|
||||
'type': 'file',
|
||||
'mxc_url': resp.content_uri,
|
||||
'filename': file_name,
|
||||
'size': len(file_bytes),
|
||||
}
|
||||
)
|
||||
elif isinstance(component, platform_message.Forward):
|
||||
for node in component.node_list:
|
||||
components.extend(await MatrixMessageConverter.yiri2target(node.message_chain, client))
|
||||
return components
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(event: nio.RoomMessageText | nio.RoomMessageImage, client: nio.AsyncClient, bot_user_id: str):
|
||||
message_components = []
|
||||
|
||||
if isinstance(event, nio.RoomMessageText):
|
||||
text = event.body
|
||||
if bot_user_id and bot_user_id in text:
|
||||
message_components.append(platform_message.At(target=bot_user_id))
|
||||
text = text.replace(bot_user_id, '').strip()
|
||||
message_components.append(platform_message.Plain(text=text))
|
||||
|
||||
elif isinstance(event, nio.RoomMessageImage):
|
||||
mxc_url = event.url
|
||||
if mxc_url:
|
||||
resp = await client.download(mxc_url)
|
||||
if isinstance(resp, nio.DownloadResponse):
|
||||
b64 = base64.b64encode(resp.body).decode('utf-8')
|
||||
content_type = resp.content_type or 'image/png'
|
||||
message_components.append(platform_message.Image(base64=f'data:{content_type};base64,{b64}'))
|
||||
if event.body:
|
||||
message_components.append(platform_message.Plain(text=event.body))
|
||||
|
||||
return platform_message.MessageChain(message_components)
|
||||
|
||||
|
||||
class MatrixEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(event: platform_events.MessageEvent):
|
||||
return event.source_platform_object
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(
|
||||
event: nio.RoomMessageText | nio.RoomMessageImage,
|
||||
room: nio.MatrixRoom,
|
||||
client: nio.AsyncClient,
|
||||
bot_user_id: str,
|
||||
bridge_user_ids: list[str] | None = None,
|
||||
):
|
||||
lb_message = await MatrixMessageConverter.target2yiri(event, client, bot_user_id)
|
||||
|
||||
# Determine if this is a direct/private chat or a group chat.
|
||||
# Exclude bot itself and bridge bots, count remaining real users.
|
||||
exclude_ids = {bot_user_id}
|
||||
if bridge_user_ids:
|
||||
exclude_ids.update(bridge_user_ids)
|
||||
real_users = [uid for uid in room.users if uid not in exclude_ids]
|
||||
is_direct = len(real_users) <= 1
|
||||
|
||||
if is_direct:
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=event.sender,
|
||||
nickname=room.user_name(event.sender) or event.sender,
|
||||
remark='',
|
||||
),
|
||||
message_chain=lb_message,
|
||||
time=event.server_timestamp / 1000.0,
|
||||
source_platform_object={'event': event, 'room': room},
|
||||
)
|
||||
else:
|
||||
return platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=event.sender,
|
||||
member_name=room.user_name(event.sender) or event.sender,
|
||||
permission=platform_entities.Permission.Member,
|
||||
group=platform_entities.Group(
|
||||
id=room.room_id,
|
||||
name=room.display_name or room.room_id,
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
),
|
||||
message_chain=lb_message,
|
||||
time=event.server_timestamp / 1000.0,
|
||||
source_platform_object={'event': event, 'room': room},
|
||||
)
|
||||
|
||||
|
||||
class BridgeState:
|
||||
"""Per-bridge runtime state."""
|
||||
|
||||
def __init__(self, user_id: str, login_command: str, logout_command: str, success_keyword: str, check_command: str):
|
||||
self.user_id = user_id
|
||||
self.login_command = login_command
|
||||
self.logout_command = logout_command
|
||||
self.success_keyword = success_keyword
|
||||
self.check_command = check_command or login_command
|
||||
self.logged_in = False
|
||||
self.dm_room_id: str | None = None
|
||||
self.login_task: asyncio.Task | None = None
|
||||
self.check_task: asyncio.Task | None = None
|
||||
self.check_responded = False
|
||||
|
||||
|
||||
class MatrixAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
client: typing.Any = None
|
||||
message_converter: MatrixMessageConverter = MatrixMessageConverter()
|
||||
event_converter: MatrixEventConverter = MatrixEventConverter()
|
||||
config: dict
|
||||
listeners: typing.Dict[typing.Type[platform_events.Event], typing.Callable] = {}
|
||||
_running: bool = False
|
||||
_initial_sync_done: bool = False
|
||||
_bridges: list = []
|
||||
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
|
||||
homeserver_url = config.get('homeserver_url', '')
|
||||
access_token = config.get('access_token', '')
|
||||
user_id = config.get('user_id', '')
|
||||
|
||||
if not homeserver_url or not access_token or not user_id:
|
||||
raise ValueError('Matrix 机器人缺少必要配置项 (homeserver_url, user_id, access_token)')
|
||||
|
||||
client = nio.AsyncClient(homeserver_url, user_id)
|
||||
client.access_token = access_token
|
||||
client.user_id = user_id
|
||||
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
bot_account_id=user_id,
|
||||
client=client,
|
||||
listeners={},
|
||||
)
|
||||
|
||||
# Parse bridges config AFTER super().__init__() to avoid Pydantic resetting _bridges
|
||||
self._bridges = []
|
||||
bridges_raw = config.get('bridges', '')
|
||||
if bridges_raw:
|
||||
if isinstance(bridges_raw, str):
|
||||
try:
|
||||
bridges_list = json.loads(bridges_raw)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
raise ValueError(f'bridges 配置 JSON 解析失败: {e}\n原始值: {bridges_raw}')
|
||||
else:
|
||||
bridges_list = bridges_raw
|
||||
for b in bridges_list:
|
||||
if isinstance(b, dict) and b.get('user_id', '').strip():
|
||||
self._bridges.append(
|
||||
BridgeState(
|
||||
user_id=b['user_id'].strip(),
|
||||
login_command=b.get('login_command', '').strip(),
|
||||
logout_command=b.get('logout_command', '').strip(),
|
||||
success_keyword=b.get('success_keyword', 'Successfully logged in').strip(),
|
||||
check_command=b.get('check_command', '').strip(),
|
||||
)
|
||||
)
|
||||
# Backward compatibility: old single-bridge config
|
||||
if not self._bridges:
|
||||
old_user_id = config.get('bridge_user_id', '').strip()
|
||||
old_command = config.get('bridge_login_command', '').strip()
|
||||
old_keyword = config.get('bridge_login_success_keyword', 'Successfully logged in').strip()
|
||||
old_check = config.get('bridge_check_command', '').strip()
|
||||
old_logout = config.get('bridge_logout_command', '').strip()
|
||||
if old_user_id:
|
||||
self._bridges.append(
|
||||
BridgeState(
|
||||
user_id=old_user_id,
|
||||
login_command=old_command,
|
||||
logout_command=old_logout,
|
||||
success_keyword=old_keyword,
|
||||
check_command=old_check,
|
||||
)
|
||||
)
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
components = await self.message_converter.yiri2target(message, self.client)
|
||||
for component in components:
|
||||
await self._send_component(target_id, component)
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
source_obj = message_source.source_platform_object
|
||||
room_id = source_obj['room'].room_id
|
||||
components = await self.message_converter.yiri2target(message, self.client)
|
||||
|
||||
for component in components:
|
||||
if quote_origin:
|
||||
original_event = source_obj['event']
|
||||
await self._send_component(room_id, component, reply_to=original_event.event_id)
|
||||
else:
|
||||
await self._send_component(room_id, component)
|
||||
|
||||
async def _send_component(self, room_id: str, component: dict, reply_to: str | None = None):
|
||||
content = {}
|
||||
if component['type'] == 'text':
|
||||
content = {
|
||||
'msgtype': 'm.text',
|
||||
'body': component['text'],
|
||||
}
|
||||
elif component['type'] == 'image':
|
||||
content = {
|
||||
'msgtype': 'm.image',
|
||||
'body': 'image.png',
|
||||
'url': component['mxc_url'],
|
||||
}
|
||||
elif component['type'] == 'file':
|
||||
content = {
|
||||
'msgtype': 'm.file',
|
||||
'body': component.get('filename', 'file'),
|
||||
'url': component['mxc_url'],
|
||||
'info': {'size': component.get('size', 0)},
|
||||
}
|
||||
|
||||
if reply_to and content:
|
||||
content['m.relates_to'] = {
|
||||
'm.in_reply_to': {'event_id': reply_to},
|
||||
}
|
||||
|
||||
if content:
|
||||
await self.client.room_send(
|
||||
room_id=room_id,
|
||||
message_type='m.room.message',
|
||||
content=content,
|
||||
)
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
async def run_async(self):
|
||||
self._running = True
|
||||
await self.logger.info('Matrix adapter starting...')
|
||||
|
||||
# Debug: log bridge parsing result
|
||||
bridges_raw = self.config.get('bridges', '')
|
||||
await self.logger.debug(f'bridges config raw: type={type(bridges_raw).__name__}, repr={repr(bridges_raw)}')
|
||||
await self.logger.debug(
|
||||
f'parsed _bridges count: {len(self._bridges)}, ids: {[b.user_id for b in self._bridges]}'
|
||||
)
|
||||
|
||||
# Collect all bridge bot user IDs for filtering
|
||||
_bridge_user_ids = [b.user_id for b in self._bridges]
|
||||
_bridge_user_id_set = set(_bridge_user_ids)
|
||||
|
||||
# Auto-join invited rooms
|
||||
async def on_invite(room: nio.MatrixRoom, event: nio.InviteMemberEvent):
|
||||
if event.membership == 'invite' and event.state_key == self.client.user_id:
|
||||
await self.client.join(room.room_id)
|
||||
await self.logger.debug(f'Auto-joined room: {room.display_name or room.room_id}')
|
||||
|
||||
self.client.add_event_callback(on_invite, nio.InviteMemberEvent)
|
||||
|
||||
# Handle text messages
|
||||
async def on_message(room: nio.MatrixRoom, event: nio.RoomMessageText):
|
||||
if not self._initial_sync_done:
|
||||
return
|
||||
if event.sender == self.client.user_id:
|
||||
return
|
||||
|
||||
# Admin commands (from any non-bridge user)
|
||||
if event.sender not in _bridge_user_id_set:
|
||||
body = (event.body or '').strip()
|
||||
if body == '!relogin':
|
||||
await self._handle_relogin_command(room.room_id)
|
||||
return
|
||||
if body == '!status':
|
||||
await self._handle_status_command(room.room_id)
|
||||
return
|
||||
|
||||
if event.sender in _bridge_user_id_set:
|
||||
return
|
||||
try:
|
||||
lb_event = await self.event_converter.target2yiri(
|
||||
event, room, self.client, self.bot_account_id, _bridge_user_ids
|
||||
)
|
||||
if type(lb_event) in self.listeners:
|
||||
result = self.listeners[type(lb_event)](lb_event, self)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception:
|
||||
await self.logger.error(f'Error handling Matrix message: {traceback.format_exc()}')
|
||||
|
||||
self.client.add_event_callback(on_message, nio.RoomMessageText)
|
||||
|
||||
# Handle image messages
|
||||
async def on_image(room: nio.MatrixRoom, event: nio.RoomMessageImage):
|
||||
if not self._initial_sync_done:
|
||||
return
|
||||
if event.sender == self.client.user_id:
|
||||
return
|
||||
if event.sender in _bridge_user_id_set:
|
||||
return
|
||||
try:
|
||||
lb_event = await self.event_converter.target2yiri(
|
||||
event, room, self.client, self.bot_account_id, _bridge_user_ids
|
||||
)
|
||||
if type(lb_event) in self.listeners:
|
||||
result = self.listeners[type(lb_event)](lb_event, self)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception:
|
||||
await self.logger.error(f'Error handling Matrix image: {traceback.format_exc()}')
|
||||
|
||||
self.client.add_event_callback(on_image, nio.RoomMessageImage)
|
||||
|
||||
# Set up bridge-specific callbacks for each bridge
|
||||
_disconnect_keywords = ['disconnected', 'logged out', 'connection lost', 'session expired', 'token expired']
|
||||
|
||||
for bridge in self._bridges:
|
||||
# Login success detection (notice)
|
||||
async def on_bridge_notice(room: nio.MatrixRoom, event: nio.RoomMessageNotice, _b=bridge):
|
||||
if not self._initial_sync_done:
|
||||
return
|
||||
if event.sender != _b.user_id:
|
||||
return
|
||||
_b.check_responded = True
|
||||
if _b.success_keyword in (event.body or ''):
|
||||
_b.logged_in = True
|
||||
await self.logger.info(f'[{_b.user_id}] Bridge login succeeded.')
|
||||
# Disconnect detection
|
||||
body_lower = (event.body or '').lower()
|
||||
for kw in _disconnect_keywords:
|
||||
if kw in body_lower and _b.logged_in:
|
||||
_b.logged_in = False
|
||||
await self.logger.info(f'[{_b.user_id}] Bridge 账号掉线 (检测到: "{kw}"), 将自动重新登录...')
|
||||
self._restart_bridge_login(_b)
|
||||
break
|
||||
|
||||
self.client.add_event_callback(on_bridge_notice, nio.RoomMessageNotice)
|
||||
|
||||
# Login success + disconnect detection (text)
|
||||
async def on_bridge_text(room: nio.MatrixRoom, event: nio.RoomMessageText, _b=bridge):
|
||||
if not self._initial_sync_done:
|
||||
return
|
||||
if event.sender != _b.user_id:
|
||||
return
|
||||
_b.check_responded = True
|
||||
if _b.success_keyword in (event.body or ''):
|
||||
_b.logged_in = True
|
||||
await self.logger.info(f'[{_b.user_id}] Bridge login succeeded.')
|
||||
body_lower = (event.body or '').lower()
|
||||
for kw in _disconnect_keywords:
|
||||
if kw in body_lower and _b.logged_in:
|
||||
_b.logged_in = False
|
||||
await self.logger.info(f'[{_b.user_id}] Bridge 账号掉线 (检测到: "{kw}"), 将自动重新登录...')
|
||||
self._restart_bridge_login(_b)
|
||||
break
|
||||
|
||||
self.client.add_event_callback(on_bridge_text, nio.RoomMessageText)
|
||||
|
||||
# QR code image forwarding
|
||||
async def on_bridge_image(room: nio.MatrixRoom, event: nio.RoomMessageImage, _b=bridge):
|
||||
if not self._initial_sync_done:
|
||||
return
|
||||
if event.sender != _b.user_id:
|
||||
return
|
||||
mxc_url = event.url
|
||||
if not mxc_url:
|
||||
return
|
||||
try:
|
||||
resp = await self.client.download(mxc_url)
|
||||
if isinstance(resp, nio.DownloadResponse):
|
||||
b64 = base64.b64encode(resp.body).decode('utf-8')
|
||||
content_type = resp.content_type or 'image/png'
|
||||
await self.logger.info(
|
||||
f'[{_b.user_id}] Bridge 发送了二维码,请扫码登录:',
|
||||
images=[platform_message.Image(base64=f'data:{content_type};base64,{b64}')],
|
||||
)
|
||||
except Exception:
|
||||
await self.logger.error(
|
||||
f'[{_b.user_id}] Failed to download bridge QR image: {traceback.format_exc()}'
|
||||
)
|
||||
|
||||
self.client.add_event_callback(on_bridge_image, nio.RoomMessageImage)
|
||||
|
||||
await self.logger.debug('Matrix adapter running, starting sync...')
|
||||
|
||||
# Initial sync to skip old messages
|
||||
resp = await self.client.sync(timeout=10000)
|
||||
if isinstance(resp, nio.SyncResponse):
|
||||
await self.logger.debug(f'Matrix initial sync done, next_batch: {resp.next_batch}')
|
||||
self._initial_sync_done = True
|
||||
|
||||
# Display account info
|
||||
display_name = self.client.user_id
|
||||
try:
|
||||
profile_resp = await self.client.get_displayname(self.client.user_id)
|
||||
if isinstance(profile_resp, nio.ProfileGetDisplayNameResponse) and profile_resp.displayname:
|
||||
display_name = profile_resp.displayname
|
||||
except Exception:
|
||||
pass
|
||||
joined_rooms = len(self.client.rooms)
|
||||
homeserver = self.config.get('homeserver_url', '')
|
||||
bridge_info = ''
|
||||
if self._bridges:
|
||||
bridge_names = ', '.join(b.user_id for b in self._bridges)
|
||||
bridge_info = f' | 桥接: [{bridge_names}]'
|
||||
await self.logger.info(
|
||||
f'Matrix 账号: {display_name} ({self.client.user_id}) | '
|
||||
f'服务器: {homeserver} | 已加入 {joined_rooms} 个房间{bridge_info}'
|
||||
)
|
||||
|
||||
# Start bridge login and status check tasks for each bridge
|
||||
for bridge in self._bridges:
|
||||
if bridge.login_command:
|
||||
await self.logger.info(
|
||||
f'[{bridge.user_id}] Bridge login enabled (命令: "{bridge.login_command}", '
|
||||
f'关键词: "{bridge.success_keyword}")'
|
||||
)
|
||||
bridge.login_task = asyncio.create_task(self._periodic_bridge_login(bridge))
|
||||
bridge.check_task = asyncio.create_task(self._periodic_bridge_check(bridge))
|
||||
else:
|
||||
await self.logger.debug(f'[{bridge.user_id}] Bridge login not configured (no login_command)')
|
||||
|
||||
# Main sync loop
|
||||
while self._running:
|
||||
try:
|
||||
await self.client.sync(timeout=30000)
|
||||
except Exception:
|
||||
await self.logger.error(f'Matrix sync error: {traceback.format_exc()}')
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _periodic_bridge_login(self, bridge: BridgeState):
|
||||
"""Periodically send login command to a bridge bot until login succeeds."""
|
||||
try:
|
||||
await self.logger.info(f'[{bridge.user_id}] Bridge login task started, looking for DM room...')
|
||||
dm_room_id = None
|
||||
for room_id, room in self.client.rooms.items():
|
||||
if room.member_count == 2 and bridge.user_id in [m for m in room.users]:
|
||||
dm_room_id = room_id
|
||||
break
|
||||
|
||||
if not dm_room_id:
|
||||
resp = await self.client.room_create(
|
||||
is_direct=True,
|
||||
invite=[bridge.user_id],
|
||||
)
|
||||
if isinstance(resp, nio.RoomCreateResponse):
|
||||
dm_room_id = resp.room_id
|
||||
await self.logger.debug(f'[{bridge.user_id}] Created DM room: {dm_room_id}')
|
||||
else:
|
||||
await self.logger.error(f'[{bridge.user_id}] Failed to create DM room: {resp}')
|
||||
return
|
||||
|
||||
bridge.dm_room_id = dm_room_id
|
||||
|
||||
# Force logout first on every adapter start
|
||||
logout_cmd = bridge.logout_command or bridge.login_command.replace('login', 'logout')
|
||||
await self.logger.info(f'[{bridge.user_id}] 强制登出: "{logout_cmd}"')
|
||||
await self.client.room_send(
|
||||
room_id=dm_room_id,
|
||||
message_type='m.room.message',
|
||||
content={'msgtype': 'm.text', 'body': logout_cmd},
|
||||
)
|
||||
await asyncio.sleep(3)
|
||||
|
||||
while self._running and not bridge.logged_in:
|
||||
await self.logger.debug(f'[{bridge.user_id}] Sending "{bridge.login_command}" in room {dm_room_id}')
|
||||
await self.client.room_send(
|
||||
room_id=dm_room_id,
|
||||
message_type='m.room.message',
|
||||
content={'msgtype': 'm.text', 'body': bridge.login_command},
|
||||
)
|
||||
for _ in range(60):
|
||||
if not self._running or bridge.logged_in:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if bridge.logged_in:
|
||||
await self.logger.debug(f'[{bridge.user_id}] Bridge login confirmed, periodic login stopped.')
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
await self.logger.error(f'[{bridge.user_id}] Bridge periodic login error: {traceback.format_exc()}')
|
||||
|
||||
def _restart_bridge_login(self, bridge: BridgeState):
|
||||
"""Cancel existing login task and start a new one."""
|
||||
if bridge.login_task and not bridge.login_task.done():
|
||||
bridge.login_task.cancel()
|
||||
bridge.login_task = asyncio.create_task(self._periodic_bridge_login(bridge))
|
||||
|
||||
async def _periodic_bridge_check(self, bridge: BridgeState):
|
||||
"""Periodically check a bridge's login status."""
|
||||
try:
|
||||
while self._running and not bridge.logged_in:
|
||||
await asyncio.sleep(5)
|
||||
|
||||
check_interval = 300 # 5 minutes
|
||||
response_timeout = 30
|
||||
await self.logger.debug(f'[{bridge.user_id}] Bridge status check started (interval: {check_interval}s)')
|
||||
|
||||
while self._running:
|
||||
for _ in range(check_interval):
|
||||
if not self._running:
|
||||
return
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if not bridge.logged_in or not bridge.dm_room_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
bridge.check_responded = False
|
||||
await self.client.room_send(
|
||||
room_id=bridge.dm_room_id,
|
||||
message_type='m.room.message',
|
||||
content={'msgtype': 'm.text', 'body': bridge.check_command},
|
||||
)
|
||||
await self.logger.debug(f'[{bridge.user_id}] Bridge status check: sent "{bridge.check_command}"')
|
||||
|
||||
for _ in range(response_timeout):
|
||||
if bridge.check_responded or not self._running:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if bridge.check_responded:
|
||||
await self.logger.debug(f'[{bridge.user_id}] Bridge status check: OK')
|
||||
else:
|
||||
await self.logger.info(
|
||||
f'[{bridge.user_id}] Bridge status check: 无响应, 可能已掉线, 尝试重新登录...'
|
||||
)
|
||||
bridge.logged_in = False
|
||||
self._restart_bridge_login(bridge)
|
||||
except Exception:
|
||||
await self.logger.error(f'[{bridge.user_id}] Bridge status check error: {traceback.format_exc()}')
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
await self.logger.error(f'[{bridge.user_id}] Bridge status check fatal error: {traceback.format_exc()}')
|
||||
|
||||
async def _handle_relogin_command(self, room_id: str):
|
||||
"""Handle !relogin command: logout then re-login all bridges."""
|
||||
if not self._bridges:
|
||||
await self.client.room_send(
|
||||
room_id=room_id,
|
||||
message_type='m.room.message',
|
||||
content={'msgtype': 'm.text', 'body': '没有配置任何桥。'},
|
||||
)
|
||||
return
|
||||
|
||||
lines = ['开始重新登录所有桥...']
|
||||
for bridge in self._bridges:
|
||||
if not bridge.login_command or not bridge.dm_room_id:
|
||||
lines.append(f'[{bridge.user_id}] 跳过(未配置登录命令或无DM房间)')
|
||||
continue
|
||||
|
||||
# Use configured logout command, fallback to deriving from login command
|
||||
logout_cmd = bridge.logout_command or bridge.login_command.replace('login', 'logout')
|
||||
lines.append(f'[{bridge.user_id}] 发送 "{logout_cmd}"...')
|
||||
|
||||
# Cancel existing tasks
|
||||
if bridge.login_task and not bridge.login_task.done():
|
||||
bridge.login_task.cancel()
|
||||
if bridge.check_task and not bridge.check_task.done():
|
||||
bridge.check_task.cancel()
|
||||
|
||||
# Send logout
|
||||
try:
|
||||
await self.client.room_send(
|
||||
room_id=bridge.dm_room_id,
|
||||
message_type='m.room.message',
|
||||
content={'msgtype': 'm.text', 'body': logout_cmd},
|
||||
)
|
||||
except Exception as e:
|
||||
lines.append(f'[{bridge.user_id}] logout 发送失败: {e}')
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Reset state and restart login
|
||||
bridge.logged_in = False
|
||||
self._restart_bridge_login(bridge)
|
||||
lines.append(f'[{bridge.user_id}] 已触发重新登录')
|
||||
|
||||
await self.client.room_send(
|
||||
room_id=room_id,
|
||||
message_type='m.room.message',
|
||||
content={'msgtype': 'm.text', 'body': '\n'.join(lines)},
|
||||
)
|
||||
|
||||
async def _handle_status_command(self, room_id: str):
|
||||
"""Handle !status command: show bridge states."""
|
||||
if not self._bridges:
|
||||
await self.client.room_send(
|
||||
room_id=room_id,
|
||||
message_type='m.room.message',
|
||||
content={'msgtype': 'm.text', 'body': '没有配置任何桥。'},
|
||||
)
|
||||
return
|
||||
|
||||
lines = ['桥状态:']
|
||||
for bridge in self._bridges:
|
||||
status = '已登录 ✓' if bridge.logged_in else '未登录 ✗'
|
||||
dm = bridge.dm_room_id or '无'
|
||||
lines.append(f'• {bridge.user_id}: {status} (DM: {dm})')
|
||||
await self.client.room_send(
|
||||
room_id=room_id,
|
||||
message_type='m.room.message',
|
||||
content={'msgtype': 'm.text', 'body': '\n'.join(lines)},
|
||||
)
|
||||
|
||||
async def kill(self) -> bool:
|
||||
self._running = False
|
||||
for bridge in self._bridges:
|
||||
if bridge.login_task and not bridge.login_task.done():
|
||||
bridge.login_task.cancel()
|
||||
if bridge.check_task and not bridge.check_task.done():
|
||||
bridge.check_task.cancel()
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
await self.logger.debug('Matrix adapter stopped')
|
||||
return True
|
||||
|
||||
async def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
if event_type in self.listeners:
|
||||
del self.listeners[event_type]
|
||||
@@ -1,123 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: MessagePlatformAdapter
|
||||
metadata:
|
||||
name: matrix
|
||||
label:
|
||||
en_US: Matrix
|
||||
zh_Hans: Matrix
|
||||
zh_Hant: Matrix
|
||||
ja_JP: Matrix
|
||||
th_TH: Matrix
|
||||
vi_VN: Matrix
|
||||
es_ES: Matrix
|
||||
description:
|
||||
en_US: Matrix protocol adapter, supports self-hosted Synapse servers and any Matrix-compatible homeserver
|
||||
zh_Hans: Matrix 协议适配器,支持自建 Synapse 服务器及任何 Matrix 兼容的 Homeserver
|
||||
zh_Hant: Matrix 協議適配器,支持自建 Synapse 伺服器及任何 Matrix 相容的 Homeserver
|
||||
ja_JP: Matrix プロトコルアダプター、セルフホストの Synapse サーバーおよび Matrix 互換のホームサーバーをサポート
|
||||
th_TH: อะแดปเตอร์โปรโตคอล Matrix รองรับเซิร์ฟเวอร์ Synapse ที่โฮสต์เองและ Homeserver ที่เข้ากันได้กับ Matrix
|
||||
vi_VN: Bộ điều hợp giao thức Matrix, hỗ trợ máy chủ Synapse tự lưu trữ và bất kỳ Homeserver tương thích Matrix nào
|
||||
es_ES: Adaptador del protocolo Matrix, compatible con servidores Synapse autoalojados y cualquier Homeserver compatible con Matrix
|
||||
icon: matrix.png
|
||||
spec:
|
||||
categories:
|
||||
- global
|
||||
- protocol
|
||||
config:
|
||||
- name: homeserver_url
|
||||
label:
|
||||
en_US: Homeserver URL
|
||||
zh_Hans: Homeserver 地址
|
||||
zh_Hant: Homeserver 地址
|
||||
ja_JP: Homeserver URL
|
||||
th_TH: URL ของ Homeserver
|
||||
vi_VN: URL Homeserver
|
||||
es_ES: URL del Homeserver
|
||||
description:
|
||||
en_US: "The URL of the Matrix homeserver, e.g. http://localhost:8008"
|
||||
zh_Hans: "Matrix Homeserver 的地址,例如 http://localhost:8008"
|
||||
type: string
|
||||
required: true
|
||||
default: "http://localhost:8008"
|
||||
- name: user_id
|
||||
label:
|
||||
en_US: Bot User ID
|
||||
zh_Hans: 机器人用户 ID
|
||||
zh_Hant: 機器人用戶 ID
|
||||
ja_JP: ボットユーザー ID
|
||||
th_TH: ID ผู้ใช้บอท
|
||||
vi_VN: ID người dùng bot
|
||||
es_ES: ID de usuario del bot
|
||||
description:
|
||||
en_US: "The full Matrix user ID, e.g. @bot:localhost"
|
||||
zh_Hans: "完整的 Matrix 用户 ID,例如 @bot:localhost"
|
||||
type: string
|
||||
required: true
|
||||
default: "@langbot:localhost"
|
||||
- name: access_token
|
||||
label:
|
||||
en_US: Access Token
|
||||
zh_Hans: 访问令牌
|
||||
zh_Hant: 訪問令牌
|
||||
ja_JP: アクセストークン
|
||||
th_TH: โทเค็นการเข้าถึง
|
||||
vi_VN: Mã truy cập
|
||||
es_ES: Token de acceso
|
||||
description:
|
||||
en_US: "Access token obtained by logging in via the Matrix client API"
|
||||
zh_Hans: "通过 Matrix Client API 登录获取的访问令牌"
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: bridge_user_id
|
||||
label:
|
||||
en_US: Bridge Bot User ID (single bridge, legacy)
|
||||
zh_Hans: 桥机器人用户 ID(单桥兼容)
|
||||
description:
|
||||
en_US: "Single bridge bot user ID (legacy). Prefer 'bridges' for multi-bridge. e.g. @discordbot:localhost"
|
||||
zh_Hans: "单桥机器人用户 ID(旧格式兼容)。推荐使用 bridges 配置多桥。例如 @discordbot:localhost"
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
- name: bridge_login_command
|
||||
label:
|
||||
en_US: Bridge Login Command (single bridge, legacy)
|
||||
zh_Hans: 桥登录命令(单桥兼容)
|
||||
description:
|
||||
en_US: "Login command for single bridge (legacy). e.g. !discord login"
|
||||
zh_Hans: "单桥登录命令(旧格式兼容)。例如 !discord login"
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
- name: bridge_login_success_keyword
|
||||
label:
|
||||
en_US: Bridge Login Success Keyword (single bridge, legacy)
|
||||
zh_Hans: 桥登录成功关键词(单桥兼容)
|
||||
description:
|
||||
en_US: "Success keyword for single bridge (legacy). e.g. Successfully logged in"
|
||||
zh_Hans: "单桥登录成功关键词(旧格式兼容)。例如 Successfully logged in"
|
||||
type: string
|
||||
required: false
|
||||
default: "Successfully logged in"
|
||||
- name: bridges
|
||||
label:
|
||||
en_US: Bridges Config (Multi-bridge)
|
||||
zh_Hans: 桥配置(多桥)
|
||||
description:
|
||||
en_US: >
|
||||
JSON array of bridge configs. Each bridge: {"user_id": "@bot:host", "login_command": "!xx login",
|
||||
"success_keyword": "logged in", "check_command": "!xx ping"}.
|
||||
Example: [{"user_id":"@discordbot:localhost","login_command":"!discord login","success_keyword":"logged in"},
|
||||
{"user_id":"@telegrambot:localhost","login_command":"!tg login","success_keyword":"logged in"}]
|
||||
zh_Hans: >
|
||||
JSON 数组格式的多桥配置。每个桥: {"user_id": "@bot:host", "login_command": "!xx login",
|
||||
"success_keyword": "logged in", "check_command": "!xx ping"}。
|
||||
示例: [{"user_id":"@discordbot:localhost","login_command":"!discord login","success_keyword":"logged in"},
|
||||
{"user_id":"@telegrambot:localhost","login_command":"!tg login","success_keyword":"logged in"}]
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
execution:
|
||||
python:
|
||||
path: ./matrix.py
|
||||
attr: MatrixAdapter
|
||||
@@ -32,20 +32,6 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: "https://ilinkai.weixin.qq.com"
|
||||
- name: qr-login
|
||||
label:
|
||||
en_US: Scan QR Login
|
||||
zh_Hans: 扫码登录
|
||||
zh_Hant: 掃碼登入
|
||||
ja_JP: QRコードでログイン
|
||||
description:
|
||||
en_US: Scan QR code with WeChat to authorize and automatically fill in the token
|
||||
zh_Hans: 使用微信扫码授权,自动填写令牌
|
||||
zh_Hant: 使用微信掃碼授權,自動填寫令牌
|
||||
ja_JP: WeChatでQRコードをスキャンし、トークンを自動入力
|
||||
type: qr-code-login
|
||||
login_platform: weixin
|
||||
required: false
|
||||
- name: token
|
||||
label:
|
||||
en_US: Token
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import typing
|
||||
import re
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
import datetime
|
||||
import time
|
||||
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
@@ -17,25 +15,11 @@ from ...utils import image
|
||||
from ..logger import EventLogger
|
||||
|
||||
|
||||
def _is_base64_data(value: str) -> bool:
|
||||
"""Check if a string contains base64-encoded data rather than a URL."""
|
||||
if not value:
|
||||
return False
|
||||
# data: URI scheme (e.g. data:image/png;base64,xxx)
|
||||
if value.startswith('data:'):
|
||||
return True
|
||||
# Only treat as base64 if it doesn't look like a URL/path and has valid base64 chars
|
||||
if value.startswith(('http://', 'https://', '/', './', '../')):
|
||||
return False
|
||||
# Check if it looks like base64 (only valid chars, reasonable length)
|
||||
return bool(re.fullmatch(r'[A-Za-z0-9+/=\s]{20,}', value))
|
||||
|
||||
|
||||
class QQOfficialMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain: platform_message.MessageChain):
|
||||
"""将 LangBot 消息链转换为 QQ Official 消息格式列表。"""
|
||||
content_list = []
|
||||
# 只实现了发文字
|
||||
for msg in message_chain:
|
||||
if type(msg) is platform_message.Plain:
|
||||
content_list.append(
|
||||
@@ -44,49 +28,6 @@ class QQOfficialMessageConverter(abstract_platform_adapter.AbstractMessageConver
|
||||
'content': msg.text,
|
||||
}
|
||||
)
|
||||
elif type(msg) is platform_message.Image:
|
||||
url = msg.url if hasattr(msg, 'url') and msg.url else None
|
||||
b64 = msg.base64 if hasattr(msg, 'base64') and msg.base64 else None
|
||||
# Some plugins (e.g. MimoTTS) store base64 data in the url field
|
||||
if url and not b64 and _is_base64_data(url):
|
||||
b64 = url
|
||||
url = None
|
||||
content_list.append(
|
||||
{
|
||||
'type': 'image',
|
||||
'url': url,
|
||||
'base64': b64,
|
||||
}
|
||||
)
|
||||
elif type(msg) is platform_message.Voice:
|
||||
url = msg.url if hasattr(msg, 'url') and msg.url else None
|
||||
b64 = msg.base64 if hasattr(msg, 'base64') and msg.base64 else None
|
||||
# Some plugins (e.g. MimoTTS) store base64 data in the url field
|
||||
if url and not b64 and _is_base64_data(url):
|
||||
b64 = url
|
||||
url = None
|
||||
content_list.append(
|
||||
{
|
||||
'type': 'voice',
|
||||
'url': url,
|
||||
'base64': b64,
|
||||
}
|
||||
)
|
||||
elif type(msg) is platform_message.File:
|
||||
url = msg.url if hasattr(msg, 'url') and msg.url else None
|
||||
b64 = msg.base64 if hasattr(msg, 'base64') and msg.base64 else None
|
||||
# Some plugins store base64 data in the url field
|
||||
if url and not b64 and _is_base64_data(url):
|
||||
b64 = url
|
||||
url = None
|
||||
content_list.append(
|
||||
{
|
||||
'type': 'file',
|
||||
'url': url,
|
||||
'base64': b64,
|
||||
'name': msg.name if hasattr(msg, 'name') else 'file',
|
||||
}
|
||||
)
|
||||
|
||||
return content_list
|
||||
|
||||
@@ -188,19 +129,12 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
config: dict
|
||||
bot_account_id: str
|
||||
bot_uuid: str = None
|
||||
enable_webhook: bool = False
|
||||
message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter()
|
||||
event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
enable_webhook = config.get('enable-webhook', False)
|
||||
|
||||
bot = QQOfficialClient(
|
||||
app_id=config['appid'],
|
||||
secret=config['secret'],
|
||||
token=config['token'],
|
||||
logger=logger,
|
||||
unified_mode=enable_webhook,
|
||||
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=logger, unified_mode=True
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
@@ -210,13 +144,6 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
bot_account_id=config['appid'],
|
||||
)
|
||||
|
||||
self.enable_webhook = enable_webhook
|
||||
self._ws_task: asyncio.Task = None
|
||||
self._stream_ctx: dict = {}
|
||||
self._stream_ctx_ts: dict[str, float] = {}
|
||||
self._fallback_text: dict[str, str] = {}
|
||||
self._fallback_text_ts: dict[str, float] = {}
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
@@ -229,18 +156,28 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
|
||||
content_list = await QQOfficialMessageConverter.yiri2target(message)
|
||||
|
||||
# 确定 target_type 和 target_id
|
||||
target_type = None
|
||||
target_id = None
|
||||
|
||||
# 私聊消息
|
||||
if qq_official_event.t == 'C2C_MESSAGE_CREATE':
|
||||
target_type = 'c2c'
|
||||
target_id = qq_official_event.user_openid
|
||||
elif qq_official_event.t == 'GROUP_AT_MESSAGE_CREATE':
|
||||
target_type = 'group'
|
||||
target_id = qq_official_event.group_openid
|
||||
elif qq_official_event.t == 'AT_MESSAGE_CREATE':
|
||||
# 频道群聊使用频道 API,暂不支持富媒体
|
||||
for content in content_list:
|
||||
if content['type'] == 'text':
|
||||
await self.bot.send_private_text_msg(
|
||||
qq_official_event.user_openid,
|
||||
content['content'],
|
||||
qq_official_event.d_id,
|
||||
)
|
||||
|
||||
# 群聊消息
|
||||
if qq_official_event.t == 'GROUP_AT_MESSAGE_CREATE':
|
||||
for content in content_list:
|
||||
if content['type'] == 'text':
|
||||
await self.bot.send_group_text_msg(
|
||||
qq_official_event.group_openid,
|
||||
content['content'],
|
||||
qq_official_event.d_id,
|
||||
)
|
||||
|
||||
# 频道群聊
|
||||
if qq_official_event.t == 'AT_MESSAGE_CREATE':
|
||||
for content in content_list:
|
||||
if content['type'] == 'text':
|
||||
await self.bot.send_channle_group_text_msg(
|
||||
@@ -248,9 +185,9 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
content['content'],
|
||||
qq_official_event.d_id,
|
||||
)
|
||||
return
|
||||
elif qq_official_event.t == 'DIRECT_MESSAGE_CREATE':
|
||||
# 频道私聊使用频道 API,暂不支持富媒体
|
||||
|
||||
# 频道私聊
|
||||
if qq_official_event.t == 'DIRECT_MESSAGE_CREATE':
|
||||
for content in content_list:
|
||||
if content['type'] == 'text':
|
||||
await self.bot.send_channle_private_text_msg(
|
||||
@@ -258,63 +195,6 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
content['content'],
|
||||
qq_official_event.d_id,
|
||||
)
|
||||
return
|
||||
|
||||
# C2C 和群聊:支持文字 + 富媒体
|
||||
for content in content_list:
|
||||
content_type = content.get('type', 'text')
|
||||
|
||||
if content_type == 'text':
|
||||
if target_type == 'c2c':
|
||||
await self.bot.send_private_text_msg(
|
||||
target_id,
|
||||
content['content'],
|
||||
qq_official_event.d_id,
|
||||
)
|
||||
elif target_type == 'group':
|
||||
await self.bot.send_group_text_msg(
|
||||
target_id,
|
||||
content['content'],
|
||||
qq_official_event.d_id,
|
||||
)
|
||||
|
||||
elif content_type == 'image':
|
||||
file_url = content.get('url')
|
||||
file_data = content.get('base64')
|
||||
if file_url or file_data:
|
||||
await self.bot.send_image_msg(
|
||||
target_type,
|
||||
target_id,
|
||||
file_url=file_url,
|
||||
file_data=file_data,
|
||||
msg_id=qq_official_event.d_id,
|
||||
)
|
||||
|
||||
elif content_type == 'voice':
|
||||
file_url = content.get('url')
|
||||
file_data = content.get('base64')
|
||||
if file_url or file_data:
|
||||
await self.bot.send_voice_msg(
|
||||
target_type,
|
||||
target_id,
|
||||
file_url=file_url,
|
||||
file_data=file_data,
|
||||
msg_id=qq_official_event.d_id,
|
||||
)
|
||||
|
||||
elif content_type == 'file':
|
||||
file_url = content.get('url')
|
||||
file_data = content.get('base64')
|
||||
file_name = content.get('name', 'file')
|
||||
if file_url or file_data:
|
||||
await self.bot.send_file_msg(
|
||||
target_type,
|
||||
target_id,
|
||||
file_url=file_url,
|
||||
file_data=file_data,
|
||||
file_name=file_name,
|
||||
msg_id=qq_official_event.d_id,
|
||||
)
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
pass
|
||||
@@ -358,196 +238,17 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
async def run_async(self):
|
||||
if not self.enable_webhook:
|
||||
await self._run_websocket()
|
||||
else:
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
async def keep_alive():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
# 保持运行但不启动独立端口
|
||||
|
||||
await keep_alive()
|
||||
async def keep_alive():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _run_websocket(self):
|
||||
"""以 WebSocket 模式运行网关连接"""
|
||||
await self.logger.info('QQ Official adapter starting in WebSocket mode')
|
||||
|
||||
async def on_ready():
|
||||
await self.logger.info('QQ Official WebSocket connected and ready')
|
||||
|
||||
async def on_event(event_type: str, event_data: dict):
|
||||
# 只处理消息事件,忽略 READY/RESUMED 等系统事件
|
||||
message_event_types = {
|
||||
'C2C_MESSAGE_CREATE',
|
||||
'DIRECT_MESSAGE_CREATE',
|
||||
'GROUP_AT_MESSAGE_CREATE',
|
||||
'AT_MESSAGE_CREATE',
|
||||
}
|
||||
if event_type not in message_event_types:
|
||||
return
|
||||
if not isinstance(event_data, dict):
|
||||
await self.logger.warning(f'Event data is not dict, skipping: {event_type} -> {type(event_data)}')
|
||||
return
|
||||
await self.logger.info(f'Processing message event: {event_type}')
|
||||
# 构造与 webhook 模式相同的 payload 结构
|
||||
payload = {'t': event_type, 'd': event_data}
|
||||
message_data = await self.bot.get_message(payload)
|
||||
if message_data:
|
||||
event = QQOfficialEvent.from_payload(message_data)
|
||||
await self.bot._handle_message(event)
|
||||
|
||||
async def on_error(error: Exception):
|
||||
await self.logger.error(f'WebSocket error: {error}')
|
||||
await self.logger.error(f'QQ Official WebSocket error: {error}')
|
||||
|
||||
self._ws_task = asyncio.create_task(self.bot.connect_gateway_loop(on_event, on_ready, on_error))
|
||||
try:
|
||||
await self._ws_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await keep_alive()
|
||||
|
||||
async def kill(self) -> bool:
|
||||
if self._ws_task:
|
||||
self._ws_task.cancel()
|
||||
try:
|
||||
await self._ws_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._ws_task = None
|
||||
return True
|
||||
|
||||
# --------------- 流式输出 ---------------
|
||||
|
||||
_STREAM_CTX_TTL = 300 # seconds
|
||||
|
||||
async def _cleanup_stale_streams(self):
|
||||
"""Remove stream contexts that have not been updated for more than _STREAM_CTX_TTL seconds."""
|
||||
now = time.time()
|
||||
stale_ids = [mid for mid, ts in self._stream_ctx_ts.items() if now - ts > self._STREAM_CTX_TTL]
|
||||
for mid in stale_ids:
|
||||
self._stream_ctx.pop(mid, None)
|
||||
self._stream_ctx_ts.pop(mid, None)
|
||||
stale_fb = [mid for mid, ts in self._fallback_text_ts.items() if now - ts > self._STREAM_CTX_TTL]
|
||||
for mid in stale_fb:
|
||||
self._fallback_text.pop(mid, None)
|
||||
self._fallback_text_ts.pop(mid, None)
|
||||
if stale_ids or stale_fb:
|
||||
await self.logger.debug(f'Cleaned up {len(stale_ids)} stream contexts, {len(stale_fb)} fallback texts')
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
return self.config.get('enable-stream-reply', False)
|
||||
|
||||
async def create_message_card(self, message_id: str, event: platform_events.MessageEvent) -> bool:
|
||||
source = event.source_platform_object
|
||||
# Streaming API only supports C2C private chat
|
||||
if source.t != 'C2C_MESSAGE_CREATE':
|
||||
return False
|
||||
|
||||
ctx = {
|
||||
'user_openid': source.user_openid,
|
||||
'msg_id': source.d_id,
|
||||
'stream_msg_id': None,
|
||||
'msg_seq': 1,
|
||||
'index': 0,
|
||||
'last_update_ts': 0,
|
||||
'accumulated_text': '',
|
||||
'sent_length': 0,
|
||||
'session_started': False,
|
||||
}
|
||||
|
||||
self._stream_ctx[message_id] = ctx
|
||||
self._stream_ctx_ts[message_id] = time.time()
|
||||
return True
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
bot_message: dict,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
):
|
||||
# Periodically clean up stale stream contexts
|
||||
await self._cleanup_stale_streams()
|
||||
# 提取纯文本内容(当前 chunk 的文本)
|
||||
text_parts = []
|
||||
for msg in message:
|
||||
if type(msg) is platform_message.Plain:
|
||||
text_parts.append(msg.text)
|
||||
chunk_text = '\n\n'.join(text_parts)
|
||||
|
||||
message_id = (
|
||||
bot_message.get('resp_message_id')
|
||||
if isinstance(bot_message, dict)
|
||||
else getattr(bot_message, 'resp_message_id', None)
|
||||
)
|
||||
if not message_id or message_id not in self._stream_ctx:
|
||||
# 非流式场景(如群聊不支持流式),累积文本后一次性回复
|
||||
if chunk_text:
|
||||
self._fallback_text[message_id] = self._fallback_text.get(message_id, '') + chunk_text
|
||||
self._fallback_text_ts[message_id] = time.time()
|
||||
if is_final:
|
||||
full_text = self._fallback_text.pop(message_id, '')
|
||||
if full_text:
|
||||
fallback_msg = platform_message.MessageChain([platform_message.Plain(text=full_text)])
|
||||
await self.reply_message(message_source, fallback_msg, quote_origin)
|
||||
return
|
||||
|
||||
ctx = self._stream_ctx[message_id]
|
||||
|
||||
# 累积文本
|
||||
if chunk_text:
|
||||
ctx['accumulated_text'] += chunk_text
|
||||
|
||||
# 未启动会话时,等第一个有内容的 chunk 来建立会话
|
||||
if not ctx['session_started']:
|
||||
if not ctx['accumulated_text']:
|
||||
return
|
||||
# 用第一个 chunk 的文本建立会话(不发 "..." 避免污染前缀)
|
||||
ctx['session_started'] = True
|
||||
|
||||
# 发送内容 = 全量累积文本
|
||||
# QQ API 的 replace 模式不允许修改已下发前缀,所以:
|
||||
# - 首次:发送全部文本,建立会话
|
||||
# - 后续:只能发送新增部分(append 行为)
|
||||
content_to_send = ctx['accumulated_text'][ctx['sent_length'] :]
|
||||
if not content_to_send and not is_final:
|
||||
return
|
||||
|
||||
input_state = 10 if is_final else 1
|
||||
|
||||
# Rate limiting: skip non-final updates if last update was <0.5s ago
|
||||
now = time.time()
|
||||
if not is_final and (now - ctx['last_update_ts']) < 0.5:
|
||||
return
|
||||
ctx['last_update_ts'] = now
|
||||
|
||||
try:
|
||||
resp = await self.bot.send_stream_msg(
|
||||
user_openid=ctx['user_openid'],
|
||||
content=content_to_send,
|
||||
event_id=ctx['msg_id'],
|
||||
msg_id=ctx['msg_id'],
|
||||
msg_seq=ctx['msg_seq'],
|
||||
index=ctx['index'],
|
||||
stream_msg_id=ctx['stream_msg_id'],
|
||||
input_state=input_state,
|
||||
)
|
||||
if resp and isinstance(resp, dict):
|
||||
new_stream_id = resp.get('id')
|
||||
if new_stream_id:
|
||||
ctx['stream_msg_id'] = new_stream_id
|
||||
ctx['sent_length'] = len(ctx['accumulated_text'])
|
||||
ctx['index'] += 1
|
||||
await self.logger.debug(
|
||||
f'[QQ Official] 流式 chunk 已发送, index={ctx["index"]}, '
|
||||
f'sent_len={ctx["sent_length"]}, is_final={is_final}'
|
||||
)
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Failed to send stream message: {e}')
|
||||
|
||||
if is_final:
|
||||
self._stream_ctx.pop(message_id, None)
|
||||
return False
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
|
||||
@@ -7,9 +7,9 @@ metadata:
|
||||
zh_Hans: QQ 官方 API
|
||||
zh_Hant: QQ 官方 API
|
||||
description:
|
||||
en_US: QQ Official API (Webhook / WebSocket)
|
||||
zh_Hans: QQ 官方 API,支持 Webhook 和 WebSocket 两种连接模式
|
||||
zh_Hant: QQ 官方 API,支援 Webhook 和 WebSocket 兩種連線模式
|
||||
en_US: QQ Official API (Webhook)
|
||||
zh_Hans: QQ 官方 API (Webhook),需要公网地址以接收消息推送,请查看文档了解使用方式
|
||||
zh_Hant: QQ 官方 API (Webhook),需要公網地址以接收訊息推送,請查看文件了解使用方式
|
||||
icon: qqofficial.svg
|
||||
spec:
|
||||
categories:
|
||||
@@ -19,6 +19,18 @@ spec:
|
||||
en: https://link.langbot.app/en/platforms/qqofficial
|
||||
ja: https://link.langbot.app/ja/platforms/qqofficial
|
||||
config:
|
||||
- name: webhook_url
|
||||
label:
|
||||
en_US: Webhook Callback URL
|
||||
zh_Hans: Webhook 回调地址
|
||||
zh_Hant: Webhook 回調地址
|
||||
description:
|
||||
en_US: Copy this URL and paste it into your QQ Official API webhook configuration
|
||||
zh_Hans: 复制此地址并粘贴到 QQ 官方 API 的 Webhook 配置中
|
||||
zh_Hant: 複製此地址並貼到 QQ 官方 API 的 Webhook 設定中
|
||||
type: webhook-url
|
||||
required: false
|
||||
default: ""
|
||||
- name: appid
|
||||
label:
|
||||
en_US: App ID
|
||||
@@ -43,46 +55,6 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: enable-webhook
|
||||
label:
|
||||
en_US: Enable Webhook Mode
|
||||
zh_Hans: 启用Webhook模式
|
||||
zh_Hant: 啟用 Webhook 模式
|
||||
description:
|
||||
en_US: If enabled, the bot will use webhook mode to receive messages. Otherwise, it will use WebSocket mode
|
||||
zh_Hans: 如果启用,机器人将使用 Webhook 模式接收消息。否则,将使用 WebSocket 模式
|
||||
zh_Hant: 如果啟用,機器人將使用 Webhook 模式接收訊息。否則,將使用 WebSocket 模式
|
||||
type: boolean
|
||||
required: true
|
||||
default: false
|
||||
- name: enable-stream-reply
|
||||
label:
|
||||
en_US: Enable Stream Reply Mode
|
||||
zh_Hans: 启用流式回复模式
|
||||
zh_Hant: 啟用串流回覆模式
|
||||
description:
|
||||
en_US: If enabled, the bot will use streaming mode to reply messages (C2C only)
|
||||
zh_Hans: 如果启用,机器人将使用流式方式回复消息(仅私聊)
|
||||
zh_Hant: 如果啟用,機器人將使用串流方式回覆訊息(僅私聊)
|
||||
type: boolean
|
||||
required: true
|
||||
default: false
|
||||
- name: webhook_url
|
||||
label:
|
||||
en_US: Webhook Callback URL
|
||||
zh_Hans: Webhook 回调地址
|
||||
zh_Hant: Webhook 回調地址
|
||||
description:
|
||||
en_US: Copy this URL and paste it into your QQ Official API webhook configuration
|
||||
zh_Hans: 复制此地址并粘贴到 QQ 官方 API 的 Webhook 配置中
|
||||
zh_Hant: 複製此地址並貼到 QQ 官方 API 的 Webhook 設定中
|
||||
type: webhook-url
|
||||
required: false
|
||||
default: ""
|
||||
show_if:
|
||||
field: enable-webhook
|
||||
operator: eq
|
||||
value: true
|
||||
execution:
|
||||
python:
|
||||
path: ./qqofficial.py
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: MessagePlatformAdapter
|
||||
metadata:
|
||||
name: web_page_bot
|
||||
label:
|
||||
en_US: "Page Bot"
|
||||
zh_Hans: "页面机器人"
|
||||
zh_Hant: "頁面機器人"
|
||||
ja_JP: "ページボット"
|
||||
th_TH: "บอทหน้าเว็บ"
|
||||
vi_VN: "Bot trang web"
|
||||
es_ES: "Bot de página"
|
||||
description:
|
||||
en_US: "Embed a chat widget on any website with a simple script tag"
|
||||
zh_Hans: "通过一行脚本标签将聊天组件嵌入到任何网站"
|
||||
zh_Hant: "透過一行腳本標籤將聊天元件嵌入到任何網站"
|
||||
ja_JP: "シンプルなスクリプトタグで任意のウェブサイトにチャットウィジェットを埋め込みます"
|
||||
th_TH: "ฝังวิดเจ็ตแชทในเว็บไซต์ใดก็ได้ด้วยแท็กสคริปต์"
|
||||
vi_VN: "Nhúng widget trò chuyện vào bất kỳ trang web nào bằng thẻ script"
|
||||
es_ES: "Incrusta un widget de chat en cualquier sitio web con una etiqueta de script"
|
||||
icon: "webpage.webp"
|
||||
spec:
|
||||
categories:
|
||||
- popular
|
||||
config:
|
||||
- name: title
|
||||
label:
|
||||
en_US: Widget Title
|
||||
zh_Hans: 组件标题
|
||||
zh_Hant: 元件標題
|
||||
ja_JP: ウィジェットタイトル
|
||||
th_TH: ชื่อวิดเจ็ต
|
||||
vi_VN: Tiêu đề widget
|
||||
es_ES: Título del widget
|
||||
description:
|
||||
en_US: The title displayed in the chat widget header
|
||||
zh_Hans: 显示在聊天组件顶部的标题
|
||||
zh_Hant: 顯示在聊天元件頂部的標題
|
||||
ja_JP: チャットウィジェットのヘッダーに表示されるタイトル
|
||||
th_TH: ชื่อที่แสดงในส่วนหัวของวิดเจ็ตแชท
|
||||
vi_VN: Tiêu đề hiển thị trong đầu widget trò chuyện
|
||||
es_ES: El título que se muestra en el encabezado del widget de chat
|
||||
type: string
|
||||
required: false
|
||||
default: "LangBot"
|
||||
- name: bubble_icon
|
||||
label:
|
||||
en_US: Bubble Icon
|
||||
zh_Hans: 气泡图标
|
||||
zh_Hant: 氣泡圖示
|
||||
ja_JP: バブルアイコン
|
||||
th_TH: ไอคอนบับเบิล
|
||||
vi_VN: Biểu tượng bong bóng
|
||||
es_ES: Icono de burbuja
|
||||
ru_RU: Иконка пузырька
|
||||
description:
|
||||
en_US: "Icon displayed on the floating chat bubble"
|
||||
zh_Hans: "浮动聊天气泡上显示的图标"
|
||||
type: select
|
||||
required: false
|
||||
default: "logo"
|
||||
options:
|
||||
- name: "logo"
|
||||
label:
|
||||
en_US: "LangBot Logo"
|
||||
zh_Hans: "LangBot 图标"
|
||||
- name: "chat"
|
||||
label:
|
||||
en_US: "Chat Bubble"
|
||||
zh_Hans: "聊天气泡"
|
||||
- name: "robot"
|
||||
label:
|
||||
en_US: "Robot"
|
||||
zh_Hans: "机器人"
|
||||
- name: "headset"
|
||||
label:
|
||||
en_US: "Headset"
|
||||
zh_Hans: "客服耳机"
|
||||
- name: "sparkle"
|
||||
label:
|
||||
en_US: "Sparkle"
|
||||
zh_Hans: "星光"
|
||||
- name: "message"
|
||||
label:
|
||||
en_US: "Message"
|
||||
zh_Hans: "消息"
|
||||
- name: language
|
||||
label:
|
||||
en_US: Widget Language
|
||||
zh_Hans: 组件语言
|
||||
zh_Hant: 元件語言
|
||||
ja_JP: ウィジェット言語
|
||||
th_TH: ภาษาวิดเจ็ต
|
||||
vi_VN: Ngôn ngữ widget
|
||||
es_ES: Idioma del widget
|
||||
ru_RU: Язык виджета
|
||||
description:
|
||||
en_US: "Display language of the chat widget"
|
||||
zh_Hans: "聊天组件的显示语言"
|
||||
zh_Hant: "聊天元件的顯示語言"
|
||||
ja_JP: "チャットウィジェットの表示言語"
|
||||
th_TH: "ภาษาแสดงผลของวิดเจ็ตแชท"
|
||||
vi_VN: "Ngôn ngữ hiển thị của widget trò chuyện"
|
||||
es_ES: "Idioma de visualización del widget de chat"
|
||||
ru_RU: "Язык отображения виджета чата"
|
||||
type: select
|
||||
required: false
|
||||
default: "en_US"
|
||||
options:
|
||||
- name: "en_US"
|
||||
label:
|
||||
en_US: "English"
|
||||
- name: "zh_Hans"
|
||||
label:
|
||||
en_US: "简体中文"
|
||||
- name: "zh_Hant"
|
||||
label:
|
||||
en_US: "繁體中文"
|
||||
- name: "ja_JP"
|
||||
label:
|
||||
en_US: "日本語"
|
||||
- name: "es_ES"
|
||||
label:
|
||||
en_US: "Español"
|
||||
- name: "ru_RU"
|
||||
label:
|
||||
en_US: "Русский"
|
||||
- name: "th_TH"
|
||||
label:
|
||||
en_US: "ไทย"
|
||||
- name: "vi_VN"
|
||||
label:
|
||||
en_US: "Tiếng Việt"
|
||||
- name: embed_code
|
||||
label:
|
||||
en_US: Embed Code
|
||||
zh_Hans: 嵌入代码
|
||||
zh_Hant: 嵌入代碼
|
||||
ja_JP: 埋め込みコード
|
||||
th_TH: โค้ดฝังตัว
|
||||
vi_VN: Mã nhúng
|
||||
es_ES: Código de incrustación
|
||||
description:
|
||||
en_US: "Copy this code and paste it into your website HTML. The code will be generated after saving."
|
||||
zh_Hans: "复制此代码并粘贴到你的网站 HTML 中。保存后将自动生成。"
|
||||
zh_Hant: "複製此代碼並貼到你的網站 HTML 中。儲存後將自動生成。"
|
||||
ja_JP: "このコードをコピーしてウェブサイトのHTMLに貼り付けてください。保存後に自動生成されます。"
|
||||
th_TH: "คัดลอกโค้ดนี้และวางในHTML ของเว็บไซต์ของคุณ จะสร้างอัตโนมัติหลังจากบันทึก"
|
||||
vi_VN: "Sao chép mã này và dán vào HTML trang web của bạn. Mã sẽ được tạo tự động sau khi lưu."
|
||||
es_ES: "Copia este código y pégalo en el HTML de tu sitio web. El código se generará después de guardar."
|
||||
type: embed-code
|
||||
required: false
|
||||
default: ""
|
||||
- name: turnstile_site_key
|
||||
label:
|
||||
en_US: Turnstile Site Key
|
||||
zh_Hans: Turnstile 站点密钥
|
||||
description:
|
||||
en_US: "Cloudflare Turnstile site key for bot protection. Get it from the Cloudflare dashboard (Turnstile > Add Site). Leave empty to disable."
|
||||
zh_Hans: "Cloudflare Turnstile 站点密钥,用于防止机器人滥用。在 Cloudflare 控制台(Turnstile > 添加站点)中获取。留空则不启用。"
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
- name: turnstile_secret_key
|
||||
label:
|
||||
en_US: Turnstile Secret Key
|
||||
zh_Hans: Turnstile 服务端密钥
|
||||
description:
|
||||
en_US: "Cloudflare Turnstile secret key for server-side token verification. Found alongside the site key in the Cloudflare dashboard. Required if site key is set."
|
||||
zh_Hans: "Cloudflare Turnstile 服务端密钥,用于服务端验证令牌。与站点密钥一起在 Cloudflare 控制台中获取。设置了站点密钥时必填。"
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
execution:
|
||||
python:
|
||||
path: "web_page_bot_adapter.py"
|
||||
attr: "WebPageBotAdapter"
|
||||
@@ -1,94 +0,0 @@
|
||||
"""Web Page Bot adapter - lightweight adapter for embeddable chat widget"""
|
||||
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||
|
||||
|
||||
class WebPageBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
"""Lightweight adapter for the embeddable page bot.
|
||||
|
||||
This adapter does not handle messages itself. The actual WebSocket
|
||||
communication is handled by the singleton websocket_proxy_bot.
|
||||
This adapter stores event listeners so that RuntimeBot can register
|
||||
its handlers, which are then called by the websocket adapter when
|
||||
a message arrives for this bot's pipeline.
|
||||
|
||||
Message sending/replying is delegated to the websocket_proxy_bot's
|
||||
adapter so that replies are actually delivered over the WebSocket
|
||||
connection while the dashboard correctly shows this adapter's name.
|
||||
"""
|
||||
|
||||
listeners: dict = pydantic.Field(default_factory=dict, exclude=True)
|
||||
_ws_adapter: typing.Any = None
|
||||
|
||||
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||||
super().__init__(config=config, logger=logger, **kwargs)
|
||||
|
||||
def set_ws_adapter(self, ws_adapter) -> None:
|
||||
"""Set the underlying WebSocket adapter used for actual message delivery."""
|
||||
object.__setattr__(self, '_ws_adapter', ws_adapter)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: platform_message.MessageChain,
|
||||
) -> dict:
|
||||
if self._ws_adapter is not None:
|
||||
return await self._ws_adapter.send_message(target_type, target_id, message)
|
||||
return {}
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
) -> dict:
|
||||
if self._ws_adapter is not None:
|
||||
return await self._ws_adapter.reply_message(message_source, message, quote_origin)
|
||||
return {}
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
bot_message,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
) -> dict:
|
||||
if self._ws_adapter is not None:
|
||||
return await self._ws_adapter.reply_message_chunk(
|
||||
message_source, bot_message, message, quote_origin, is_final
|
||||
)
|
||||
return {}
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
func: typing.Callable,
|
||||
):
|
||||
self.listeners[event_type] = func
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
func: typing.Callable,
|
||||
):
|
||||
self.listeners.pop(event_type, None)
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
return False
|
||||
|
||||
async def run_async(self):
|
||||
pass
|
||||
|
||||
async def kill(self):
|
||||
pass
|
||||
|
Before Width: | Height: | Size: 14 KiB |
@@ -312,7 +312,7 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
|
||||
async def _process_image_components(self, message_chain_obj: list):
|
||||
"""
|
||||
处理消息链中的图片和文件组件,将path转换为base64
|
||||
处理消息链中的图片组件,将path转换为base64
|
||||
|
||||
Args:
|
||||
message_chain_obj: 消息链对象列表
|
||||
@@ -322,18 +322,16 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
storage_mgr = self.ap.storage_mgr
|
||||
|
||||
for component in message_chain_obj:
|
||||
comp_type = component.get('type', '')
|
||||
comp_path = component.get('path', '')
|
||||
|
||||
if not comp_path:
|
||||
continue
|
||||
|
||||
if comp_type == 'Image':
|
||||
if component.get('type') == 'Image' and component.get('path'):
|
||||
try:
|
||||
file_content = await storage_mgr.storage_provider.load(comp_path)
|
||||
# 从storage读取文件
|
||||
file_content = await storage_mgr.storage_provider.load(component['path'])
|
||||
|
||||
# 转换为base64
|
||||
base64_str = base64.b64encode(file_content).decode('utf-8')
|
||||
|
||||
file_key = comp_path
|
||||
# 添加data URI前缀(根据文件扩展名判断MIME类型)
|
||||
file_key = component['path']
|
||||
if file_key.lower().endswith(('.jpg', '.jpeg')):
|
||||
mime_type = 'image/jpeg'
|
||||
elif file_key.lower().endswith('.png'):
|
||||
@@ -343,19 +341,19 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
elif file_key.lower().endswith('.webp'):
|
||||
mime_type = 'image/webp'
|
||||
else:
|
||||
mime_type = 'image/png'
|
||||
mime_type = 'image/png' # 默认
|
||||
|
||||
component['base64'] = f'data:{mime_type};base64,{base64_str}'
|
||||
await storage_mgr.storage_provider.delete(comp_path)
|
||||
await storage_mgr.storage_provider.delete(component['path'])
|
||||
component['path'] = ''
|
||||
# 保留path字段用于后端处理,前端使用base64显示
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Failed to load image file {comp_path}: {e}')
|
||||
await self.logger.error(f'加载图片文件失败 {component["path"]}: {e}')
|
||||
|
||||
async def handle_websocket_message(
|
||||
self,
|
||||
connection: WebSocketConnection,
|
||||
message_data: dict,
|
||||
owner_bot=None,
|
||||
):
|
||||
"""
|
||||
处理从WebSocket接收的消息
|
||||
@@ -368,8 +366,6 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
message_data: 消息数据,包含:
|
||||
- message: 消息链
|
||||
- stream: 是否启用流式输出 (可选,默认True)
|
||||
owner_bot: Optional RuntimeBot that owns this pipeline (e.g. a web_page_bot).
|
||||
When provided, its identity is used for logging and session tracking.
|
||||
"""
|
||||
pipeline_uuid = connection.pipeline_uuid
|
||||
session_type = connection.session_type
|
||||
@@ -439,26 +435,12 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
|
||||
)
|
||||
|
||||
# 设置流水线UUID (proxy bot always needs it for reply_message routing)
|
||||
# 设置流水线UUID
|
||||
self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid
|
||||
if owner_bot is not None:
|
||||
owner_bot.bot_entity.use_pipeline_uuid = pipeline_uuid
|
||||
|
||||
# 异步触发事件处理
|
||||
# Use owner_bot's listeners if available, otherwise fall back to proxy bot
|
||||
listeners = (
|
||||
owner_bot.adapter.listeners
|
||||
if (owner_bot and hasattr(owner_bot.adapter, 'listeners') and owner_bot.adapter.listeners)
|
||||
else self.listeners
|
||||
)
|
||||
# Pass owner_bot's adapter so that downstream logging / dashboard
|
||||
# attributes the message to the correct bot adapter name.
|
||||
# Wire the ws adapter into the owner so replies are actually delivered.
|
||||
if owner_bot and hasattr(owner_bot.adapter, 'set_ws_adapter'):
|
||||
owner_bot.adapter.set_ws_adapter(self)
|
||||
callback_adapter = owner_bot.adapter if (owner_bot and hasattr(owner_bot, 'adapter')) else self
|
||||
if event.__class__ in listeners:
|
||||
asyncio.create_task(listeners[event.__class__](event, callback_adapter))
|
||||
# 异步触发事件处理(不等待结果)
|
||||
if event.__class__ in self.listeners:
|
||||
asyncio.create_task(self.listeners[event.__class__](event, self))
|
||||
|
||||
def get_websocket_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]:
|
||||
"""获取消息历史"""
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import datetime
|
||||
@@ -127,107 +126,6 @@ class WecomBotMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
||||
if summary:
|
||||
yiri_msg_list.append(platform_message.Plain(text=summary))
|
||||
|
||||
# Handle quoted message (引用消息) - important for group chat file references
|
||||
# Extract files/images/voice from quote and add them as top-level components
|
||||
# so that plugins like FileReader can process them the same way as direct messages
|
||||
quote_info = event.quote or {}
|
||||
if quote_info:
|
||||
# Process quote text content - add as Plain for context
|
||||
if quote_info.get('content'):
|
||||
yiri_msg_list.append(platform_message.Plain(text=f'[引用消息] {quote_info.get("content")}'))
|
||||
|
||||
# Process quote images - add as top-level Image components
|
||||
quote_images = quote_info.get('images', [])
|
||||
if not quote_images and quote_info.get('picurl'):
|
||||
quote_images = [quote_info.get('picurl')]
|
||||
for img_data in quote_images:
|
||||
if img_data:
|
||||
yiri_msg_list.append(platform_message.Image(base64=img_data))
|
||||
|
||||
# Process quote file - add as top-level File component (same as private chat)
|
||||
quote_file = quote_info.get('file') or {}
|
||||
if quote_file:
|
||||
file_url = (
|
||||
quote_file.get('base64')
|
||||
or quote_file.get('download_url')
|
||||
or quote_file.get('url')
|
||||
or quote_file.get('fileurl')
|
||||
)
|
||||
file_name = quote_file.get('filename') or quote_file.get('name')
|
||||
file_size = quote_file.get('filesize') or quote_file.get('size')
|
||||
if file_url or file_name:
|
||||
file_kwargs = {}
|
||||
if file_url:
|
||||
file_kwargs['url'] = file_url
|
||||
if file_name:
|
||||
file_kwargs['name'] = file_name
|
||||
if file_size is not None:
|
||||
file_kwargs['size'] = file_size
|
||||
try:
|
||||
yiri_msg_list.append(platform_message.File(**file_kwargs))
|
||||
except Exception:
|
||||
yiri_msg_list.append(platform_message.Unknown(text='[quoted file unsupported]'))
|
||||
|
||||
# Process quote voice - add as top-level Voice/File component
|
||||
quote_voice = quote_info.get('voice') or {}
|
||||
if quote_voice:
|
||||
voice_payload = quote_voice.get('base64') or quote_voice.get('url')
|
||||
if voice_payload:
|
||||
if quote_voice.get('base64') and not voice_payload.startswith('data:'):
|
||||
voice_payload = f'data:audio/mpeg;base64,{quote_voice.get("base64")}'
|
||||
try:
|
||||
yiri_msg_list.append(platform_message.Voice(base64=voice_payload))
|
||||
except Exception:
|
||||
try:
|
||||
voice_kwargs = {'url': voice_payload}
|
||||
voice_name = quote_voice.get('filename') or quote_voice.get('name')
|
||||
voice_size = quote_voice.get('filesize') or quote_voice.get('size')
|
||||
if voice_name:
|
||||
voice_kwargs['name'] = voice_name
|
||||
if voice_size is not None:
|
||||
voice_kwargs['size'] = voice_size
|
||||
yiri_msg_list.append(platform_message.File(**voice_kwargs))
|
||||
except Exception:
|
||||
yiri_msg_list.append(platform_message.Unknown(text='[quoted voice unsupported]'))
|
||||
|
||||
# Process quote video - add as top-level File component
|
||||
quote_video = quote_info.get('video') or {}
|
||||
if quote_video:
|
||||
video_payload = (
|
||||
quote_video.get('base64')
|
||||
or quote_video.get('url')
|
||||
or quote_video.get('download_url')
|
||||
or quote_video.get('fileurl')
|
||||
)
|
||||
if video_payload:
|
||||
video_kwargs = {'url': video_payload}
|
||||
video_name = quote_video.get('filename') or quote_video.get('name')
|
||||
video_size = quote_video.get('filesize') or quote_video.get('size')
|
||||
if video_name:
|
||||
video_kwargs['name'] = video_name
|
||||
if video_size is not None:
|
||||
video_kwargs['size'] = video_size
|
||||
try:
|
||||
yiri_msg_list.append(platform_message.File(**video_kwargs))
|
||||
except Exception:
|
||||
yiri_msg_list.append(platform_message.Unknown(text='[quoted video unsupported]'))
|
||||
|
||||
# Process quote link - add as Plain text
|
||||
quote_link = quote_info.get('link') or {}
|
||||
if quote_link:
|
||||
link_summary = '\n'.join(
|
||||
filter(
|
||||
None,
|
||||
[
|
||||
quote_link.get('title', ''),
|
||||
quote_link.get('description') or quote_link.get('digest', ''),
|
||||
quote_link.get('url', ''),
|
||||
],
|
||||
)
|
||||
)
|
||||
if link_summary:
|
||||
yiri_msg_list.append(platform_message.Plain(text=f'[引用链接] {link_summary}'))
|
||||
|
||||
has_content_element = any(
|
||||
not isinstance(element, (platform_message.Source, platform_message.At)) for element in yiri_msg_list
|
||||
)
|
||||
@@ -294,8 +192,6 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
_ws_mode: bool = False
|
||||
bot_name: str = ''
|
||||
listeners: dict = {}
|
||||
_stream_to_monitoring_msg: dict = {} # Maps stream_id to (monitoring_message_id, timestamp)
|
||||
_STREAM_MAPPING_TTL = 600 # 10 minutes
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
enable_webhook = config.get('enable-webhook', False)
|
||||
@@ -332,9 +228,8 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot_account_id=bot_account_id,
|
||||
bot_name=bot_name,
|
||||
event_converter=event_converter,
|
||||
listeners={},
|
||||
_stream_to_monitoring_msg={},
|
||||
)
|
||||
self.listeners = {}
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
@@ -426,23 +321,6 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
async def on_monitoring_message_created(self, query, monitoring_message_id: str):
|
||||
"""Called by pipeline after monitoring message is created, to map stream_id to monitoring message ID."""
|
||||
try:
|
||||
stream_id = query.message_event.source_platform_object.stream_id
|
||||
if stream_id:
|
||||
self._stream_to_monitoring_msg[stream_id] = (monitoring_message_id, time.time())
|
||||
self._cleanup_stream_mapping()
|
||||
except Exception as e:
|
||||
await self.logger.debug(f'Failed to map stream_id to monitoring message: {e}')
|
||||
|
||||
def _cleanup_stream_mapping(self):
|
||||
"""Remove entries older than TTL from the stream_id to monitoring message mapping."""
|
||||
now = time.time()
|
||||
expired = [k for k, (_, ts) in self._stream_to_monitoring_msg.items() if now - ts > self._STREAM_MAPPING_TTL]
|
||||
for k in expired:
|
||||
del self._stream_to_monitoring_msg[k]
|
||||
|
||||
async def _on_feedback(self, **kwargs):
|
||||
"""Handle feedback event from WeChat Work AI Bot SDK and dispatch as FeedbackEvent."""
|
||||
try:
|
||||
@@ -450,9 +328,6 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
feedback_type = kwargs.get('feedback_type', 0)
|
||||
feedback_content = kwargs.get('feedback_content', '') or None
|
||||
inaccurate_reasons = kwargs.get('inaccurate_reasons', []) or None
|
||||
# WeChat Work returns integer reason codes, but FeedbackEvent expects strings
|
||||
if inaccurate_reasons:
|
||||
inaccurate_reasons = [str(r) for r in inaccurate_reasons]
|
||||
session = kwargs.get('session')
|
||||
|
||||
session_id = None
|
||||
@@ -468,11 +343,6 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
message_id = session.msg_id
|
||||
stream_id = session.stream_id
|
||||
|
||||
# Resolve stream_id to LangBot monitoring message ID if available
|
||||
monitoring_msg_id = None
|
||||
if stream_id and stream_id in self._stream_to_monitoring_msg:
|
||||
monitoring_msg_id = self._stream_to_monitoring_msg[stream_id][0]
|
||||
|
||||
await self.logger.info(
|
||||
f'Feedback event: feedback_id={feedback_id}, type={feedback_type}, '
|
||||
f'session_id={session_id}, user_id={user_id}, message_id={message_id}'
|
||||
@@ -486,7 +356,7 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_id=message_id,
|
||||
stream_id=monitoring_msg_id or stream_id,
|
||||
stream_id=stream_id,
|
||||
source_platform_object=session,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,18 +19,6 @@ spec:
|
||||
en: https://link.langbot.app/en/platforms/wecombot
|
||||
ja: https://link.langbot.app/ja/platforms/wecombot
|
||||
config:
|
||||
- name: one-click-create
|
||||
label:
|
||||
en_US: One-Click Create Bot
|
||||
zh_Hans: 一键创建机器人
|
||||
zh_Hant: 一鍵建立機器人
|
||||
description:
|
||||
en_US: "Scan QR code with WeCom to automatically create a bot and fill in BotId and Secret. Note: Robot Name needs to be filled in manually."
|
||||
zh_Hans: "使用企业微信扫码自动创建机器人并填写 BotId 和 Secret。注意:机器人名称需手动填写。"
|
||||
zh_Hant: "使用企業微信掃碼自動建立機器人並填寫 BotId 和 Secret。注意:機器人名稱需手動填寫。"
|
||||
type: qr-code-login
|
||||
login_platform: wecombot
|
||||
required: false
|
||||
- name: BotId
|
||||
label:
|
||||
en_US: BotId
|
||||
|
||||
@@ -11,7 +11,6 @@ import os
|
||||
import sys
|
||||
import httpx
|
||||
import sqlalchemy
|
||||
import yaml
|
||||
from async_lru import alru_cache
|
||||
from langbot_plugin.api.entities.builtin.pipeline.query import provider_session
|
||||
|
||||
@@ -35,10 +34,6 @@ from ..core import taskmgr
|
||||
from ..entity.persistence import plugin as persistence_plugin
|
||||
|
||||
|
||||
class PluginRuntimeNotConnectedError(RuntimeError):
|
||||
"""Raised when plugin runtime operations are requested before connection."""
|
||||
|
||||
|
||||
class PluginRuntimeConnector:
|
||||
"""Plugin runtime connector"""
|
||||
|
||||
@@ -196,114 +191,44 @@ class PluginRuntimeConnector:
|
||||
|
||||
async def ping_plugin_runtime(self):
|
||||
if not hasattr(self, 'handler'):
|
||||
raise PluginRuntimeNotConnectedError('Plugin runtime is not connected')
|
||||
raise Exception('Plugin runtime is not connected')
|
||||
|
||||
return await self.handler.ping()
|
||||
|
||||
def _inspect_plugin_package(
|
||||
def _extract_deps_metadata(
|
||||
self,
|
||||
file_bytes: bytes,
|
||||
task_context: taskmgr.TaskContext | None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract plugin identity and dependency metadata from a plugin package."""
|
||||
plugin_author = None
|
||||
plugin_name = None
|
||||
|
||||
):
|
||||
"""Extract dependency count from requirements.txt inside plugin zip."""
|
||||
if task_context is None:
|
||||
return
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_bytes)) as zf:
|
||||
try:
|
||||
manifest = yaml.safe_load(zf.read('manifest.yaml').decode('utf-8', errors='ignore')) or {}
|
||||
metadata = manifest.get('metadata', {})
|
||||
plugin_author = metadata.get('author')
|
||||
plugin_name = metadata.get('name')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if task_context is not None:
|
||||
for name in zf.namelist():
|
||||
if name.endswith('requirements.txt'):
|
||||
content = zf.read(name).decode('utf-8', errors='ignore')
|
||||
deps = [
|
||||
line.strip()
|
||||
for line in content.splitlines()
|
||||
if line.strip() and not line.strip().startswith('#')
|
||||
]
|
||||
task_context.metadata['deps_total'] = len(deps)
|
||||
task_context.metadata['deps_list'] = deps
|
||||
break
|
||||
for name in zf.namelist():
|
||||
if name.endswith('requirements.txt'):
|
||||
content = zf.read(name).decode('utf-8', errors='ignore')
|
||||
deps = [
|
||||
line.strip()
|
||||
for line in content.splitlines()
|
||||
if line.strip() and not line.strip().startswith('#')
|
||||
]
|
||||
task_context.metadata['deps_total'] = len(deps)
|
||||
task_context.metadata['deps_list'] = deps
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return plugin_author, plugin_name
|
||||
|
||||
def _build_plugin_startup_failure_message(
|
||||
self,
|
||||
plugin_author: str,
|
||||
plugin_name: str,
|
||||
task_context: taskmgr.TaskContext | None,
|
||||
) -> str:
|
||||
dep_hint = ''
|
||||
if task_context is not None:
|
||||
current_dep = task_context.metadata.get('current_dep')
|
||||
if current_dep:
|
||||
dep_hint = f' Last dependency: {current_dep}.'
|
||||
|
||||
return (
|
||||
f'Plugin {plugin_author}/{plugin_name} failed to start after installation. '
|
||||
f'Dependency installation or plugin initialization may have failed.{dep_hint} '
|
||||
f'Please check the plugin requirements and runtime logs.'
|
||||
)
|
||||
|
||||
async def _wait_for_installed_plugin_ready(
|
||||
self,
|
||||
plugin_author: str | None,
|
||||
plugin_name: str | None,
|
||||
task_context: taskmgr.TaskContext | None,
|
||||
timeout: float = 30,
|
||||
):
|
||||
"""Wait until the installed plugin is registered by the runtime.
|
||||
|
||||
The plugin runtime launches plugins asynchronously. If dependency installation
|
||||
fails, the plugin process exits before registration; without this check the
|
||||
install task can incorrectly finish successfully.
|
||||
"""
|
||||
if not plugin_author or not plugin_name:
|
||||
return
|
||||
|
||||
deadline = time.time() + timeout
|
||||
last_error: Exception | None = None
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
plugin = await self.get_plugin_info(plugin_author, plugin_name)
|
||||
if plugin is not None:
|
||||
status = plugin.get('status')
|
||||
if status == 'initialized':
|
||||
return
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
message = self._build_plugin_startup_failure_message(plugin_author, plugin_name, task_context)
|
||||
if last_error is not None:
|
||||
message = f'{message} Last runtime error: {last_error}'
|
||||
raise RuntimeError(message)
|
||||
|
||||
async def install_plugin(
|
||||
self,
|
||||
install_source: PluginInstallSource,
|
||||
install_info: dict[str, Any],
|
||||
task_context: taskmgr.TaskContext | None = None,
|
||||
):
|
||||
plugin_author = install_info.get('plugin_author')
|
||||
plugin_name = install_info.get('plugin_name')
|
||||
|
||||
if install_source == PluginInstallSource.LOCAL:
|
||||
# transfer file before install
|
||||
file_bytes = install_info['plugin_file']
|
||||
plugin_author, plugin_name = self._inspect_plugin_package(file_bytes, task_context)
|
||||
if task_context is not None and plugin_author and plugin_name:
|
||||
task_context.metadata['plugin_name'] = f'{plugin_author}/{plugin_name}'
|
||||
self._extract_deps_metadata(file_bytes, task_context)
|
||||
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
|
||||
install_info['plugin_file_key'] = file_key
|
||||
del install_info['plugin_file']
|
||||
@@ -340,9 +265,7 @@ class PluginRuntimeConnector:
|
||||
task_context.metadata['download_speed'] = downloaded / elapsed if elapsed > 0 else 0
|
||||
|
||||
file_bytes = b''.join(chunks)
|
||||
plugin_author, plugin_name = self._inspect_plugin_package(file_bytes, task_context)
|
||||
if task_context is not None and plugin_author and plugin_name:
|
||||
task_context.metadata['plugin_name'] = f'{plugin_author}/{plugin_name}'
|
||||
self._extract_deps_metadata(file_bytes, task_context)
|
||||
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
|
||||
install_info['plugin_file_key'] = file_key
|
||||
self.ap.logger.info(f'Transfered file {file_key} to plugin runtime')
|
||||
@@ -366,8 +289,6 @@ class PluginRuntimeConnector:
|
||||
if metadata is not None and task_context is not None:
|
||||
task_context.metadata.update(metadata)
|
||||
|
||||
await self._wait_for_installed_plugin_ready(plugin_author, plugin_name, task_context)
|
||||
|
||||
async def upgrade_plugin(
|
||||
self,
|
||||
plugin_author: str,
|
||||
@@ -510,17 +431,6 @@ class PluginRuntimeConnector:
|
||||
async def get_plugin_assets(self, plugin_author: str, plugin_name: str, filepath: str) -> dict[str, Any]:
|
||||
return await self.handler.get_plugin_assets(plugin_author, plugin_name, filepath)
|
||||
|
||||
async def handle_page_api(
|
||||
self,
|
||||
plugin_author: str,
|
||||
plugin_name: str,
|
||||
page_id: str,
|
||||
endpoint: str,
|
||||
method: str,
|
||||
body: Any = None,
|
||||
) -> dict[str, Any]:
|
||||
return await self.handler.handle_page_api(plugin_author, plugin_name, page_id, endpoint, method, body)
|
||||
|
||||
async def get_debug_info(self) -> dict[str, Any]:
|
||||
"""Get debug information including debug key and WS URL"""
|
||||
if not self.is_enable_plugin:
|
||||
|
||||
@@ -367,22 +367,6 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
owner_type = data['owner_type']
|
||||
owner = data['owner']
|
||||
value = base64.b64decode(data['value_base64'])
|
||||
max_value_bytes = (
|
||||
self.ap.instance_config.data.get('plugin', {})
|
||||
.get('binary_storage', {})
|
||||
.get(
|
||||
'max_value_bytes',
|
||||
10 * 1024 * 1024,
|
||||
)
|
||||
)
|
||||
try:
|
||||
max_value_bytes = int(max_value_bytes)
|
||||
except (TypeError, ValueError):
|
||||
max_value_bytes = 10 * 1024 * 1024
|
||||
if max_value_bytes >= 0 and len(value) > max_value_bytes:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Binary storage value exceeds limit ({len(value)} > {max_value_bytes} bytes)',
|
||||
)
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_bstorage.BinaryStorage)
|
||||
@@ -955,11 +939,6 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
timeout=20,
|
||||
)
|
||||
asset_file_key = result['file_file_key']
|
||||
if not asset_file_key:
|
||||
return {
|
||||
'asset_base64': '',
|
||||
'mime_type': '',
|
||||
}
|
||||
mime_type = result['mime_type']
|
||||
asset_bytes = await self.read_local_file(asset_file_key)
|
||||
await self.delete_local_file(asset_file_key)
|
||||
@@ -968,30 +947,6 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
'mime_type': mime_type,
|
||||
}
|
||||
|
||||
async def handle_page_api(
|
||||
self,
|
||||
plugin_author: str,
|
||||
plugin_name: str,
|
||||
page_id: str,
|
||||
endpoint: str,
|
||||
method: str,
|
||||
body: Any = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Forward a page API call to the plugin via runtime."""
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.PAGE_API,
|
||||
{
|
||||
'plugin_author': plugin_author,
|
||||
'plugin_name': plugin_name,
|
||||
'page_id': page_id,
|
||||
'endpoint': endpoint,
|
||||
'method': method,
|
||||
'body': body,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
return result
|
||||
|
||||
async def cleanup_plugin_data(self, plugin_author: str, plugin_name: str) -> None:
|
||||
"""Cleanup plugin settings and binary storage"""
|
||||
# Delete plugin settings
|
||||
|
||||
@@ -9,6 +9,7 @@ from ...discover import engine
|
||||
from . import token
|
||||
from ...entity.persistence import model as persistence_model
|
||||
from ...entity.errors import provider as provider_errors
|
||||
from async_lru import alru_cache
|
||||
|
||||
|
||||
class ModelManager:
|
||||
@@ -23,8 +24,6 @@ class ModelManager:
|
||||
|
||||
embedding_models: list[requester.RuntimeEmbeddingModel]
|
||||
|
||||
rerank_models: list[requester.RuntimeRerankModel]
|
||||
|
||||
requester_components: list[engine.Component]
|
||||
|
||||
requester_dict: dict[str, type[requester.ProviderAPIRequester]]
|
||||
@@ -33,7 +32,6 @@ class ModelManager:
|
||||
self.ap = ap
|
||||
self.llm_models = []
|
||||
self.embedding_models = []
|
||||
self.rerank_models = []
|
||||
self.requester_components = []
|
||||
self.requester_dict = {}
|
||||
|
||||
@@ -66,7 +64,8 @@ class ModelManager:
|
||||
|
||||
self.llm_models = []
|
||||
self.embedding_models = []
|
||||
self.rerank_models = []
|
||||
|
||||
# Load all providers first
|
||||
self.provider_dict = {}
|
||||
providers_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider)
|
||||
@@ -111,22 +110,6 @@ class ModelManager:
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}')
|
||||
|
||||
# Load rerank models
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.RerankModel))
|
||||
rerank_models = result.all()
|
||||
for rerank_model in rerank_models:
|
||||
try:
|
||||
provider = self.provider_dict.get(rerank_model.provider_uuid)
|
||||
if provider is None:
|
||||
self.ap.logger.warning(
|
||||
f'Provider {rerank_model.provider_uuid} not found for model {rerank_model.uuid}'
|
||||
)
|
||||
continue
|
||||
runtime_rerank_model = await self.load_rerank_model_with_provider(rerank_model, provider)
|
||||
self.rerank_models.append(runtime_rerank_model)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to load model {rerank_model.uuid}: {e}\n{traceback.format_exc()}')
|
||||
|
||||
async def sync_new_models_from_space(self):
|
||||
"""Sync models from Space"""
|
||||
space_model_provider = await self.ap.persistence_mgr.execute_async(
|
||||
@@ -229,26 +212,6 @@ class ModelManager:
|
||||
|
||||
return runtime_embedding_model
|
||||
|
||||
async def init_temporary_runtime_rerank_model(
|
||||
self,
|
||||
model_info: dict,
|
||||
) -> requester.RuntimeRerankModel:
|
||||
"""Initialize runtime rerank model from dict (for testing)"""
|
||||
provider_info = model_info.get('provider', {})
|
||||
runtime_provider = await self.load_provider(provider_info)
|
||||
|
||||
runtime_rerank_model = requester.RuntimeRerankModel(
|
||||
model_entity=persistence_model.RerankModel(
|
||||
uuid=model_info.get('uuid', ''),
|
||||
name=model_info.get('name', ''),
|
||||
provider_uuid='',
|
||||
extra_args=model_info.get('extra_args', {}),
|
||||
),
|
||||
provider=runtime_provider,
|
||||
)
|
||||
|
||||
return runtime_rerank_model
|
||||
|
||||
async def load_provider(
|
||||
self, provider_info: persistence_model.ModelProvider | sqlalchemy.Row | dict
|
||||
) -> requester.RuntimeProvider:
|
||||
@@ -264,8 +227,7 @@ class ModelManager:
|
||||
raise provider_errors.RequesterNotFoundError(provider_entity.requester)
|
||||
|
||||
requester_inst = self.requester_dict[provider_entity.requester](
|
||||
ap=self.ap,
|
||||
config={'base_url': provider_entity.base_url},
|
||||
ap=self.ap, config={'base_url': provider_entity.base_url}
|
||||
)
|
||||
await requester_inst.initialize()
|
||||
|
||||
@@ -306,9 +268,6 @@ class ModelManager:
|
||||
for model in self.embedding_models:
|
||||
if model.provider.provider_entity.uuid == provider_uuid:
|
||||
model.provider = new_runtime_provider
|
||||
for model in self.rerank_models:
|
||||
if model.provider.provider_entity.uuid == provider_uuid:
|
||||
model.provider = new_runtime_provider
|
||||
|
||||
# update ref in provider dict
|
||||
self.provider_dict[provider_uuid] = new_runtime_provider
|
||||
@@ -345,22 +304,6 @@ class ModelManager:
|
||||
|
||||
return runtime_embedding_model
|
||||
|
||||
async def load_rerank_model_with_provider(
|
||||
self,
|
||||
model_info: persistence_model.RerankModel | sqlalchemy.Row,
|
||||
provider: requester.RuntimeProvider,
|
||||
) -> requester.RuntimeRerankModel:
|
||||
"""Load rerank model with provider info"""
|
||||
if isinstance(model_info, sqlalchemy.Row):
|
||||
model_info = persistence_model.RerankModel(**model_info._mapping)
|
||||
|
||||
runtime_rerank_model = requester.RuntimeRerankModel(
|
||||
model_entity=model_info,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
return runtime_rerank_model
|
||||
|
||||
async def load_llm_model(self, model_info: dict):
|
||||
"""Load LLM model from dict (with provider info)"""
|
||||
provider_info = model_info.get('provider', {})
|
||||
@@ -408,6 +351,7 @@ class ModelManager:
|
||||
|
||||
await self.load_embedding_model_with_provider(model_entity, provider_entity)
|
||||
|
||||
@alru_cache(ttl=60 * 5)
|
||||
async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel:
|
||||
"""Get LLM model by uuid"""
|
||||
for model in self.llm_models:
|
||||
@@ -415,6 +359,7 @@ class ModelManager:
|
||||
return model
|
||||
raise ValueError(f'LLM model {uuid} not found')
|
||||
|
||||
@alru_cache(ttl=60 * 5)
|
||||
async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel:
|
||||
"""Get embedding model by uuid"""
|
||||
for model in self.embedding_models:
|
||||
@@ -422,13 +367,6 @@ class ModelManager:
|
||||
return model
|
||||
raise ValueError(f'Embedding model {uuid} not found')
|
||||
|
||||
async def get_rerank_model_by_uuid(self, uuid: str) -> requester.RuntimeRerankModel:
|
||||
"""Get rerank model by uuid"""
|
||||
for model in self.rerank_models:
|
||||
if model.model_entity.uuid == uuid:
|
||||
return model
|
||||
raise ValueError(f'Rerank model {uuid} not found')
|
||||
|
||||
async def remove_llm_model(self, model_uuid: str):
|
||||
"""Remove LLM model"""
|
||||
for model in self.llm_models:
|
||||
@@ -443,13 +381,6 @@ class ModelManager:
|
||||
self.embedding_models.remove(model)
|
||||
return
|
||||
|
||||
async def remove_rerank_model(self, model_uuid: str):
|
||||
"""Remove rerank model"""
|
||||
for model in self.rerank_models:
|
||||
if model.model_entity.uuid == model_uuid:
|
||||
self.rerank_models.remove(model)
|
||||
return
|
||||
|
||||
def get_available_requesters_info(self, model_type: str) -> list[dict]:
|
||||
"""Get all available requesters"""
|
||||
if model_type != '':
|
||||
|
||||
@@ -247,40 +247,6 @@ class RuntimeProvider:
|
||||
except Exception as monitor_err:
|
||||
self.requester.ap.logger.error(f'[Monitoring] Failed to record embedding call: {monitor_err}')
|
||||
|
||||
async def invoke_rerank(
|
||||
self,
|
||||
model: RuntimeRerankModel,
|
||||
query: str,
|
||||
documents: typing.List[str],
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> typing.List[dict]:
|
||||
"""Bridge method for invoking rerank with monitoring"""
|
||||
start_time = time.time()
|
||||
status = 'success'
|
||||
|
||||
try:
|
||||
result = await self.requester.invoke_rerank(
|
||||
model=model,
|
||||
query=query,
|
||||
documents=documents,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
status = 'error'
|
||||
raise
|
||||
finally:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
try:
|
||||
self.requester.ap.logger.debug(
|
||||
f'[Rerank] model={model.model_entity.name} docs={len(documents)} '
|
||||
f'duration={duration_ms}ms status={status}'
|
||||
)
|
||||
except Exception as monitor_err:
|
||||
self.requester.ap.logger.error(f'[Monitoring] Failed to record rerank call: {monitor_err}')
|
||||
|
||||
|
||||
class RuntimeLLMModel:
|
||||
"""运行时模型"""
|
||||
@@ -318,29 +284,10 @@ class RuntimeEmbeddingModel:
|
||||
self.provider = provider
|
||||
|
||||
|
||||
class RuntimeRerankModel:
|
||||
"""运行时 Rerank 模型"""
|
||||
|
||||
model_entity: persistence_model.RerankModel
|
||||
"""模型数据"""
|
||||
|
||||
provider: RuntimeProvider
|
||||
"""提供商实例"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_entity: persistence_model.RerankModel,
|
||||
provider: RuntimeProvider,
|
||||
):
|
||||
self.model_entity = model_entity
|
||||
self.provider = provider
|
||||
|
||||
|
||||
class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""Provider API请求器"""
|
||||
|
||||
name: str = None
|
||||
init_api_key: str = 'langbot-init-placeholder'
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -356,14 +303,6 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any] | list[dict[str, typing.Any]]:
|
||||
"""Scan models supported by the provider.
|
||||
|
||||
The default implementation does not support scanning. Requesters that
|
||||
can enumerate remote models should override this method.
|
||||
"""
|
||||
raise NotImplementedError('This provider does not support model scanning')
|
||||
|
||||
@abc.abstractmethod
|
||||
async def invoke_llm(
|
||||
self,
|
||||
@@ -429,23 +368,3 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
或者 tuple[typing.List[typing.List[float]], dict]: 返回 (embedding 向量, usage_info)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def invoke_rerank(
|
||||
self,
|
||||
model: RuntimeRerankModel,
|
||||
query: str,
|
||||
documents: typing.List[str],
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> typing.List[dict]:
|
||||
"""调用 Rerank API
|
||||
|
||||
Args:
|
||||
model (RuntimeRerankModel): 使用的模型信息
|
||||
query (str): 查询文本
|
||||
documents (typing.List[str]): 待重排序的文档列表
|
||||
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
typing.List[dict]: [{"index": int, "relevance_score": float}, ...]
|
||||
"""
|
||||
raise NotImplementedError('This requester does not support rerank')
|
||||
|
||||
@@ -25,7 +25,6 @@ spec:
|
||||
support_type:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
provider_category: maas
|
||||
execution:
|
||||
python:
|
||||
|
||||
@@ -24,7 +24,6 @@ spec:
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
- rerank
|
||||
provider_category: maas
|
||||
execution:
|
||||
python:
|
||||
|
||||
@@ -25,198 +25,12 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
|
||||
async def initialize(self):
|
||||
self.client = openai.AsyncClient(
|
||||
api_key=self.init_api_key,
|
||||
api_key='',
|
||||
base_url=self.requester_cfg['base_url'].replace(' ', ''),
|
||||
timeout=self.requester_cfg['timeout'],
|
||||
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
|
||||
)
|
||||
|
||||
def _mask_api_key(self, api_key: str | None) -> str:
|
||||
if not api_key:
|
||||
return ''
|
||||
if len(api_key) <= 8:
|
||||
return '****'
|
||||
return f'{api_key[:4]}...{api_key[-4:]}'
|
||||
|
||||
def _infer_model_type(self, model_id: str) -> str:
|
||||
normalized_model_id = (model_id or '').lower()
|
||||
embedding_keywords = (
|
||||
'embedding',
|
||||
'embed',
|
||||
'bge-',
|
||||
'e5-',
|
||||
'm3e',
|
||||
'gte-',
|
||||
'multilingual-e5',
|
||||
'text-embedding',
|
||||
)
|
||||
return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm'
|
||||
|
||||
def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]:
|
||||
normalized_model_id = (model_id or '').lower()
|
||||
abilities: set[str] = set()
|
||||
|
||||
def _flatten(value: typing.Any) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
return [value.lower()]
|
||||
if isinstance(value, dict):
|
||||
flattened: list[str] = []
|
||||
for nested_value in value.values():
|
||||
flattened.extend(_flatten(nested_value))
|
||||
return flattened
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
flattened: list[str] = []
|
||||
for nested_value in value:
|
||||
flattened.extend(_flatten(nested_value))
|
||||
return flattened
|
||||
return [str(value).lower()]
|
||||
|
||||
capability_tokens = _flatten(item.get('capabilities'))
|
||||
capability_tokens.extend(_flatten(item.get('modalities')))
|
||||
capability_tokens.extend(_flatten(item.get('input_modalities')))
|
||||
capability_tokens.extend(_flatten(item.get('output_modalities')))
|
||||
capability_tokens.extend(_flatten(item.get('supported_generation_methods')))
|
||||
capability_tokens.extend(_flatten(item.get('supported_parameters')))
|
||||
capability_tokens.extend(_flatten(item.get('architecture')))
|
||||
|
||||
combined_tokens = capability_tokens + [normalized_model_id]
|
||||
|
||||
vision_keywords = (
|
||||
'vision',
|
||||
'image',
|
||||
'file',
|
||||
'video',
|
||||
'multimodal',
|
||||
'vl',
|
||||
'ocr',
|
||||
'omni',
|
||||
)
|
||||
function_call_keywords = (
|
||||
'function',
|
||||
'tool',
|
||||
'tools',
|
||||
'tool_choice',
|
||||
'tool_call',
|
||||
'tool-use',
|
||||
'tool_use',
|
||||
)
|
||||
|
||||
if any(any(keyword in token for keyword in vision_keywords) for token in combined_tokens):
|
||||
abilities.add('vision')
|
||||
|
||||
if any(any(keyword in token for keyword in function_call_keywords) for token in combined_tokens):
|
||||
abilities.add('func_call')
|
||||
|
||||
return sorted(abilities)
|
||||
|
||||
def _normalize_modalities(self, value: typing.Any) -> list[str]:
|
||||
normalized: list[str] = []
|
||||
|
||||
def _collect(item: typing.Any):
|
||||
if item is None:
|
||||
return
|
||||
if isinstance(item, str):
|
||||
for part in item.replace('->', ',').replace('+', ',').split(','):
|
||||
token = part.strip().lower()
|
||||
if token and token not in normalized:
|
||||
normalized.append(token)
|
||||
return
|
||||
if isinstance(item, dict):
|
||||
for nested in item.values():
|
||||
_collect(nested)
|
||||
return
|
||||
if isinstance(item, (list, tuple, set)):
|
||||
for nested in item:
|
||||
_collect(nested)
|
||||
return
|
||||
|
||||
_collect(value)
|
||||
return normalized
|
||||
|
||||
def _extract_scan_metadata(self, item: dict[str, typing.Any], model_id: str) -> dict[str, typing.Any]:
|
||||
display_name = item.get('name')
|
||||
if not isinstance(display_name, str) or not display_name.strip() or display_name == model_id:
|
||||
display_name = ''
|
||||
|
||||
description = item.get('description')
|
||||
if not isinstance(description, str) or not description.strip():
|
||||
description = ''
|
||||
|
||||
context_length = item.get('context_length')
|
||||
if context_length is None and isinstance(item.get('top_provider'), dict):
|
||||
context_length = item['top_provider'].get('context_length')
|
||||
|
||||
if not isinstance(context_length, int):
|
||||
try:
|
||||
context_length = int(context_length) if context_length is not None else None
|
||||
except (TypeError, ValueError):
|
||||
context_length = None
|
||||
|
||||
input_modalities = self._normalize_modalities(item.get('input_modalities'))
|
||||
output_modalities = self._normalize_modalities(item.get('output_modalities'))
|
||||
|
||||
if isinstance(item.get('architecture'), dict):
|
||||
if not input_modalities:
|
||||
input_modalities = self._normalize_modalities(item['architecture'].get('input_modalities'))
|
||||
if not output_modalities:
|
||||
output_modalities = self._normalize_modalities(item['architecture'].get('output_modalities'))
|
||||
|
||||
owned_by = item.get('owned_by')
|
||||
if not isinstance(owned_by, str) or not owned_by.strip():
|
||||
owned_by = ''
|
||||
|
||||
return {
|
||||
'display_name': display_name or None,
|
||||
'description': description or None,
|
||||
'context_length': context_length,
|
||||
'owned_by': owned_by or None,
|
||||
'input_modalities': input_modalities,
|
||||
'output_modalities': output_modalities,
|
||||
}
|
||||
|
||||
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
|
||||
models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/models'
|
||||
async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client:
|
||||
response = await client.get(models_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
models = []
|
||||
for item in payload.get('data', []):
|
||||
model_id = item.get('id')
|
||||
if not model_id:
|
||||
continue
|
||||
models.append(
|
||||
{
|
||||
'id': model_id,
|
||||
'name': model_id,
|
||||
'type': self._infer_model_type(model_id),
|
||||
'abilities': self._infer_model_abilities(item, model_id),
|
||||
**self._extract_scan_metadata(item, model_id),
|
||||
}
|
||||
)
|
||||
|
||||
models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower()))
|
||||
return {
|
||||
'models': models,
|
||||
'debug': {
|
||||
'request': {
|
||||
'method': 'GET',
|
||||
'url': models_url,
|
||||
'headers': {
|
||||
'Authorization': f'Bearer {self._mask_api_key(api_key)}' if api_key else '',
|
||||
},
|
||||
},
|
||||
'response': payload,
|
||||
},
|
||||
}
|
||||
|
||||
async def _req(
|
||||
self,
|
||||
args: dict,
|
||||
@@ -615,88 +429,3 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
||||
except openai.APIError as e:
|
||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||
|
||||
async def invoke_rerank(
|
||||
self,
|
||||
model: requester.RuntimeRerankModel,
|
||||
query: str,
|
||||
documents: typing.List[str],
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> typing.List[dict]:
|
||||
"""Standard /rerank endpoint (Jina/Cohere/SiliconFlow/Voyage/DashScope compatible)
|
||||
|
||||
Supports extra_args from model.extra_args:
|
||||
- rerank_url: full URL override (e.g. "https://dashscope.aliyuncs.com/compatible-api/v1/reranks")
|
||||
- rerank_path: path override appended to base_url (e.g. "reranks" instead of default "rerank")
|
||||
- Any other fields are merged into the request payload.
|
||||
"""
|
||||
api_key = model.provider.token_mgr.get_token()
|
||||
base_url = self.requester_cfg.get('base_url', '').rstrip('/')
|
||||
timeout = self.requester_cfg.get('timeout', 120)
|
||||
|
||||
merged_args = {}
|
||||
if model.model_entity.extra_args:
|
||||
merged_args.update(model.model_entity.extra_args)
|
||||
if extra_args:
|
||||
merged_args.update(extra_args)
|
||||
|
||||
rerank_url = merged_args.pop('rerank_url', None)
|
||||
rerank_path = merged_args.pop('rerank_path', 'rerank')
|
||||
if not rerank_url:
|
||||
rerank_url = f'{base_url}/{rerank_path}'
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
}
|
||||
|
||||
payload = {
|
||||
'model': model.model_entity.name,
|
||||
'query': query,
|
||||
'documents': documents[:64],
|
||||
'top_n': min(len(documents), 64),
|
||||
}
|
||||
|
||||
if merged_args:
|
||||
payload.update(merged_args)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(trust_env=True, timeout=timeout) as client:
|
||||
resp = await client.post(rerank_url, headers=headers, json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
results = self._parse_rerank_response(data)
|
||||
|
||||
if results:
|
||||
scores = [r.get('relevance_score', 0.0) for r in results]
|
||||
min_score = min(scores)
|
||||
max_score = max(scores)
|
||||
if max_score - min_score > 1e-6:
|
||||
for r in results:
|
||||
r['relevance_score'] = (r['relevance_score'] - min_score) / (max_score - min_score)
|
||||
|
||||
return results
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise errors.RequesterError(f'Rerank request failed: {e.response.status_code} - {e.response.text}')
|
||||
except httpx.TimeoutException:
|
||||
raise errors.RequesterError('Rerank request timed out')
|
||||
except Exception as e:
|
||||
raise errors.RequesterError(f'Rerank request error: {str(e)}')
|
||||
|
||||
@staticmethod
|
||||
def _parse_rerank_response(data: dict) -> typing.List[dict]:
|
||||
"""Parse rerank response from various providers.
|
||||
|
||||
Handles:
|
||||
- Jina/Cohere/SiliconFlow: {"results": [{"index", "relevance_score"}]}
|
||||
- Voyage AI: {"data": [{"index", "relevance_score"}]}
|
||||
- DashScope: {"output": {"results": [{"index", "relevance_score"}]}}
|
||||
"""
|
||||
if 'results' in data:
|
||||
return data['results']
|
||||
if 'data' in data:
|
||||
return data['data']
|
||||
if 'output' in data and isinstance(data['output'], dict):
|
||||
return data['output'].get('results', [])
|
||||
return []
|
||||
|
||||
@@ -25,7 +25,6 @@ spec:
|
||||
support_type:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
provider_category: manufacturer
|
||||
execution:
|
||||
python:
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 128 128" id="Chroma--Streamline-Svg-Logos" height="128" width="128">
|
||||
<desc>
|
||||
Chroma Streamline Icon: https://streamlinehq.com
|
||||
</desc>
|
||||
<path fill="#ffde2d" d="M84.88839999999999 104.10666666666665c23.0732 0 41.77773333333333 -17.956266666666664 41.77773333333333 -40.10653333333333 0 -22.150266666666667 -18.70453333333333 -40.10653333333333 -41.77773333333333 -40.10653333333333 -23.0732 0 -41.77773333333333 17.956266666666664 -41.77773333333333 40.10653333333333 0 22.150266666666667 18.70453333333333 40.10653333333333 41.77773333333333 40.10653333333333Z" stroke-width="1.3333"></path>
|
||||
<path fill="#327eff" d="M43.111066666666666 104.10666666666665c23.0732 0 41.77773333333333 -17.956266666666664 41.77773333333333 -40.10653333333333 0 -22.150266666666667 -18.70453333333333 -40.10653333333333 -41.77773333333333 -40.10653333333333C20.037866666666666 23.8936 1.3333333333333333 41.849866666666664 1.3333333333333333 64.00013333333334 1.3333333333333333 86.15039999999999 20.037866666666666 104.10666666666665 43.111066666666666 104.10666666666665Z" stroke-width="1.3333"></path>
|
||||
<path fill="#ff6446" d="M84.88866666666667 64.00013333333334c0 22.150399999999998 -18.704666666666665 40.10626666666666 -41.778 40.10626666666666V64.00013333333334h41.778Zm-41.778 0c0 -22.150266666666667 18.70453333333333 -40.10653333333333 41.778 -40.10653333333333v40.10653333333333H43.11066666666666Z" stroke-width="1.3333"></path>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.5 KiB |
@@ -1,61 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import requester
|
||||
|
||||
REQUESTER_NAME: str = 'chroma-embedding'
|
||||
|
||||
|
||||
class ChromaEmbedding(requester.ProviderAPIRequester):
|
||||
"""Chroma built-in embedding requester.
|
||||
|
||||
Uses chromadb's DefaultEmbeddingFunction (all-MiniLM-L6-v2).
|
||||
The embedding function runs locally using ONNX Runtime.
|
||||
"""
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base_url': '',
|
||||
}
|
||||
|
||||
_embedding_function = None
|
||||
|
||||
async def initialize(self):
|
||||
try:
|
||||
from chromadb.utils import embedding_functions
|
||||
except ImportError:
|
||||
raise ImportError('chromadb is not installed. Install it with: pip install chromadb')
|
||||
|
||||
self._embedding_function = embedding_functions.DefaultEmbeddingFunction()
|
||||
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query,
|
||||
model: requester.RuntimeLLMModel,
|
||||
messages: typing.List,
|
||||
funcs: typing.List = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
):
|
||||
raise NotImplementedError('Chroma embedding does not support LLM inference')
|
||||
|
||||
async def invoke_embedding(
|
||||
self,
|
||||
model: requester.RuntimeEmbeddingModel,
|
||||
input_text: typing.List[str],
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> typing.List[typing.List[float]]:
|
||||
"""Generate embeddings using Chroma's DefaultEmbeddingFunction."""
|
||||
if self._embedding_function is None:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
result = self._embedding_function(input_text)
|
||||
# DefaultEmbeddingFunction returns list of ndarray, convert for JSON
|
||||
if isinstance(result, list):
|
||||
return [item.tolist() if hasattr(item, 'tolist') else item for item in result]
|
||||
return result.tolist() if hasattr(result, 'tolist') else result
|
||||
except Exception as e:
|
||||
from .. import errors
|
||||
|
||||
raise errors.RequesterError(f'Chroma embedding failed: {str(e)}')
|
||||
@@ -1,21 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: LLMAPIRequester
|
||||
metadata:
|
||||
name: chroma-embedding
|
||||
label:
|
||||
en_US: Chroma Embedding
|
||||
zh_Hans: Chroma 嵌入
|
||||
description:
|
||||
en_US: Chroma built-in embedding model (all-MiniLM-L6-v2), runs locally using ONNX Runtime. First-time use will download model files automatically.
|
||||
zh_Hans: 使用 Chroma 内置嵌入模型 (all-MiniLM-L6-v2),基于 ONNX Runtime 本地运行。首次使用时将自动下载模型文件。
|
||||
ja_JP: Chroma 組み込み埋め込みモデル (all-MiniLM-L6-v2) を使用します。ONNX Runtime でローカル実行。初回使用時にモデルファイルが自動ダウンロードされます。
|
||||
icon: chroma.svg
|
||||
spec:
|
||||
config: []
|
||||
support_type:
|
||||
- text-embedding
|
||||
provider_category: builtin
|
||||
execution:
|
||||
python:
|
||||
path: ./chromaembed.py
|
||||
attr: ChromaEmbedding
|
||||
@@ -1 +0,0 @@
|
||||
<svg height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Cohere</title><path clip-rule="evenodd" d="M8.128 14.099c.592 0 1.77-.033 3.398-.703 1.897-.781 5.672-2.2 8.395-3.656 1.905-1.018 2.74-2.366 2.74-4.18A4.56 4.56 0 0018.1 1H7.549A6.55 6.55 0 001 7.55c0 3.617 2.745 6.549 7.128 6.549z" fill="#39594D" fill-rule="evenodd"></path><path clip-rule="evenodd" d="M9.912 18.61a4.387 4.387 0 012.705-4.052l3.323-1.38c3.361-1.394 7.06 1.076 7.06 4.715a5.104 5.104 0 01-5.105 5.104l-3.597-.001a4.386 4.386 0 01-4.386-4.387z" fill="#D18EE2" fill-rule="evenodd"></path><path d="M4.776 14.962A3.775 3.775 0 001 18.738v.489a3.776 3.776 0 007.551 0v-.49a3.775 3.775 0 00-3.775-3.775z" fill="#FF7759"></path></svg>
|
||||
|
Before Width: | Height: | Size: 769 B |
@@ -1,31 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: LLMAPIRequester
|
||||
metadata:
|
||||
name: cohere-rerank
|
||||
label:
|
||||
en_US: Cohere
|
||||
zh_Hans: Cohere
|
||||
icon: cohere.svg
|
||||
spec:
|
||||
config:
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_Hans: 基础 URL
|
||||
type: string
|
||||
required: true
|
||||
default: https://api.cohere.com/v2
|
||||
- name: timeout
|
||||
label:
|
||||
en_US: Timeout
|
||||
zh_Hans: 超时时间
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- rerank
|
||||
provider_category: manufacturer
|
||||
execution:
|
||||
python:
|
||||
path: ./chatcmpl.py
|
||||
attr: OpenAIChatCompletions
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import httpx
|
||||
|
||||
from . import chatcmpl
|
||||
|
||||
@@ -21,68 +20,6 @@ class GeminiChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
||||
models_url = 'https://generativelanguage.googleapis.com/v1beta/models'
|
||||
params = {'key': api_key} if api_key else {}
|
||||
|
||||
all_models: list[dict[str, typing.Any]] = []
|
||||
next_page_token = ''
|
||||
last_payload: dict[str, typing.Any] = {}
|
||||
|
||||
async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client:
|
||||
while True:
|
||||
request_params = dict(params)
|
||||
if next_page_token:
|
||||
request_params['pageToken'] = next_page_token
|
||||
|
||||
response = await client.get(models_url, params=request_params)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
last_payload = payload
|
||||
|
||||
for item in payload.get('models', []):
|
||||
model_name = item.get('name', '')
|
||||
model_id = model_name.replace('models/', '', 1)
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
supported_methods = item.get('supportedGenerationMethods', []) or []
|
||||
if 'embedContent' in supported_methods and 'generateContent' not in supported_methods:
|
||||
model_type = 'embedding'
|
||||
else:
|
||||
model_type = 'llm'
|
||||
|
||||
all_models.append(
|
||||
{
|
||||
'id': model_id,
|
||||
'name': model_id,
|
||||
'type': model_type,
|
||||
'abilities': self._infer_model_abilities(item, model_id),
|
||||
'display_name': item.get('displayName') or None,
|
||||
'description': item.get('description') or None,
|
||||
'context_length': item.get('inputTokenLimit'),
|
||||
'input_modalities': self._normalize_modalities(item.get('inputModalities')),
|
||||
'output_modalities': self._normalize_modalities(item.get('outputModalities')),
|
||||
}
|
||||
)
|
||||
|
||||
next_page_token = payload.get('nextPageToken', '')
|
||||
if not next_page_token:
|
||||
break
|
||||
|
||||
all_models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower()))
|
||||
return {
|
||||
'models': all_models,
|
||||
'debug': {
|
||||
'request': {
|
||||
'method': 'GET',
|
||||
'url': models_url,
|
||||
'query': {'key': self._mask_api_key(api_key)} if api_key else {},
|
||||
},
|
||||
'response': last_payload,
|
||||
},
|
||||
}
|
||||
|
||||
async def _closure_stream(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
|
||||
@@ -25,7 +25,6 @@ spec:
|
||||
support_type:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
provider_category: maas
|
||||
execution:
|
||||
python:
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
<svg fill="currentColor" fill-rule="evenodd" height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Jina</title><path d="M6.608 21.416a4.608 4.608 0 100-9.217 4.608 4.608 0 000 9.217zM20.894 2.015c.614 0 1.106.492 1.106 1.106v9.002c0 5.13-4.148 9.309-9.217 9.37v-9.355l-.03-9.032c0-.614.491-1.106 1.106-1.106h7.158l-.123.015z"></path></svg>
|
||||
|
Before Width: | Height: | Size: 404 B |
@@ -1,31 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: LLMAPIRequester
|
||||
metadata:
|
||||
name: jina-rerank
|
||||
label:
|
||||
en_US: Jina
|
||||
zh_Hans: Jina
|
||||
icon: jina.svg
|
||||
spec:
|
||||
config:
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_Hans: 基础 URL
|
||||
type: string
|
||||
required: true
|
||||
default: https://api.jina.ai/v1
|
||||
- name: timeout
|
||||
label:
|
||||
en_US: Timeout
|
||||
zh_Hans: 超时时间
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- rerank
|
||||
provider_category: manufacturer
|
||||
execution:
|
||||
python:
|
||||
path: ./chatcmpl.py
|
||||
attr: OpenAIChatCompletions
|
||||
@@ -25,181 +25,12 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
|
||||
async def initialize(self):
|
||||
self.client = openai.AsyncClient(
|
||||
api_key=self.init_api_key,
|
||||
api_key='',
|
||||
base_url=self.requester_cfg['base_url'],
|
||||
timeout=self.requester_cfg['timeout'],
|
||||
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
|
||||
)
|
||||
|
||||
def _mask_api_key(self, api_key: str | None) -> str:
|
||||
if not api_key:
|
||||
return ''
|
||||
if len(api_key) <= 8:
|
||||
return '****'
|
||||
return f'{api_key[:4]}...{api_key[-4:]}'
|
||||
|
||||
def _infer_model_type(self, model_id: str) -> str:
|
||||
normalized_model_id = (model_id or '').lower()
|
||||
embedding_keywords = (
|
||||
'embedding',
|
||||
'embed',
|
||||
'bge-',
|
||||
'e5-',
|
||||
'm3e',
|
||||
'gte-',
|
||||
'multilingual-e5',
|
||||
'text-embedding',
|
||||
)
|
||||
return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm'
|
||||
|
||||
def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]:
|
||||
normalized_model_id = (model_id or '').lower()
|
||||
abilities: set[str] = set()
|
||||
|
||||
def _flatten(value: typing.Any) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
return [value.lower()]
|
||||
if isinstance(value, dict):
|
||||
flattened: list[str] = []
|
||||
for nested_value in value.values():
|
||||
flattened.extend(_flatten(nested_value))
|
||||
return flattened
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
flattened: list[str] = []
|
||||
for nested_value in value:
|
||||
flattened.extend(_flatten(nested_value))
|
||||
return flattened
|
||||
return [str(value).lower()]
|
||||
|
||||
capability_tokens = _flatten(item.get('capabilities'))
|
||||
capability_tokens.extend(_flatten(item.get('modalities')))
|
||||
capability_tokens.extend(_flatten(item.get('input_modalities')))
|
||||
capability_tokens.extend(_flatten(item.get('output_modalities')))
|
||||
capability_tokens.extend(_flatten(item.get('supported_generation_methods')))
|
||||
capability_tokens.extend(_flatten(item.get('supported_parameters')))
|
||||
capability_tokens.extend(_flatten(item.get('architecture')))
|
||||
|
||||
combined_tokens = capability_tokens + [normalized_model_id]
|
||||
|
||||
vision_keywords = ('vision', 'image', 'file', 'video', 'multimodal', 'vl', 'ocr', 'omni')
|
||||
function_call_keywords = ('function', 'tool', 'tools', 'tool_choice', 'tool_call', 'tool-use', 'tool_use')
|
||||
|
||||
if any(any(keyword in token for keyword in vision_keywords) for token in combined_tokens):
|
||||
abilities.add('vision')
|
||||
|
||||
if any(any(keyword in token for keyword in function_call_keywords) for token in combined_tokens):
|
||||
abilities.add('func_call')
|
||||
|
||||
return sorted(abilities)
|
||||
|
||||
def _normalize_modalities(self, value: typing.Any) -> list[str]:
|
||||
normalized: list[str] = []
|
||||
|
||||
def _collect(item: typing.Any):
|
||||
if item is None:
|
||||
return
|
||||
if isinstance(item, str):
|
||||
for part in item.replace('->', ',').replace('+', ',').split(','):
|
||||
token = part.strip().lower()
|
||||
if token and token not in normalized:
|
||||
normalized.append(token)
|
||||
return
|
||||
if isinstance(item, dict):
|
||||
for nested in item.values():
|
||||
_collect(nested)
|
||||
return
|
||||
if isinstance(item, (list, tuple, set)):
|
||||
for nested in item:
|
||||
_collect(nested)
|
||||
return
|
||||
|
||||
_collect(value)
|
||||
return normalized
|
||||
|
||||
def _extract_scan_metadata(self, item: dict[str, typing.Any], model_id: str) -> dict[str, typing.Any]:
|
||||
display_name = item.get('name')
|
||||
if not isinstance(display_name, str) or not display_name.strip() or display_name == model_id:
|
||||
display_name = ''
|
||||
|
||||
description = item.get('description')
|
||||
if not isinstance(description, str) or not description.strip():
|
||||
description = ''
|
||||
|
||||
context_length = item.get('context_length')
|
||||
if context_length is None and isinstance(item.get('top_provider'), dict):
|
||||
context_length = item['top_provider'].get('context_length')
|
||||
|
||||
if not isinstance(context_length, int):
|
||||
try:
|
||||
context_length = int(context_length) if context_length is not None else None
|
||||
except (TypeError, ValueError):
|
||||
context_length = None
|
||||
|
||||
input_modalities = self._normalize_modalities(item.get('input_modalities'))
|
||||
output_modalities = self._normalize_modalities(item.get('output_modalities'))
|
||||
|
||||
if isinstance(item.get('architecture'), dict):
|
||||
if not input_modalities:
|
||||
input_modalities = self._normalize_modalities(item['architecture'].get('input_modalities'))
|
||||
if not output_modalities:
|
||||
output_modalities = self._normalize_modalities(item['architecture'].get('output_modalities'))
|
||||
|
||||
owned_by = item.get('owned_by')
|
||||
if not isinstance(owned_by, str) or not owned_by.strip():
|
||||
owned_by = ''
|
||||
|
||||
return {
|
||||
'display_name': display_name or None,
|
||||
'description': description or None,
|
||||
'context_length': context_length,
|
||||
'owned_by': owned_by or None,
|
||||
'input_modalities': input_modalities,
|
||||
'output_modalities': output_modalities,
|
||||
}
|
||||
|
||||
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
|
||||
models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/models'
|
||||
async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client:
|
||||
response = await client.get(models_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
models = []
|
||||
for item in payload.get('data', []):
|
||||
model_id = item.get('id')
|
||||
if not model_id:
|
||||
continue
|
||||
models.append(
|
||||
{
|
||||
'id': model_id,
|
||||
'name': model_id,
|
||||
'type': self._infer_model_type(model_id),
|
||||
'abilities': self._infer_model_abilities(item, model_id),
|
||||
**self._extract_scan_metadata(item, model_id),
|
||||
}
|
||||
)
|
||||
|
||||
models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower()))
|
||||
return {
|
||||
'models': models,
|
||||
'debug': {
|
||||
'request': {
|
||||
'method': 'GET',
|
||||
'url': models_url,
|
||||
'headers': {
|
||||
'Authorization': f'Bearer {self._mask_api_key(api_key)}' if api_key else '',
|
||||
},
|
||||
},
|
||||
'response': payload,
|
||||
},
|
||||
}
|
||||
|
||||
async def _req(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
|
||||
@@ -8,7 +8,6 @@ import uuid
|
||||
import json
|
||||
|
||||
import ollama
|
||||
import httpx
|
||||
|
||||
from .. import errors, requester
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
@@ -32,60 +31,6 @@ class OllamaChatCompletions(requester.ProviderAPIRequester):
|
||||
os.environ['OLLAMA_HOST'] = self.requester_cfg['base_url']
|
||||
self.client = ollama.AsyncClient(timeout=self.requester_cfg['timeout'])
|
||||
|
||||
def _infer_model_type(self, model_id: str) -> str:
|
||||
normalized_model_id = (model_id or '').lower()
|
||||
embedding_keywords = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding')
|
||||
return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm'
|
||||
|
||||
def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]:
|
||||
normalized_model_id = (model_id or '').lower()
|
||||
abilities: set[str] = set()
|
||||
details = item.get('details', {}) or {}
|
||||
families = details.get('families', []) or []
|
||||
tokens = [normalized_model_id, str(details.get('family', '')).lower()]
|
||||
tokens.extend(str(family).lower() for family in families)
|
||||
|
||||
if any(keyword in token for token in tokens for keyword in ('vision', 'vl', 'omni', 'llava', 'ocr')):
|
||||
abilities.add('vision')
|
||||
if any(keyword in token for token in tokens for keyword in ('tool', 'function')):
|
||||
abilities.add('func_call')
|
||||
return sorted(abilities)
|
||||
|
||||
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
||||
del api_key
|
||||
models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/api/tags'
|
||||
|
||||
async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client:
|
||||
response = await client.get(models_url)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
models: list[dict[str, typing.Any]] = []
|
||||
for item in payload.get('models', []):
|
||||
model_id = item.get('model') or item.get('name')
|
||||
if not model_id:
|
||||
continue
|
||||
models.append(
|
||||
{
|
||||
'id': model_id,
|
||||
'name': item.get('name', model_id),
|
||||
'type': self._infer_model_type(model_id),
|
||||
'abilities': self._infer_model_abilities(item, model_id),
|
||||
}
|
||||
)
|
||||
|
||||
models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower()))
|
||||
return {
|
||||
'models': models,
|
||||
'debug': {
|
||||
'request': {
|
||||
'method': 'GET',
|
||||
'url': models_url,
|
||||
},
|
||||
'response': payload,
|
||||
},
|
||||
}
|
||||
|
||||
async def _req(
|
||||
self,
|
||||
args: dict,
|
||||
@@ -159,21 +104,6 @@ class OllamaChatCompletions(requester.ProviderAPIRequester):
|
||||
|
||||
return ret_msg
|
||||
|
||||
async def _prepare_messages(
|
||||
self,
|
||||
messages: typing.List[provider_message.Message],
|
||||
) -> list[dict]:
|
||||
"""Prepare messages for Ollama API request."""
|
||||
req_messages: list = []
|
||||
for m in messages:
|
||||
msg_dict: dict = m.dict(exclude_none=True)
|
||||
content: Any = msg_dict.get('content')
|
||||
if isinstance(content, list):
|
||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
return req_messages
|
||||
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
@@ -183,7 +113,14 @@ class OllamaChatCompletions(requester.ProviderAPIRequester):
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message:
|
||||
req_messages = await self._prepare_messages(messages)
|
||||
req_messages: list = []
|
||||
for m in messages:
|
||||
msg_dict: dict = m.dict(exclude_none=True)
|
||||
content: Any = msg_dict.get('content')
|
||||
if isinstance(content, list):
|
||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
try:
|
||||
return await self._closure(
|
||||
query=query,
|
||||
@@ -196,109 +133,6 @@ class OllamaChatCompletions(requester.ProviderAPIRequester):
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
|
||||
async def invoke_llm_stream(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
model: requester.RuntimeLLMModel,
|
||||
messages: typing.List[provider_message.Message],
|
||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.MessageChunk:
|
||||
req_messages = await self._prepare_messages(messages)
|
||||
|
||||
try:
|
||||
args = extra_args.copy()
|
||||
args['model'] = model.model_entity.name
|
||||
|
||||
# Process messages for Ollama format
|
||||
msgs: list[dict] = req_messages.copy()
|
||||
for msg in msgs:
|
||||
if 'content' in msg and isinstance(msg['content'], list):
|
||||
text_content: list = []
|
||||
image_urls: list = []
|
||||
for me in msg['content']:
|
||||
if me['type'] == 'text':
|
||||
text_content.append(me['text'])
|
||||
elif me['type'] == 'image_base64':
|
||||
image_urls.append(me['image_base64'])
|
||||
msg['content'] = '\n'.join(text_content)
|
||||
msg['images'] = [url.split(',')[1] for url in image_urls]
|
||||
if 'tool_calls' in msg:
|
||||
for tool_call in msg['tool_calls']:
|
||||
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
|
||||
args['messages'] = msgs
|
||||
|
||||
args['tools'] = []
|
||||
if funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
|
||||
if tools:
|
||||
args['tools'] = tools
|
||||
|
||||
args['stream'] = True
|
||||
|
||||
chunk_idx = 0
|
||||
thinking_started = False
|
||||
thinking_ended = False
|
||||
role = 'assistant'
|
||||
|
||||
async for chunk in await self.client.chat(**args):
|
||||
message: ollama.Message = chunk.message
|
||||
done = chunk.done
|
||||
|
||||
delta_content = message.content or ''
|
||||
reasoning_content = getattr(message, 'thinking', '') or ''
|
||||
|
||||
# Handle reasoning/thinking content
|
||||
if reasoning_content:
|
||||
if remove_think:
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
if not thinking_started:
|
||||
thinking_started = True
|
||||
delta_content = '<think>\n' + reasoning_content
|
||||
else:
|
||||
delta_content = reasoning_content
|
||||
elif thinking_started and not thinking_ended and delta_content:
|
||||
thinking_ended = True
|
||||
delta_content = '\n</think>\n' + delta_content
|
||||
|
||||
# Handle tool calls
|
||||
tool_calls_data = None
|
||||
if message.tool_calls:
|
||||
tool_calls_data = []
|
||||
for tc in message.tool_calls:
|
||||
tool_calls_data.append(
|
||||
{
|
||||
'id': uuid.uuid4().hex,
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': tc.function.name,
|
||||
'arguments': json.dumps(tc.function.arguments),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Skip empty first chunk
|
||||
if chunk_idx == 0 and not delta_content and not reasoning_content and not tool_calls_data:
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
chunk_data = {
|
||||
'role': role,
|
||||
'content': delta_content if delta_content else None,
|
||||
'tool_calls': tool_calls_data,
|
||||
'is_final': bool(done),
|
||||
}
|
||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
||||
|
||||
yield provider_message.MessageChunk(**chunk_data)
|
||||
chunk_idx += 1
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
|
||||
async def invoke_embedding(
|
||||
self,
|
||||
model: requester.RuntimeEmbeddingModel,
|
||||
|
||||
@@ -15,11 +15,3 @@ class OpenRouterChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
|
||||
'base_url': 'https://openrouter.ai/api/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
||||
original_base_url = self.requester_cfg.get('base_url', '')
|
||||
self.requester_cfg['base_url'] = 'https://openrouter.ai/api/v1'
|
||||
try:
|
||||
return await super().scan_models(api_key)
|
||||
finally:
|
||||
self.requester_cfg['base_url'] = original_base_url
|
||||
|
||||
@@ -25,7 +25,6 @@ spec:
|
||||
support_type:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
provider_category: maas
|
||||
execution:
|
||||
python:
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
<svg height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Qiniu</title><path d="M23.111 4.6a.914.914 0 00-.861.161A13.443 13.443 0 017.947 8.897L7.38 6.831a1.076 1.076 0 00-1.211-.698l.27 2.18c-1.816-.827-2.313-.946-3.587-2.45C2.674 5.729 1.263 4.472.89 4.6a11.906 11.906 0 005.892 6.497l.738 5.97s.33 2.286 2.473 2.286h4.586c2.144 0 2.474-2.286 2.474-2.286l.518-4.28c-1.393-.11-2.268.857-2.546 1.814-.465 1.614-.465 1.716-.557 1.998-.188.575-.806.644-.806.644h-2.753s-.617-.07-.806-.644c-.12-.371-.727-2.54-1.335-4.74A11.877 11.877 0 0023.11 4.599V4.6z" fill="#06AEEF"></path></svg>
|
||||
|
Before Width: | Height: | Size: 649 B |
@@ -1,45 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import openai
|
||||
|
||||
from . import chatcmpl
|
||||
|
||||
|
||||
class QiniuChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""七牛云 ChatCompletion API 请求器"""
|
||||
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base_url': 'https://api.qnaigc.com/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
||||
try:
|
||||
result = await super().scan_models(api_key)
|
||||
except Exception:
|
||||
return self._qiniu_fallback_scan_result()
|
||||
models = result.get('models') or []
|
||||
if not models:
|
||||
return self._qiniu_fallback_scan_result()
|
||||
return result
|
||||
|
||||
def _qiniu_fallback_scan_result(self) -> dict[str, typing.Any]:
|
||||
mid = 'deepseek-v3'
|
||||
return {
|
||||
'models': [
|
||||
{
|
||||
'id': mid,
|
||||
'name': mid,
|
||||
'type': 'llm',
|
||||
'abilities': [],
|
||||
}
|
||||
],
|
||||
'debug': {
|
||||
'request': {'method': 'GET', 'url': '', 'headers': {}},
|
||||
'response': {},
|
||||
},
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: LLMAPIRequester
|
||||
metadata:
|
||||
name: qiniu-chat-completions
|
||||
label:
|
||||
en_US: Qiniu
|
||||
zh_Hans: 七牛云
|
||||
icon: qiniu.svg
|
||||
spec:
|
||||
config:
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_Hans: 基础 URL
|
||||
type: string
|
||||
required: true
|
||||
default: https://api.qnaigc.com/v1
|
||||
- name: timeout
|
||||
label:
|
||||
en_US: Timeout
|
||||
zh_Hans: 超时时间
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
provider_category: maas
|
||||
execution:
|
||||
python:
|
||||
path: ./qiniuchatcmpl.py
|
||||
attr: QiniuChatCompletions
|
||||
@@ -1,17 +1,8 @@
|
||||
<svg id="_图层_1" data-name="图层 1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 334.84 76.22">
|
||||
<defs>
|
||||
<style>
|
||||
.cls-1 {
|
||||
fill: currentColor;
|
||||
}
|
||||
</style>
|
||||
</defs>
|
||||
<path class="cls-1" d="M308.56,23.63c-5.04,0-9.73,1.43-13.73,3.88V1.08l-12.56,4.61v70h12.56v-3.35c4,2.46,8.71,3.88,13.73,3.88,14.49,0,26.29-11.79,26.29-26.29s-11.79-26.29-26.29-26.29h0ZM308.56,63.88c-6.87,0-12.57-4.98-13.73-11.51v-4.91c1.16-6.54,6.88-11.51,13.73-11.51,7.7,0,13.96,6.26,13.96,13.96s-6.26,13.96-13.96,13.96Z"></path>
|
||||
<path class="cls-1" d="M255.54,5.69v21.83c-4-2.46-8.71-3.88-13.73-3.88-14.49,0-26.29,11.79-26.29,26.29s11.79,26.29,26.29,26.29c5.04,0,9.73-1.43,13.73-3.88v3.35h12.56V1.08l-12.56,4.61ZM241.81,63.88c-7.7,0-13.96-6.26-13.96-13.96s6.26-13.96,13.96-13.96c6.87,0,12.57,4.98,13.73,11.51v4.91c-1.16,6.54-6.88,11.51-13.73,11.51Z"></path>
|
||||
<polygon class="cls-1" points="195.35 52.2 186.65 61.17 200.64 75.62 209.32 75.62 218.01 75.62 195.35 52.2"></polygon>
|
||||
<path class="cls-1" d="M167.14,4.59c.65,3.99.68,8.04.03,12.15-.03.17.16.3.31.21,3.82-2.21,7.82-3.69,12.01-4.33.12-.02.19-.13.17-.23-.68-4.13-.61-8.18-.03-12.16.02-.17-.16-.3-.31-.2-4.01,2.31-8.01,3.81-12.01,4.34-.12.01-.19.12-.17.23h0Z"></path>
|
||||
<path class="cls-1" d="M198.75,24.09l-19.07,19.72v-25.57c-4.49.67-8.7,2.11-12.56,4.57v52.83h12.56v-13.87l3.78-3.9.02.02,8.68-8.97-.02-.02,23.98-24.8h-17.37Z"></path>
|
||||
<path class="cls-1" d="M145.03,57.86c-2.56,4.45-7.17,7.2-12.13,7.2-5.96,0-11.3-3.96-13.32-9.85h38.87l.08-.42c.29-1.5.42-3.06.42-4.65,0-14.37-11.69-26.06-26.06-26.06s-26.06,11.69-26.06,26.06,11.69,26.06,26.06,26.06c9.63,0,18.43-5.28,22.98-13.77l.26-.49-11.1-4.08h-.01ZM132.88,35.19h.03c5.96,0,11.3,3.96,13.32,9.85h-26.67c2.02-5.89,7.36-9.85,13.32-9.85Z"></path>
|
||||
<path class="cls-1" d="M75.92,65.07c-5.96,0-11.29-3.96-13.32-9.85h38.87l.08-.42c.29-1.5.42-3.06.42-4.65,0-14.37-11.69-26.06-26.06-26.06s-26.06,11.69-26.06,26.06,11.69,26.06,26.06,26.06c9.63,0,18.43-5.28,22.98-13.77l.26-.49h0l-11.1-4.08c-2.56,4.45-7.17,7.2-12.13,7.2h-.01ZM75.92,35.19h.03c5.96,0,11.29,3.96,13.32,9.85h-26.67c2.03-5.89,7.36-9.85,13.32-9.85Z"></path>
|
||||
<path class="cls-1" d="M30.43,45.58l-10.2-1.91c-3.03-.56-4.98-2.25-4.98-4.33,0-1.5,1.61-4.35,7.68-4.35,5.53,0,9.36,3.5,10.25,6.26l10.9-4-.14-.42c-1.17-3.54-3.5-6.58-6.94-9.04-3.49-2.49-8.04-3.69-13.88-3.69s-10.98,1.5-14.78,4.34c-3.88,2.91-5.84,6.76-5.84,11.46,0,7.98,4.72,12.77,14.42,14.64l9.9,1.81c3.05.61,4.94,2.27,4.94,4.33,0,2.61-3.58,4.44-8.7,4.44-5.79,0-9.9-3.72-11.85-7.14L0,62.1l.14.39c1.3,3.8,3.89,7.07,7.7,9.71,3.78,2.6,8.65,3.95,14.51,3.98l.25.03c6.87,0,12.55-1.57,16.43-4.53,3.98-3.05,6-6.99,6-11.74,0-3.73-1.14-6.7-3.6-9.33-2.27-2.42-5.98-4.11-10.98-5.02h-.02Z"></path>
|
||||
</svg>
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="24" height="24" rx="5" fill="#1E3A5F"/>
|
||||
<path d="M6 12C6 8.68629 8.68629 6 12 6C15.3137 6 18 8.68629 18 12" stroke="#4FC3F7" stroke-width="2" stroke-linecap="round"/>
|
||||
<path d="M18 12C18 15.3137 15.3137 18 12 18C8.68629 18 6 15.3137 6 12" stroke="#81D4FA" stroke-width="2" stroke-linecap="round"/>
|
||||
<circle cx="12" cy="12" r="2" fill="#4FC3F7"/>
|
||||
<circle cx="6" cy="12" r="1.5" fill="#81D4FA"/>
|
||||
<circle cx="18" cy="12" r="1.5" fill="#4FC3F7"/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 2.7 KiB After Width: | Height: | Size: 569 B |
@@ -46,15 +46,14 @@ class SeekDBEmbedding(requester.ProviderAPIRequester):
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> typing.List[typing.List[float]]:
|
||||
"""Generate embeddings using SeekDB's built-in embedding function."""
|
||||
if self._embedding_function is None:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
result = self._embedding_function(input_text)
|
||||
# Ensure JSON serialization compatibility
|
||||
if isinstance(result, list):
|
||||
return [item.tolist() if hasattr(item, 'tolist') else item for item in result]
|
||||
return result.tolist() if hasattr(result, 'tolist') else result
|
||||
if self._embedding_function is None:
|
||||
await self.initialize()
|
||||
|
||||
if self._embedding_function is None:
|
||||
raise RuntimeError('SeekDB embedding function initialization failed')
|
||||
|
||||
return self._embedding_function(input_text)
|
||||
except Exception as e:
|
||||
from .. import errors
|
||||
|
||||
|
||||