mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-14 09:46:03 +00:00
Compare commits
29 Commits
copilot/ad
...
v4.6.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b6c0345b3e | ||
|
|
6421a6f5cb | ||
|
|
daf56e5dc2 | ||
|
|
cb7c9af25c | ||
|
|
45e61befac | ||
|
|
ea50ba10e6 | ||
|
|
5c4a727e74 | ||
|
|
867f05c4ad | ||
|
|
b06b32306f | ||
|
|
dbfcb70f8d | ||
|
|
e64d56c4ac | ||
|
|
8f0da7943c | ||
|
|
e62ff7e520 | ||
|
|
86e951916e | ||
|
|
6bf08466de | ||
|
|
5e36dd480d | ||
|
|
0e2cd8c018 | ||
|
|
b4f92eba38 | ||
|
|
905e48c8ed | ||
|
|
10ec79312e | ||
|
|
24f779ff95 | ||
|
|
08c0677de9 | ||
|
|
cc5d32cf8a | ||
|
|
01a5133396 | ||
|
|
0aa5188b29 | ||
|
|
e49a161d0a | ||
|
|
0ddc3d60e7 | ||
|
|
51794176af | ||
|
|
b634aa48dc |
@@ -84,7 +84,7 @@ docker compose up -d
|
||||
## ✨ 特性
|
||||
|
||||
- 💬 大模型对话、Agent:支持多种大模型,适配群聊和私聊;具有多轮对话、工具调用、多模态、流式输出能力,自带 RAG(知识库)实现,并深度适配 [Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io)等 LLMOps 平台。
|
||||
- 🤖 多平台支持:目前支持 QQ、QQ频道、企业微信、个人微信、飞书、Discord、Telegram 等平台。
|
||||
- 🤖 多平台支持:目前支持 QQ、QQ频道、企业微信、个人微信、飞书、Discord、Telegram、KOOK、Slack、LINE 等平台。
|
||||
- 🛠️ 高稳定性、功能完备:原生支持访问控制、限速、敏感词过滤等机制;配置简单,支持多种部署方式。支持多流水线配置,不同机器人用于不同应用场景。
|
||||
- 🧩 插件扩展、活跃社区:高稳定性、高安全性的生产级插件系统,支持事件驱动、组件扩展等插件机制;适配 Anthropic [MCP 协议](https://modelcontextprotocol.io/);目前已有数百个插件。
|
||||
- 😻 Web 管理面板:支持通过浏览器管理 LangBot 实例,不再需要手动编写配置文件。
|
||||
@@ -108,6 +108,7 @@ docker compose up -d
|
||||
| 微信公众号 | ✅ | |
|
||||
| 飞书 | ✅ | |
|
||||
| 钉钉 | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Discord | ✅ | |
|
||||
| Telegram | ✅ | |
|
||||
| Slack | ✅ | |
|
||||
|
||||
@@ -80,7 +80,7 @@ Click the Star and Watch button in the upper right corner of the repository to g
|
||||
## ✨ Features
|
||||
|
||||
- 💬 Chat with LLM / Agent: Supports multiple LLMs, adapt to group chats and private chats; Supports multi-round conversations, tool calls, multi-modal, and streaming output capabilities. Built-in RAG (knowledge base) implementation, and deeply integrates with [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io) etc. LLMOps platforms.
|
||||
- 🤖 Multi-platform Support: Currently supports QQ, QQ Channel, WeCom, personal WeChat, Lark, DingTalk, Discord, Telegram, etc.
|
||||
- 🤖 Multi-platform Support: Currently supports QQ, QQ Channel, WeCom, personal WeChat, Lark, DingTalk, Discord, Telegram, KOOK, Slack, LINE, etc.
|
||||
- 🛠️ High Stability, Feature-rich: Native access control, rate limiting, sensitive word filtering, etc. mechanisms; Easy to use, supports multiple deployment methods. Supports multiple pipeline configurations, different bots can be used for different scenarios.
|
||||
- 🧩 Plugin Extension, Active Community: High stability, high security production-level plugin system; Support event-driven, component extension, etc. plugin mechanisms; Integrate Anthropic [MCP protocol](https://modelcontextprotocol.io/); Currently has hundreds of plugins.
|
||||
- 😻 Web UI: Support management LangBot instance through the browser. No need to manually write configuration files.
|
||||
@@ -107,6 +107,7 @@ Or visit the demo environment: https://demo.langbot.dev/
|
||||
| Personal WeChat | ✅ | |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
|
||||
### LLMs
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ Haga clic en los botones Star y Watch en la esquina superior derecha del reposit
|
||||
## ✨ Características
|
||||
|
||||
- 💬 Chat con LLM / Agent: Compatible con múltiples LLMs, adaptado para chats grupales y privados; Admite conversaciones de múltiples rondas, llamadas a herramientas, capacidades multimodales y de salida en streaming. Implementación RAG (base de conocimientos) incorporada, e integración profunda con [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io) etc. LLMOps platforms.
|
||||
- 🤖 Soporte Multiplataforma: Actualmente compatible con QQ, QQ Channel, WeCom, WeChat personal, Lark, DingTalk, Discord, Telegram, etc.
|
||||
- 🤖 Soporte Multiplataforma: Actualmente compatible con QQ, QQ Channel, WeCom, WeChat personal, Lark, DingTalk, Discord, Telegram, KOOK, Slack, LINE, etc.
|
||||
- 🛠️ Alta Estabilidad, Rico en Funciones: Control de acceso nativo, limitación de velocidad, filtrado de palabras sensibles, etc.; Fácil de usar, admite múltiples métodos de despliegue. Compatible con múltiples configuraciones de pipeline, diferentes bots para diferentes escenarios.
|
||||
- 🧩 Extensión de Plugin, Comunidad Activa: Sistema de plugin de alta estabilidad, alta seguridad de nivel de producción; Compatible con mecanismos de plugin impulsados por eventos, extensión de componentes, etc.; Integración del protocolo [MCP](https://modelcontextprotocol.io/) de Anthropic; Actualmente cuenta con cientos de plugins.
|
||||
- 😻 Interfaz Web: Admite la gestión de instancias de LangBot a través del navegador. No es necesario escribir archivos de configuración manualmente.
|
||||
@@ -107,6 +107,7 @@ O visite el entorno de demostración: https://demo.langbot.dev/
|
||||
| WeChat Personal | ✅ | |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
|
||||
### LLMs
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ Cliquez sur les boutons Star et Watch dans le coin supérieur droit du dépôt p
|
||||
## ✨ Fonctionnalités
|
||||
|
||||
- 💬 Chat avec LLM / Agent : Prend en charge plusieurs LLM, adapté aux chats de groupe et privés ; Prend en charge les conversations multi-tours, les appels d'outils, les capacités multimodales et de sortie en streaming. Implémentation RAG (base de connaissances) intégrée, et intégration profonde avec [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io) etc. LLMOps platforms.
|
||||
- 🤖 Support Multi-plateforme : Actuellement compatible avec QQ, QQ Channel, WeCom, WeChat personnel, Lark, DingTalk, Discord, Telegram, etc.
|
||||
- 🤖 Support Multi-plateforme : Actuellement compatible avec QQ, QQ Channel, WeCom, WeChat personnel, Lark, DingTalk, Discord, Telegram, KOOK, Slack, LINE, etc.
|
||||
- 🛠️ Haute Stabilité, Riche en Fonctionnalités : Contrôle d'accès natif, limitation de débit, filtrage de mots sensibles, etc. ; Facile à utiliser, prend en charge plusieurs méthodes de déploiement. Prend en charge plusieurs configurations de pipeline, différents bots pour différents scénarios.
|
||||
- 🧩 Extension de Plugin, Communauté Active : Système de plugin de haute stabilité, haute sécurité de niveau production; Prend en charge les mécanismes de plugin pilotés par événements, l'extension de composants, etc. ; Intégration du protocole [MCP](https://modelcontextprotocol.io/) d'Anthropic ; Dispose actuellement de centaines de plugins.
|
||||
- 😻 Interface Web : Prend en charge la gestion des instances LangBot via le navigateur. Pas besoin d'écrire manuellement les fichiers de configuration.
|
||||
@@ -107,6 +107,7 @@ Ou visitez l'environnement de démonstration : https://demo.langbot.dev/
|
||||
| WeChat Personnel | ✅ | |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
|
||||
### LLMs
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ LangBotはBTPanelにリストされています。BTPanelをインストール
|
||||
## ✨ 機能
|
||||
|
||||
- 💬 LLM / エージェントとのチャット: 複数のLLMをサポートし、グループチャットとプライベートチャットに対応。マルチラウンドの会話、ツールの呼び出し、マルチモーダル、ストリーミング出力機能をサポート、RAG(知識ベース)を組み込み、[Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io) などの LLMOps プラットフォームと深く統合。
|
||||
- 🤖 多プラットフォーム対応: 現在、QQ、QQ チャンネル、WeChat、個人 WeChat、Lark、DingTalk、Discord、Telegram など、複数のプラットフォームをサポートしています。
|
||||
- 🤖 多プラットフォーム対応: 現在、QQ、QQ チャンネル、WeChat、個人 WeChat、Lark、DingTalk、Discord、Telegram、KOOK、Slack、LINE など、複数のプラットフォームをサポートしています。
|
||||
- 🛠️ 高い安定性、豊富な機能: ネイティブのアクセス制御、レート制限、敏感な単語のフィルタリングなどのメカニズムをサポート。使いやすく、複数のデプロイ方法をサポート。複数のパイプライン設定をサポートし、異なるボットを異なる用途に使用できます。
|
||||
- 🧩 プラグイン拡張、活発なコミュニティ: 高い安定性、高いセキュリティの生産レベルのプラグインシステム;イベント駆動、コンポーネント拡張などのプラグインメカニズムをサポート。適配 Anthropic [MCP プロトコル](https://modelcontextprotocol.io/);豊富なエコシステム、現在数百のプラグインが存在。
|
||||
- 😻 Web UI: ブラウザを通じてLangBotインスタンスを管理することをサポート。
|
||||
@@ -107,6 +107,7 @@ LangBotはBTPanelにリストされています。BTPanelをインストール
|
||||
| 個人WeChat | ✅ | |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
|
||||
### LLMs
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ LangBot은 BTPanel에 등록되어 있습니다. BTPanel을 설치한 경우 [
|
||||
## ✨ 기능
|
||||
|
||||
- 💬 LLM / Agent와 채팅: 여러 LLM을 지원하며 그룹 채팅 및 개인 채팅에 적응; 멀티 라운드 대화, 도구 호출, 멀티모달, 스트리밍 출력 기능을 지원합니다. 내장된 RAG(지식 베이스) 구현 및 [Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io) 등의 LLMOps 플랫폼과 깊이 통합됩니다.
|
||||
- 🤖 다중 플랫폼 지원: 현재 QQ, QQ Channel, WeCom, 개인 WeChat, Lark, DingTalk, Discord, Telegram 등을 지원합니다.
|
||||
- 🤖 다중 플랫폼 지원: 현재 QQ, QQ Channel, WeCom, 개인 WeChat, Lark, DingTalk, Discord, Telegram, KOOK, Slack, LINE 등을 지원합니다.
|
||||
- 🛠️ 높은 안정성, 풍부한 기능: 네이티브 액세스 제어, 속도 제한, 민감한 단어 필터링 등의 메커니즘; 사용하기 쉽고 여러 배포 방법을 지원합니다. 여러 파이프라인 구성을 지원하며 다양한 시나리오에 대해 다른 봇을 사용할 수 있습니다.
|
||||
- 🧩 플러그인 확장, 활발한 커뮤니티: 고안정성, 고보안 생산 수준의 플러그인 시스템; 이벤트 기반, 컴포넌트 확장 등의 플러그인 메커니즘을 지원; Anthropic [MCP 프로토콜](https://modelcontextprotocol.io/) 통합; 현재 수백 개의 플러그인이 있습니다.
|
||||
- 😻 웹 UI: 브라우저를 통해 LangBot 인스턴스 관리를 지원합니다. 구성 파일을 수동으로 작성할 필요가 없습니다.
|
||||
@@ -105,6 +105,7 @@ LangBot은 BTPanel에 등록되어 있습니다. BTPanel을 설치한 경우 [
|
||||
| WeComCS | ✅ | |
|
||||
| WeCom AI Bot | ✅ | |
|
||||
| 개인 WeChat | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ LangBot добавлен в BTPanel. Если у вас установлен BTP
|
||||
## ✨ Функции
|
||||
|
||||
- 💬 Чат с LLM / Agent: Поддержка нескольких LLM, адаптация к групповым и личным чатам; Поддержка многораундовых разговоров, вызовов инструментов, мультимодальных возможностей и потоковой передачи. Встроенная реализация RAG (база знаний) и глубокая интеграция с [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io) 등의 LLMOps 플랫포트폼과 깊이 통합됩니다.
|
||||
- 🤖 Многоплатформенная поддержка: В настоящее время поддерживает QQ, QQ Channel, WeCom, личный WeChat, Lark, DingTalk, Discord, Telegram и т.д.
|
||||
- 🤖 Многоплатформенная поддержка: В настоящее время поддерживает QQ, QQ Channel, WeCom, личный WeChat, Lark, DingTalk, Discord, Telegram, KOOK, Slack, LINE и т.д.
|
||||
- 🛠️ Высокая стабильность, богатство функций: Нативный контроль доступа, ограничение скорости, фильтрация чувствительных слов и т.д.; Простота в использовании, поддержка нескольких методов развертывания. Поддержка нескольких конфигураций конвейера, разные боты для разных сценариев.
|
||||
- 🧩 Расширение плагинов, активное сообщество: Высокая стабильность, высокая безопасность уровня производства; Поддержка механизмов плагинов, управляемых событиями, расширения компонентов и т.д.; Интеграция протокола [MCP](https://modelcontextprotocol.io/) от Anthropic; В настоящее время сотни плагинов.
|
||||
- 😻 Веб-интерфейс: Поддержка управления экземплярами LangBot через браузер. Нет необходимости вручную писать конфигурационные файлы.
|
||||
@@ -105,6 +105,7 @@ LangBot добавлен в BTPanel. Если у вас установлен BTP
|
||||
| WeComCS | ✅ | |
|
||||
| WeCom AI Bot | ✅ | |
|
||||
| Личный WeChat | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ docker compose up -d
|
||||
## ✨ 特性
|
||||
|
||||
- 💬 大模型對話、Agent:支援多種大模型,適配群聊和私聊;具有多輪對話、工具調用、多模態、流式輸出能力,自帶 RAG(知識庫)實現,並深度適配 [Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io) 等 LLMOps 平台。
|
||||
- 🤖 多平台支援:目前支援 QQ、QQ頻道、企業微信、個人微信、飛書、Discord、Telegram 等平台。
|
||||
- 🤖 多平台支援:目前支援 QQ、QQ頻道、企業微信、個人微信、飛書、Discord、Telegram、KOOK、Slack、LINE 等平台。
|
||||
- 🛠️ 高穩定性、功能完備:原生支援訪問控制、限速、敏感詞過濾等機制;配置簡單,支援多種部署方式。支援多流水線配置,不同機器人用於不同應用場景。
|
||||
- 🧩 外掛擴展、活躍社群:高穩定性、高安全性的生產級外掛系統;支援事件驅動、組件擴展等外掛機制;適配 Anthropic [MCP 協議](https://modelcontextprotocol.io/);目前已有數百個外掛。
|
||||
- 😻 Web 管理面板:支援通過瀏覽器管理 LangBot 實例,不再需要手動編寫配置文件。
|
||||
@@ -105,6 +105,7 @@ docker compose up -d
|
||||
| 企微對外客服 | ✅ | |
|
||||
| 企微智能機器人 | ✅ | |
|
||||
| 微信公眾號 | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ Nhấp vào các nút Star và Watch ở góc trên bên phải của kho lưu t
|
||||
## ✨ Tính năng
|
||||
|
||||
- 💬 Chat với LLM / Agent: Hỗ trợ nhiều LLM, thích ứng với chat nhóm và chat riêng tư; Hỗ trợ các cuộc trò chuyện nhiều vòng, gọi công cụ, khả năng đa phương thức và đầu ra streaming. Triển khai RAG (cơ sở kiến thức) tích hợp sẵn và tích hợp sâu với [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io) v.v. LLMOps platforms.
|
||||
- 🤖 Hỗ trợ Đa nền tảng: Hiện hỗ trợ QQ, QQ Channel, WeCom, WeChat cá nhân, Lark, DingTalk, Discord, Telegram, v.v.
|
||||
- 🤖 Hỗ trợ Đa nền tảng: Hiện hỗ trợ QQ, QQ Channel, WeCom, WeChat cá nhân, Lark, DingTalk, Discord, Telegram, KOOK, Slack, LINE, v.v.
|
||||
- 🛠️ Độ ổn định Cao, Tính năng Phong phú: Kiểm soát truy cập gốc, giới hạn tốc độ, lọc từ nhạy cảm, v.v.; Dễ sử dụng, hỗ trợ nhiều phương pháp triển khai. Hỗ trợ nhiều cấu hình pipeline, các bot khác nhau cho các kịch bản khác nhau.
|
||||
- 🧩 Mở rộng Plugin, Cộng đồng Hoạt động: Hỗ trợ các cơ chế plugin hướng sự kiện, mở rộng thành phần, v.v.; Tích hợp giao thức [MCP](https://modelcontextprotocol.io/) của Anthropic; Hiện có hàng trăng plugin.
|
||||
- 😻 Giao diện Web: Hỗ trợ quản lý các phiên bản LangBot thông qua trình duyệt. Không cần viết tệp cấu hình thủ công.
|
||||
@@ -105,6 +105,7 @@ Hoặc truy cập môi trường demo: https://demo.langbot.dev/
|
||||
| WeComCS | ✅ | |
|
||||
| WeCom AI Bot | ✅ | |
|
||||
| WeChat Cá nhân | ✅ | |
|
||||
| KOOK | ✅ | |
|
||||
| Lark | ✅ | |
|
||||
| DingTalk | ✅ | |
|
||||
|
||||
|
||||
@@ -25,13 +25,12 @@ services:
|
||||
platform: linux/amd64 # For Apple Silicon compatibility
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
- ./plugins:/app/plugins
|
||||
restart: on-failure
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
ports:
|
||||
- 5300:5300 # For web ui
|
||||
- 2280-2290:2280-2290 # For platform webhook
|
||||
- 5300:5300 # For web ui and webhook callback
|
||||
- 2280-2285:2280-2285 # For platform reverse connection
|
||||
networks:
|
||||
- langbot_network
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "langbot"
|
||||
version = "4.5.4"
|
||||
version = "4.6.3"
|
||||
description = "Easy-to-use global IM bot platform designed for LLM era"
|
||||
readme = "README.md"
|
||||
license-files = ["LICENSE"]
|
||||
@@ -63,11 +63,13 @@ dependencies = [
|
||||
"langchain-text-splitters>=0.0.1",
|
||||
"chromadb>=0.4.24",
|
||||
"qdrant-client (>=1.15.1,<2.0.0)",
|
||||
"langbot-plugin==0.2.0b1",
|
||||
"langbot-plugin==0.2.1",
|
||||
"asyncpg>=0.30.0",
|
||||
"line-bot-sdk>=3.19.0",
|
||||
"tboxsdk>=0.0.10",
|
||||
"boto3>=1.35.0",
|
||||
"pymilvus>=2.6.4",
|
||||
"pgvector>=0.4.1",
|
||||
]
|
||||
keywords = [
|
||||
"bot",
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""LangBot - Easy-to-use global IM bot platform designed for LLM era"""
|
||||
|
||||
__version__ = '4.5.4'
|
||||
__version__ = '4.6.3'
|
||||
|
||||
@@ -23,20 +23,25 @@ xml_template = """
|
||||
|
||||
|
||||
class OAClient:
|
||||
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str, logger: None):
|
||||
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str, logger: None, unified_mode: bool = False):
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
self.appid = AppID
|
||||
self.appsecret = Appsecret
|
||||
self.base_url = 'https://api.weixin.qq.com'
|
||||
self.access_token = ''
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
@@ -46,19 +51,39 @@ class OAClient:
|
||||
self.logger = logger
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求(独立端口模式,使用全局 request)。"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
# 每隔100毫秒查询是否生成ai回答
|
||||
start_time = time.time()
|
||||
signature = request.args.get('signature', '')
|
||||
timestamp = request.args.get('timestamp', '')
|
||||
nonce = request.args.get('nonce', '')
|
||||
echostr = request.args.get('echostr', '')
|
||||
msg_signature = request.args.get('msg_signature', '')
|
||||
signature = req.args.get('signature', '')
|
||||
timestamp = req.args.get('timestamp', '')
|
||||
nonce = req.args.get('nonce', '')
|
||||
echostr = req.args.get('echostr', '')
|
||||
msg_signature = req.args.get('msg_signature', '')
|
||||
if msg_signature is None:
|
||||
await self.logger.error('msg_signature不在请求体中')
|
||||
raise Exception('msg_signature不在请求体中')
|
||||
|
||||
if request.method == 'GET':
|
||||
if req.method == 'GET':
|
||||
# 校验签名
|
||||
check_str = ''.join(sorted([self.token, timestamp, nonce]))
|
||||
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
|
||||
@@ -68,8 +93,8 @@ class OAClient:
|
||||
else:
|
||||
await self.logger.error('拒绝请求')
|
||||
raise Exception('拒绝请求')
|
||||
elif request.method == 'POST':
|
||||
encryt_msg = await request.data
|
||||
elif req.method == 'POST':
|
||||
encryt_msg = await req.data
|
||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
|
||||
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
|
||||
xml_msg = xml_msg.decode('utf-8')
|
||||
@@ -182,6 +207,7 @@ class OAClientForLongerResponse:
|
||||
Appsecret: str,
|
||||
LoadingMessage: str,
|
||||
logger: None,
|
||||
unified_mode: bool = False,
|
||||
):
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
@@ -189,13 +215,18 @@ class OAClientForLongerResponse:
|
||||
self.appsecret = Appsecret
|
||||
self.base_url = 'https://api.weixin.qq.com'
|
||||
self.access_token = ''
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
@@ -206,24 +237,44 @@ class OAClientForLongerResponse:
|
||||
self.logger = logger
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求(独立端口模式,使用全局 request)。"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
signature = request.args.get('signature', '')
|
||||
timestamp = request.args.get('timestamp', '')
|
||||
nonce = request.args.get('nonce', '')
|
||||
echostr = request.args.get('echostr', '')
|
||||
msg_signature = request.args.get('msg_signature', '')
|
||||
signature = req.args.get('signature', '')
|
||||
timestamp = req.args.get('timestamp', '')
|
||||
nonce = req.args.get('nonce', '')
|
||||
echostr = req.args.get('echostr', '')
|
||||
msg_signature = req.args.get('msg_signature', '')
|
||||
|
||||
if msg_signature is None:
|
||||
await self.logger.error('msg_signature不在请求体中')
|
||||
raise Exception('msg_signature不在请求体中')
|
||||
|
||||
if request.method == 'GET':
|
||||
if req.method == 'GET':
|
||||
check_str = ''.join(sorted([self.token, timestamp, nonce]))
|
||||
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
|
||||
return echostr if check_signature == signature else '拒绝请求'
|
||||
|
||||
elif request.method == 'POST':
|
||||
encryt_msg = await request.data
|
||||
elif req.method == 'POST':
|
||||
encryt_msg = await req.data
|
||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
|
||||
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
|
||||
xml_msg = xml_msg.decode('utf-8')
|
||||
|
||||
@@ -10,38 +10,20 @@ import traceback
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
|
||||
def handle_validation(body: dict, bot_secret: str):
|
||||
# bot正确的secert是32位的,此处仅为了适配演示demo
|
||||
while len(bot_secret) < 32:
|
||||
bot_secret = bot_secret * 2
|
||||
bot_secret = bot_secret[:32]
|
||||
# 实际使用场景中以上三行内容可清除
|
||||
|
||||
seed_bytes = bot_secret.encode()
|
||||
|
||||
signing_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed_bytes)
|
||||
|
||||
msg = body['d']['event_ts'] + body['d']['plain_token']
|
||||
msg_bytes = msg.encode()
|
||||
|
||||
signature = signing_key.sign(msg_bytes)
|
||||
|
||||
signature_hex = signature.hex()
|
||||
|
||||
response = {'plain_token': body['d']['plain_token'], 'signature': signature_hex}
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class QQOfficialClient:
|
||||
def __init__(self, secret: str, token: str, app_id: str, logger: None):
|
||||
def __init__(self, secret: str, token: str, app_id: str, logger: None, unified_mode: bool = False):
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
|
||||
self.secret = secret
|
||||
self.token = token
|
||||
self.app_id = app_id
|
||||
@@ -82,18 +64,45 @@ class QQOfficialClient:
|
||||
raise Exception(f'获取access_token失败: {e}')
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求"""
|
||||
"""处理回调请求(独立端口模式,使用全局 request)"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""处理回调请求的内部实现。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
# 读取请求数据
|
||||
body = await request.get_data()
|
||||
|
||||
body = await req.get_data()
|
||||
|
||||
print(f'[QQ Official] Received request, body length: {len(body)}')
|
||||
|
||||
if not body or len(body) == 0:
|
||||
print('[QQ Official] Received empty body, might be health check or GET request')
|
||||
return {'code': 0, 'message': 'ok'}, 200
|
||||
|
||||
payload = json.loads(body)
|
||||
|
||||
# 验证是否为回调验证请求
|
||||
|
||||
if payload.get('op') == 13:
|
||||
# 生成签名
|
||||
response = handle_validation(payload, self.secret)
|
||||
|
||||
return response
|
||||
validation_data = payload.get('d')
|
||||
if not validation_data:
|
||||
return {'error': "missing 'd' field"}, 400
|
||||
response = await self.verify(validation_data)
|
||||
return response, 200
|
||||
|
||||
if payload.get('op') == 0:
|
||||
message_data = await self.get_message(payload)
|
||||
@@ -104,6 +113,7 @@ class QQOfficialClient:
|
||||
return {'code': 0, 'message': 'success'}
|
||||
|
||||
except Exception as e:
|
||||
print(f'[QQ Official] ERROR: {traceback.format_exc()}')
|
||||
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
|
||||
return {'error': str(e)}, 400
|
||||
|
||||
@@ -261,3 +271,26 @@ class QQOfficialClient:
|
||||
if self.access_token_expiry_time is None:
|
||||
return True
|
||||
return time.time() > self.access_token_expiry_time
|
||||
|
||||
async def repeat_seed(self, bot_secret: str, target_size: int = 32) -> bytes:
|
||||
seed = bot_secret
|
||||
while len(seed) < target_size:
|
||||
seed *= 2
|
||||
return seed[:target_size].encode("utf-8")
|
||||
|
||||
async def verify(self, validation_payload: dict):
|
||||
seed = await self.repeat_seed(self.secret)
|
||||
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
|
||||
|
||||
event_ts = validation_payload.get("event_ts", "")
|
||||
plain_token = validation_payload.get("plain_token", "")
|
||||
msg = event_ts + plain_token
|
||||
|
||||
# sign
|
||||
signature = private_key.sign(msg.encode()).hex()
|
||||
|
||||
response = {
|
||||
"plain_token": plain_token,
|
||||
"signature": signature,
|
||||
}
|
||||
return response
|
||||
|
||||
@@ -8,14 +8,19 @@ import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
|
||||
|
||||
class SlackClient:
|
||||
def __init__(self, bot_token: str, signing_secret: str, logger: None):
|
||||
def __init__(self, bot_token: str, signing_secret: str, logger: None, unified_mode: bool = False):
|
||||
self.bot_token = bot_token
|
||||
self.signing_secret = signing_secret
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
self.client = AsyncWebClient(self.bot_token)
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
|
||||
)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
|
||||
)
|
||||
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
@@ -23,8 +28,28 @@ class SlackClient:
|
||||
self.logger = logger
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求(独立端口模式,使用全局 request)"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""处理回调请求的内部实现。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
body = await request.get_data()
|
||||
body = await req.get_data()
|
||||
data = json.loads(body)
|
||||
if 'type' in data:
|
||||
if data['type'] == 'url_verification':
|
||||
|
||||
@@ -200,7 +200,7 @@ class StreamSessionManager:
|
||||
|
||||
|
||||
class WecomBotClient:
|
||||
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger):
|
||||
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger, unified_mode: bool = False):
|
||||
"""企业微信智能机器人客户端。
|
||||
|
||||
Args:
|
||||
@@ -208,6 +208,7 @@ class WecomBotClient:
|
||||
EnCodingAESKey: 企业微信消息加解密密钥。
|
||||
Corpid: 企业 ID。
|
||||
logger: 日志记录器。
|
||||
unified_mode: 是否使用统一 webhook 模式(默认 False)。
|
||||
|
||||
Example:
|
||||
>>> client = WecomBotClient(Token='token', EnCodingAESKey='aeskey', Corpid='corp', logger=logger)
|
||||
@@ -217,10 +218,15 @@ class WecomBotClient:
|
||||
self.EnCodingAESKey = EnCodingAESKey
|
||||
self.Corpid = Corpid
|
||||
self.ReceiveId = ''
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['POST', 'GET']
|
||||
)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['POST', 'GET']
|
||||
)
|
||||
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
@@ -359,7 +365,7 @@ class WecomBotClient:
|
||||
return await self._encrypt_and_reply(payload, nonce)
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""企业微信回调入口。
|
||||
"""企业微信回调入口(独立端口模式,使用全局 request)。
|
||||
|
||||
Returns:
|
||||
Quart Response: 根据请求类型返回验证、首包或刷新结果。
|
||||
@@ -367,15 +373,33 @@ class WecomBotClient:
|
||||
Example:
|
||||
作为 Quart 路由处理函数直接注册并使用。
|
||||
"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
self.wxcpt = WXBizMsgCrypt(self.Token, self.EnCodingAESKey, '')
|
||||
await self.logger.info(f'{request.method} {request.url} {str(request.args)}')
|
||||
|
||||
if request.method == 'GET':
|
||||
return await self._handle_get_callback()
|
||||
if req.method == 'GET':
|
||||
return await self._handle_get_callback(req)
|
||||
|
||||
if request.method == 'POST':
|
||||
return await self._handle_post_callback()
|
||||
if req.method == 'POST':
|
||||
return await self._handle_post_callback(req)
|
||||
|
||||
return Response('', status=405)
|
||||
|
||||
@@ -383,13 +407,13 @@ class WecomBotClient:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
return Response('Internal Server Error', status=500)
|
||||
|
||||
async def _handle_get_callback(self) -> tuple[Response, int] | Response:
|
||||
async def _handle_get_callback(self, req) -> tuple[Response, int] | Response:
|
||||
"""处理企业微信的 GET 验证请求。"""
|
||||
|
||||
msg_signature = unquote(request.args.get('msg_signature', ''))
|
||||
timestamp = unquote(request.args.get('timestamp', ''))
|
||||
nonce = unquote(request.args.get('nonce', ''))
|
||||
echostr = unquote(request.args.get('echostr', ''))
|
||||
msg_signature = unquote(req.args.get('msg_signature', ''))
|
||||
timestamp = unquote(req.args.get('timestamp', ''))
|
||||
nonce = unquote(req.args.get('nonce', ''))
|
||||
echostr = unquote(req.args.get('echostr', ''))
|
||||
|
||||
if not all([msg_signature, timestamp, nonce, echostr]):
|
||||
await self.logger.error('请求参数缺失')
|
||||
@@ -402,16 +426,16 @@ class WecomBotClient:
|
||||
|
||||
return Response(decrypted_str, mimetype='text/plain')
|
||||
|
||||
async def _handle_post_callback(self) -> tuple[Response, int] | Response:
|
||||
async def _handle_post_callback(self, req) -> tuple[Response, int] | Response:
|
||||
"""处理企业微信的 POST 回调请求。"""
|
||||
|
||||
self.stream_sessions.cleanup()
|
||||
|
||||
msg_signature = unquote(request.args.get('msg_signature', ''))
|
||||
timestamp = unquote(request.args.get('timestamp', ''))
|
||||
nonce = unquote(request.args.get('nonce', ''))
|
||||
msg_signature = unquote(req.args.get('msg_signature', ''))
|
||||
timestamp = unquote(req.args.get('timestamp', ''))
|
||||
nonce = unquote(req.args.get('nonce', ''))
|
||||
|
||||
encrypted_json = await request.get_json()
|
||||
encrypted_json = await req.get_json()
|
||||
encrypted_msg = (encrypted_json or {}).get('encrypt', '')
|
||||
if not encrypted_msg:
|
||||
await self.logger.error("请求体中缺少 'encrypt' 字段")
|
||||
@@ -433,32 +457,174 @@ class WecomBotClient:
|
||||
async def get_message(self, msg_json):
|
||||
message_data = {}
|
||||
|
||||
msg_type = msg_json.get('msgtype', '')
|
||||
if msg_type:
|
||||
message_data['msgtype'] = msg_type
|
||||
|
||||
if msg_json.get('chattype', '') == 'single':
|
||||
message_data['type'] = 'single'
|
||||
elif msg_json.get('chattype', '') == 'group':
|
||||
message_data['type'] = 'group'
|
||||
|
||||
if msg_json.get('msgtype') == 'text':
|
||||
max_inline_file_size = 5 * 1024 * 1024 # avoid decoding very large payloads by default
|
||||
|
||||
async def _safe_download(url: str):
|
||||
if not url:
|
||||
return None
|
||||
return await self.download_url_to_base64(url, self.EnCodingAESKey)
|
||||
|
||||
if msg_type == 'text':
|
||||
message_data['content'] = msg_json.get('text', {}).get('content')
|
||||
elif msg_json.get('msgtype') == 'image':
|
||||
elif msg_type == 'markdown':
|
||||
message_data['content'] = msg_json.get('markdown', {}).get('content') or msg_json.get('text', {}).get(
|
||||
'content', ''
|
||||
)
|
||||
elif msg_type == 'image':
|
||||
picurl = msg_json.get('image', {}).get('url', '')
|
||||
base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey)
|
||||
message_data['picurl'] = base64
|
||||
elif msg_json.get('msgtype') == 'mixed':
|
||||
base64_data = await _safe_download(picurl)
|
||||
if base64_data:
|
||||
message_data['picurl'] = base64_data
|
||||
message_data['images'] = [base64_data]
|
||||
elif msg_type == 'voice':
|
||||
voice_info = msg_json.get('voice', {}) or {}
|
||||
download_url = voice_info.get('url')
|
||||
message_data['voice'] = {
|
||||
'url': download_url,
|
||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||
}
|
||||
# 企业微信智能转写文本(如果已有)直接复用,避免重复转写
|
||||
if voice_info.get('content'):
|
||||
message_data['content'] = voice_info.get('content')
|
||||
if (message_data['voice'].get('filesize') or 0) <= max_inline_file_size:
|
||||
voice_base64 = await _safe_download(download_url)
|
||||
if voice_base64:
|
||||
message_data['voice']['base64'] = voice_base64
|
||||
elif msg_type == 'video':
|
||||
video_info = msg_json.get('video', {}) or {}
|
||||
download_url = video_info.get('url')
|
||||
video_data = {
|
||||
'url': download_url,
|
||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||
'filename': video_info.get('filename') or video_info.get('name'),
|
||||
}
|
||||
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
video_base64 = await _safe_download(download_url)
|
||||
if video_base64:
|
||||
video_data['base64'] = video_base64
|
||||
message_data['video'] = video_data
|
||||
elif msg_type == 'file':
|
||||
file_info = msg_json.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
file_base64 = await _safe_download(download_url)
|
||||
if file_base64:
|
||||
file_data['base64'] = file_base64
|
||||
message_data['file'] = file_data
|
||||
elif msg_type == 'link':
|
||||
message_data['link'] = msg_json.get('link', {})
|
||||
if not message_data.get('content'):
|
||||
title = message_data['link'].get('title', '')
|
||||
desc = message_data['link'].get('description') or message_data['link'].get('digest', '')
|
||||
message_data['content'] = '\n'.join(filter(None, [title, desc]))
|
||||
elif msg_type == 'mixed':
|
||||
items = msg_json.get('mixed', {}).get('msg_item', [])
|
||||
texts = []
|
||||
picurl = None
|
||||
images = []
|
||||
files = []
|
||||
voices = []
|
||||
videos = []
|
||||
links = []
|
||||
for item in items:
|
||||
if item.get('msgtype') == 'text':
|
||||
item_type = item.get('msgtype')
|
||||
if item_type == 'text':
|
||||
texts.append(item.get('text', {}).get('content', ''))
|
||||
elif item.get('msgtype') == 'image' and picurl is None:
|
||||
picurl = item.get('image', {}).get('url')
|
||||
elif item_type == 'image':
|
||||
img_url = item.get('image', {}).get('url')
|
||||
base64_data = await _safe_download(img_url)
|
||||
if base64_data:
|
||||
images.append(base64_data)
|
||||
elif item_type == 'file':
|
||||
file_info = item.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
file_base64 = await _safe_download(download_url)
|
||||
if file_base64:
|
||||
file_data['base64'] = file_base64
|
||||
files.append(file_data)
|
||||
elif item_type == 'voice':
|
||||
voice_info = item.get('voice', {}) or {}
|
||||
download_url = voice_info.get('url')
|
||||
voice_data = {
|
||||
'url': download_url,
|
||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||
}
|
||||
if voice_info.get('content'):
|
||||
texts.append(voice_info.get('content'))
|
||||
if (voice_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
voice_base64 = await _safe_download(download_url)
|
||||
if voice_base64:
|
||||
voice_data['base64'] = voice_base64
|
||||
voices.append(voice_data)
|
||||
elif item_type == 'video':
|
||||
video_info = item.get('video', {}) or {}
|
||||
download_url = video_info.get('url')
|
||||
video_data = {
|
||||
'url': download_url,
|
||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||
'filename': video_info.get('filename') or video_info.get('name'),
|
||||
}
|
||||
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
video_base64 = await _safe_download(download_url)
|
||||
if video_base64:
|
||||
video_data['base64'] = video_base64
|
||||
videos.append(video_data)
|
||||
elif item_type == 'link':
|
||||
links.append(item.get('link', {}))
|
||||
|
||||
if texts:
|
||||
message_data['content'] = ''.join(texts) # 拼接所有 text
|
||||
if picurl:
|
||||
base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey)
|
||||
message_data['picurl'] = base64 # 只保留第一个 image
|
||||
message_data['content'] = ' '.join(texts) # 拼接所有 text
|
||||
if images:
|
||||
message_data['images'] = images
|
||||
message_data['picurl'] = images[0] # 只保留第一个 image
|
||||
if files:
|
||||
message_data['files'] = files
|
||||
message_data['file'] = files[0]
|
||||
if voices:
|
||||
message_data['voices'] = voices
|
||||
message_data['voice'] = voices[0]
|
||||
if videos:
|
||||
message_data['videos'] = videos
|
||||
message_data['video'] = videos[0]
|
||||
if links:
|
||||
message_data['link'] = links[0]
|
||||
if items:
|
||||
message_data['attachments'] = items
|
||||
else:
|
||||
message_data['raw_msg'] = msg_json
|
||||
|
||||
# Extract user information
|
||||
from_info = msg_json.get('from', {})
|
||||
|
||||
@@ -17,6 +17,13 @@ class WecomBotEvent(dict):
|
||||
"""
|
||||
return self.get('type', '')
|
||||
|
||||
@property
|
||||
def msgtype(self) -> str:
|
||||
"""
|
||||
消息 msgtype
|
||||
"""
|
||||
return self.get('msgtype', '')
|
||||
|
||||
@property
|
||||
def userid(self) -> str:
|
||||
"""
|
||||
@@ -29,12 +36,7 @@ class WecomBotEvent(dict):
|
||||
"""
|
||||
用户名称
|
||||
"""
|
||||
return (
|
||||
self.get('username', '')
|
||||
or self.get('from', {}).get('alias', '')
|
||||
or self.get('from', {}).get('name', '')
|
||||
or self.userid
|
||||
)
|
||||
return self.get('username', '') or self.get('from', {}).get('alias', '') or self.get('from', {}).get('name', '') or self.userid
|
||||
|
||||
@property
|
||||
def chatname(self) -> str:
|
||||
@@ -57,6 +59,55 @@ class WecomBotEvent(dict):
|
||||
"""
|
||||
return self.get('picurl', '')
|
||||
|
||||
@property
|
||||
def images(self):
|
||||
"""
|
||||
图片列表(兼容 mixed)
|
||||
"""
|
||||
return self.get('images', [])
|
||||
|
||||
@property
|
||||
def file(self):
|
||||
"""
|
||||
文件信息
|
||||
"""
|
||||
return self.get('file', {})
|
||||
|
||||
@property
|
||||
def voice(self):
|
||||
"""
|
||||
语音信息
|
||||
"""
|
||||
return self.get('voice', {})
|
||||
|
||||
@property
|
||||
def video(self):
|
||||
"""
|
||||
视频信息
|
||||
"""
|
||||
return self.get('video', {})
|
||||
|
||||
@property
|
||||
def link(self):
|
||||
"""
|
||||
链接消息信息
|
||||
"""
|
||||
return self.get('link', {})
|
||||
|
||||
@property
|
||||
def location(self):
|
||||
"""
|
||||
位置信息
|
||||
"""
|
||||
return self.get('location', {})
|
||||
|
||||
@property
|
||||
def attachments(self):
|
||||
"""
|
||||
原始 mixed 中的附件项
|
||||
"""
|
||||
return self.get('attachments', [])
|
||||
|
||||
@property
|
||||
def chatid(self) -> str:
|
||||
"""
|
||||
@@ -70,7 +121,7 @@ class WecomBotEvent(dict):
|
||||
消息id
|
||||
"""
|
||||
return self.get('msgid', '')
|
||||
|
||||
|
||||
@property
|
||||
def ai_bot_id(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -21,6 +21,7 @@ class WecomClient:
|
||||
EncodingAESKey: str,
|
||||
contacts_secret: str,
|
||||
logger: None,
|
||||
unified_mode: bool = False,
|
||||
):
|
||||
self.corpid = corpid
|
||||
self.secret = secret
|
||||
@@ -31,13 +32,18 @@ class WecomClient:
|
||||
self.access_token = ''
|
||||
self.secret_for_contacts = contacts_secret
|
||||
self.logger = logger
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
@@ -133,6 +139,58 @@ class WecomClient:
|
||||
await self.logger.error(f'发送图片失败:{data}')
|
||||
raise Exception('Failed to send image: ' + str(data))
|
||||
|
||||
async def send_voice(self, user_id: str, agent_id: int, media_id: str):
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
url = self.base_url + '/message/send?access_token=' + self.access_token
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = {
|
||||
'touser': user_id,
|
||||
'msgtype': 'voice',
|
||||
'agentid': agent_id,
|
||||
'voice': {
|
||||
'media_id': media_id,
|
||||
},
|
||||
'safe': 0,
|
||||
'enable_id_trans': 0,
|
||||
'enable_duplicate_check': 0,
|
||||
'duplicate_check_interval': 1800,
|
||||
}
|
||||
response = await client.post(url, json=params)
|
||||
data = response.json()
|
||||
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.send_voice(user_id, agent_id, media_id)
|
||||
if data['errcode'] != 0:
|
||||
await self.logger.error(f'发送语音失败:{data}')
|
||||
raise Exception('Failed to send voice: ' + str(data))
|
||||
|
||||
async def send_file(self, user_id: str, agent_id: int, media_id: str):
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
url = self.base_url + '/message/send?access_token=' + self.access_token
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = {
|
||||
'touser': user_id,
|
||||
'msgtype': 'file',
|
||||
'agentid': agent_id,
|
||||
'file': {
|
||||
'media_id': media_id,
|
||||
},
|
||||
'safe': 0,
|
||||
'enable_id_trans': 0,
|
||||
'enable_duplicate_check': 0,
|
||||
'duplicate_check_interval': 1800,
|
||||
}
|
||||
response = await client.post(url, json=params)
|
||||
data = response.json()
|
||||
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.send_file(user_id, agent_id, media_id)
|
||||
if data['errcode'] != 0:
|
||||
await self.logger.error(f'发送文件失败:{data}')
|
||||
raise Exception('Failed to send file: ' + str(data))
|
||||
|
||||
async def send_private_msg(self, user_id: str, agent_id: int, content: str):
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
@@ -161,25 +219,43 @@ class WecomClient:
|
||||
raise Exception('Failed to send message: ' + str(data))
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求(独立端口模式,使用全局 request)。"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
处理回调请求,包括 GET 验证和 POST 消息接收。
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""
|
||||
处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
msg_signature = request.args.get('msg_signature')
|
||||
timestamp = request.args.get('timestamp')
|
||||
nonce = request.args.get('nonce')
|
||||
msg_signature = req.args.get('msg_signature')
|
||||
timestamp = req.args.get('timestamp')
|
||||
nonce = req.args.get('nonce')
|
||||
|
||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
|
||||
if request.method == 'GET':
|
||||
echostr = request.args.get('echostr')
|
||||
if req.method == 'GET':
|
||||
echostr = req.args.get('echostr')
|
||||
ret, reply_echo_str = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
||||
if ret != 0:
|
||||
await self.logger.error('验证失败')
|
||||
raise Exception(f'验证失败,错误码: {ret}')
|
||||
return reply_echo_str
|
||||
|
||||
elif request.method == 'POST':
|
||||
encrypt_msg = await request.data
|
||||
elif req.method == 'POST':
|
||||
encrypt_msg = await req.data
|
||||
ret, xml_msg = wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
|
||||
if ret != 0:
|
||||
await self.logger.error('消息解密失败')
|
||||
@@ -263,7 +339,7 @@ class WecomClient:
|
||||
return ext
|
||||
return 'jpg' # 默认返回jpg
|
||||
|
||||
async def upload_to_work(self, image: platform_message.Image):
|
||||
async def upload_image_to_work(self, image: platform_message.Image):
|
||||
"""
|
||||
获取 media_id
|
||||
"""
|
||||
@@ -280,7 +356,7 @@ class WecomClient:
|
||||
file_bytes = await f.read()
|
||||
file_name = image.path.split('/')[-1]
|
||||
elif image.url:
|
||||
file_bytes = await self.download_image_to_bytes(image.url)
|
||||
file_bytes = await self.download_media_to_bytes(image.url)
|
||||
file_name = image.url.split('/')[-1]
|
||||
elif image.base64:
|
||||
try:
|
||||
@@ -315,7 +391,7 @@ class WecomClient:
|
||||
data = response.json()
|
||||
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
media_id = await self.upload_to_work(image)
|
||||
media_id = await self.upload_image_to_work(image)
|
||||
if data.get('errcode', 0) != 0:
|
||||
await self.logger.error(f'上传图片失败:{data}')
|
||||
raise Exception('failed to upload file')
|
||||
@@ -323,13 +399,128 @@ class WecomClient:
|
||||
media_id = data.get('media_id')
|
||||
return media_id
|
||||
|
||||
async def download_image_to_bytes(self, url: str) -> bytes:
|
||||
async def upload_voice_to_work(self, voice: platform_message.Voice):
|
||||
"""
|
||||
上传语音文件到企业微信
|
||||
"""
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file'
|
||||
file_bytes = None
|
||||
file_name = 'voice.mp3'
|
||||
|
||||
if voice.path:
|
||||
async with aiofiles.open(voice.path, 'rb') as f:
|
||||
file_bytes = await f.read()
|
||||
file_name = voice.path.split('/')[-1]
|
||||
elif voice.url:
|
||||
file_bytes = await self.download_media_to_bytes(voice.url)
|
||||
file_name = voice.url.split('/')[-1]
|
||||
elif voice.base64:
|
||||
try:
|
||||
base64_data = voice.base64
|
||||
if ',' in base64_data:
|
||||
base64_data = base64_data.split(',', 1)[1]
|
||||
padding = 4 - (len(base64_data) % 4) if len(base64_data) % 4 else 0
|
||||
padded_base64 = base64_data + '=' * padding
|
||||
file_bytes = base64.b64decode(padded_base64)
|
||||
except binascii.Error as e:
|
||||
raise ValueError(f'Invalid base64 string: {str(e)}')
|
||||
else:
|
||||
await self.logger.error('Voice对象出错')
|
||||
raise ValueError('voice对象出错')
|
||||
|
||||
boundary = '-------------------------acebdf13572468'
|
||||
headers = {'Content-Type': f'multipart/form-data; boundary={boundary}'}
|
||||
body = (
|
||||
(
|
||||
f'--{boundary}\r\n'
|
||||
f'Content-Disposition: form-data; name="media"; filename="{file_name}"; filelength={len(file_bytes)}\r\n'
|
||||
f'Content-Type: application/octet-stream\r\n\r\n'
|
||||
).encode('utf-8')
|
||||
+ file_bytes
|
||||
+ f'\r\n--{boundary}--\r\n'.encode('utf-8')
|
||||
)
|
||||
|
||||
# print(body)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, headers=headers, content=body)
|
||||
data = response.json()
|
||||
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
media_id = await self.upload_voice_to_work(voice)
|
||||
if data.get('errcode', 0) != 0:
|
||||
await self.logger.error(f'上传语音文件失败:{data}')
|
||||
raise Exception('failed to upload file')
|
||||
media_id = data.get('media_id')
|
||||
return media_id
|
||||
|
||||
async def upload_file_to_work(self, file: platform_message.File):
|
||||
"""
|
||||
上传文件到企业微信
|
||||
"""
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file'
|
||||
file_bytes = None
|
||||
file_name = 'file.txt'
|
||||
if file.path:
|
||||
async with aiofiles.open(file.path, 'rb') as f:
|
||||
file_bytes = await f.read()
|
||||
file_name = file.path.split('/')[-1]
|
||||
elif file.url:
|
||||
file_bytes = await self.download_media_to_bytes(file.url)
|
||||
file_name = file.url.split('/')[-1]
|
||||
elif file.base64:
|
||||
try:
|
||||
base64_data = file.base64
|
||||
if ',' in base64_data:
|
||||
base64_data = base64_data.split(',', 1)[1]
|
||||
padding = 4 - (len(base64_data) % 4) if len(base64_data) % 4 else 0
|
||||
padded_base64 = base64_data + '=' * padding
|
||||
file_bytes = base64.b64decode(padded_base64)
|
||||
except binascii.Error as e:
|
||||
raise ValueError(f'Invalid base64 string: {str(e)}')
|
||||
else:
|
||||
await self.logger.error('File对象出错')
|
||||
raise ValueError('file对象出错')
|
||||
boundary = '-------------------------acebdf13572468'
|
||||
headers = {'Content-Type': f'multipart/form-data; boundary={boundary}'}
|
||||
body = (
|
||||
(
|
||||
f'--{boundary}\r\n'
|
||||
f'Content-Disposition: form-data; name="media"; filename="{file_name}"; filelength={len(file_bytes)}\r\n'
|
||||
f'Content-Type: application/octet-stream\r\n\r\n'
|
||||
).encode('utf-8')
|
||||
+ file_bytes
|
||||
+ f'\r\n--{boundary}--\r\n'.encode('utf-8')
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, headers=headers, content=body)
|
||||
data = response.json()
|
||||
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
media_id = await self.upload_file_to_work(file)
|
||||
if data.get('errcode', 0) != 0:
|
||||
await self.logger.error(f'上传文件失败:{data}')
|
||||
raise Exception('failed to upload file')
|
||||
media_id = data.get('media_id')
|
||||
return media_id
|
||||
|
||||
async def download_media_to_bytes(self, url: str) -> bytes:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
# 进行media_id的获取
|
||||
async def get_media_id(self, image: platform_message.Image):
|
||||
media_id = await self.upload_to_work(image=image)
|
||||
async def get_media_id(self, media: platform_message.Image | platform_message.Voice | platform_message.File):
|
||||
if isinstance(media, platform_message.Image):
|
||||
media_id = await self.upload_image_to_work(image=media)
|
||||
elif isinstance(media, platform_message.Voice):
|
||||
media_id = await self.upload_voice_to_work(voice=media)
|
||||
elif isinstance(media, platform_message.File):
|
||||
media_id = await self.upload_file_to_work(file=media)
|
||||
else:
|
||||
raise ValueError('Unsupported media type')
|
||||
return media_id
|
||||
|
||||
@@ -13,7 +13,7 @@ import aiofiles
|
||||
|
||||
|
||||
class WecomCSClient:
|
||||
def __init__(self, corpid: str, secret: str, token: str, EncodingAESKey: str, logger: None):
|
||||
def __init__(self, corpid: str, secret: str, token: str, EncodingAESKey: str, logger: None, unified_mode: bool = False):
|
||||
self.corpid = corpid
|
||||
self.secret = secret
|
||||
self.access_token_for_contacts = ''
|
||||
@@ -22,10 +22,15 @@ class WecomCSClient:
|
||||
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
|
||||
self.access_token = ''
|
||||
self.logger = logger
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
|
||||
)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
|
||||
)
|
||||
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
@@ -192,27 +197,45 @@ class WecomCSClient:
|
||||
return data
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求(独立端口模式,使用全局 request)。"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
处理回调请求,包括 GET 验证和 POST 消息接收。
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""
|
||||
处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
msg_signature = request.args.get('msg_signature')
|
||||
timestamp = request.args.get('timestamp')
|
||||
nonce = request.args.get('nonce')
|
||||
msg_signature = req.args.get('msg_signature')
|
||||
timestamp = req.args.get('timestamp')
|
||||
nonce = req.args.get('nonce')
|
||||
try:
|
||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
|
||||
except Exception as e:
|
||||
raise Exception(f'初始化失败,错误码: {e}')
|
||||
|
||||
if request.method == 'GET':
|
||||
echostr = request.args.get('echostr')
|
||||
if req.method == 'GET':
|
||||
echostr = req.args.get('echostr')
|
||||
ret, reply_echo_str = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
||||
if ret != 0:
|
||||
raise Exception(f'验证失败,错误码: {ret}')
|
||||
return reply_echo_str
|
||||
|
||||
elif request.method == 'POST':
|
||||
encrypt_msg = await request.data
|
||||
elif req.method == 'POST':
|
||||
encrypt_msg = await req.data
|
||||
ret, xml_msg = wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
|
||||
if ret != 0:
|
||||
raise Exception(f'消息解密失败,错误码: {ret}')
|
||||
|
||||
@@ -18,7 +18,8 @@ class BotsRouterGroup(group.RouterGroup):
|
||||
@self.route('/<bot_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _(bot_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
bot = await self.ap.bot_service.get_bot(bot_uuid)
|
||||
# 返回运行时信息,包括webhook地址等
|
||||
bot = await self.ap.bot_service.get_runtime_bot_info(bot_uuid)
|
||||
if bot is None:
|
||||
return self.http_status(404, -1, 'bot not found')
|
||||
return self.success(data={'bot': bot})
|
||||
|
||||
@@ -21,6 +21,22 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
|
||||
return self.success(data={'plugins': plugins})
|
||||
|
||||
@self.route('/debug-info', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
"""Get plugin debug information including debug URL and key"""
|
||||
debug_info = await self.ap.plugin_connector.get_debug_info()
|
||||
|
||||
# Get debug URL from config
|
||||
plugin_config = self.ap.instance_config.data.get('plugin', {})
|
||||
debug_url = plugin_config.get('display_plugin_debug_url', 'http://localhost:5401')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'debug_url': debug_url,
|
||||
'plugin_debug_key': debug_info.get('plugin_debug_key', ''),
|
||||
}
|
||||
)
|
||||
|
||||
@self.route(
|
||||
'/<author>/<plugin_name>/upgrade',
|
||||
methods=['POST'],
|
||||
|
||||
49
src/langbot/pkg/api/http/controller/groups/webhook_mgmt.py
Normal file
49
src/langbot/pkg/api/http/controller/groups/webhook_mgmt.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import quart
|
||||
|
||||
from .. import group
|
||||
|
||||
|
||||
@group.group_class('webhook_mgmt', '/api/v1/webhooks')
|
||||
class WebhookManagementRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET', 'POST'])
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
webhooks = await self.ap.webhook_service.get_webhooks()
|
||||
return self.success(data={'webhooks': webhooks})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
name = json_data.get('name', '')
|
||||
url = json_data.get('url', '')
|
||||
description = json_data.get('description', '')
|
||||
enabled = json_data.get('enabled', True)
|
||||
|
||||
if not name:
|
||||
return self.http_status(400, -1, 'Name is required')
|
||||
if not url:
|
||||
return self.http_status(400, -1, 'URL is required')
|
||||
|
||||
webhook = await self.ap.webhook_service.create_webhook(name, url, description, enabled)
|
||||
return self.success(data={'webhook': webhook})
|
||||
|
||||
@self.route('/<int:webhook_id>', methods=['GET', 'PUT', 'DELETE'])
|
||||
async def _(webhook_id: int) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
webhook = await self.ap.webhook_service.get_webhook(webhook_id)
|
||||
if webhook is None:
|
||||
return self.http_status(404, -1, 'Webhook not found')
|
||||
return self.success(data={'webhook': webhook})
|
||||
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
name = json_data.get('name')
|
||||
url = json_data.get('url')
|
||||
description = json_data.get('description')
|
||||
enabled = json_data.get('enabled')
|
||||
|
||||
await self.ap.webhook_service.update_webhook(webhook_id, name, url, description, enabled)
|
||||
return self.success()
|
||||
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.webhook_service.delete_webhook(webhook_id)
|
||||
return self.success()
|
||||
@@ -1,49 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import quart
|
||||
import traceback
|
||||
|
||||
from .. import group
|
||||
|
||||
|
||||
@group.group_class('webhooks', '/api/v1/webhooks')
|
||||
class WebhooksRouterGroup(group.RouterGroup):
|
||||
@group.group_class('webhooks', '/bots')
|
||||
class WebhookRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET', 'POST'])
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
webhooks = await self.ap.webhook_service.get_webhooks()
|
||||
return self.success(data={'webhooks': webhooks})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
name = json_data.get('name', '')
|
||||
url = json_data.get('url', '')
|
||||
description = json_data.get('description', '')
|
||||
enabled = json_data.get('enabled', True)
|
||||
@self.route('/<bot_uuid>', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
|
||||
async def handle_webhook(bot_uuid: str):
|
||||
"""处理 bot webhook 回调(无子路径)"""
|
||||
return await self._dispatch_webhook(bot_uuid, '')
|
||||
|
||||
if not name:
|
||||
return self.http_status(400, -1, 'Name is required')
|
||||
if not url:
|
||||
return self.http_status(400, -1, 'URL is required')
|
||||
@self.route('/<bot_uuid>/<path:path>', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
|
||||
async def handle_webhook_with_path(bot_uuid: str, path: str):
|
||||
"""处理 bot webhook 回调(带子路径)"""
|
||||
return await self._dispatch_webhook(bot_uuid, path)
|
||||
|
||||
webhook = await self.ap.webhook_service.create_webhook(name, url, description, enabled)
|
||||
return self.success(data={'webhook': webhook})
|
||||
async def _dispatch_webhook(self, bot_uuid: str, path: str):
|
||||
"""分发 webhook 请求到对应的 bot adapter
|
||||
|
||||
@self.route('/<int:webhook_id>', methods=['GET', 'PUT', 'DELETE'])
|
||||
async def _(webhook_id: int) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
webhook = await self.ap.webhook_service.get_webhook(webhook_id)
|
||||
if webhook is None:
|
||||
return self.http_status(404, -1, 'Webhook not found')
|
||||
return self.success(data={'webhook': webhook})
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
name = json_data.get('name')
|
||||
url = json_data.get('url')
|
||||
description = json_data.get('description')
|
||||
enabled = json_data.get('enabled')
|
||||
Returns:
|
||||
适配器返回的响应
|
||||
"""
|
||||
try:
|
||||
|
||||
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
|
||||
|
||||
await self.ap.webhook_service.update_webhook(webhook_id, name, url, description, enabled)
|
||||
return self.success()
|
||||
if not runtime_bot:
|
||||
return quart.jsonify({'error': 'Bot not found'}), 404
|
||||
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.webhook_service.delete_webhook(webhook_id)
|
||||
return self.success()
|
||||
if not runtime_bot.enable:
|
||||
return quart.jsonify({'error': 'Bot is disabled'}), 403
|
||||
|
||||
|
||||
if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'):
|
||||
return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501
|
||||
|
||||
|
||||
response = await runtime_bot.adapter.handle_unified_webhook(
|
||||
bot_uuid=bot_uuid,
|
||||
path=path,
|
||||
request=quart.request,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Webhook dispatch error for bot {bot_uuid}: {traceback.format_exc()}')
|
||||
return quart.jsonify({'error': str(e)}), 500
|
||||
|
||||
@@ -92,7 +92,11 @@ class HTTPController:
|
||||
|
||||
@self.quart_app.route('/')
|
||||
async def index():
|
||||
return await quart.send_from_directory(frontend_path, 'index.html', mimetype='text/html')
|
||||
response = await quart.send_from_directory(frontend_path, 'index.html', mimetype='text/html')
|
||||
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate'
|
||||
response.headers['Pragma'] = 'no-cache'
|
||||
response.headers['Expires'] = '0'
|
||||
return response
|
||||
|
||||
@self.quart_app.route('/<path:path>')
|
||||
async def static_file(path: str):
|
||||
|
||||
@@ -58,6 +58,16 @@ class BotService:
|
||||
if runtime_bot is not None:
|
||||
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
|
||||
|
||||
# Webhook URL for unified webhook adapters (independent of bot running state)
|
||||
if persistence_bot['adapter'] in ['wecom', 'wecombot', 'officialaccount', 'qqofficial', 'slack', 'wecomcs', 'LINE']:
|
||||
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
||||
webhook_url = f'/bots/{bot_uuid}'
|
||||
adapter_runtime_values['webhook_url'] = webhook_url
|
||||
adapter_runtime_values['webhook_full_url'] = f'{webhook_prefix}{webhook_url}'
|
||||
else:
|
||||
adapter_runtime_values['webhook_url'] = None
|
||||
adapter_runtime_values['webhook_full_url'] = None
|
||||
|
||||
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
|
||||
|
||||
return persistence_bot
|
||||
|
||||
@@ -74,7 +74,16 @@ class KnowledgeService:
|
||||
# Only internal KBs support file storage
|
||||
if runtime_kb.get_type() != 'internal':
|
||||
raise Exception('Only internal knowledge bases support file storage')
|
||||
return await runtime_kb.store_file(file_id)
|
||||
result = await runtime_kb.store_file(file_id)
|
||||
|
||||
# Update the KB's updated_at timestamp
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.KnowledgeBase)
|
||||
.values(updated_at=sqlalchemy.func.now())
|
||||
.where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def retrieve_knowledge_base(self, kb_uuid: str, query: str) -> list[dict]:
|
||||
"""检索知识库"""
|
||||
@@ -103,6 +112,13 @@ class KnowledgeService:
|
||||
raise Exception('Only internal knowledge bases support file deletion')
|
||||
await runtime_kb.delete_file(file_id)
|
||||
|
||||
# Update the KB's updated_at timestamp
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.KnowledgeBase)
|
||||
.values(updated_at=sqlalchemy.func.now())
|
||||
.where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
|
||||
async def delete_knowledge_base(self, kb_uuid: str) -> None:
|
||||
"""删除知识库"""
|
||||
await self.ap.rag_mgr.delete_knowledge_base(kb_uuid)
|
||||
|
||||
@@ -8,6 +8,7 @@ class KnowledgeBase(Base):
|
||||
name = sqlalchemy.Column(sqlalchemy.String, index=True)
|
||||
description = sqlalchemy.Column(sqlalchemy.Text)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now())
|
||||
embedding_model_uuid = sqlalchemy.Column(sqlalchemy.String, default='')
|
||||
top_k = sqlalchemy.Column(sqlalchemy.Integer, default=5)
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ...entity.persistence import pipeline as persistence_pipeline
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(2)
|
||||
@@ -11,30 +10,45 @@ class DBMigrateCombineQuoteMsgConfig(migration.DBMigration):
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
# read all pipelines
|
||||
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
# Read all pipelines using raw SQL
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
for pipeline in pipelines:
|
||||
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
config = serialized_pipeline['config']
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
# Ensure 'trigger' exists
|
||||
if 'trigger' not in config:
|
||||
config['trigger'] = {}
|
||||
|
||||
# Ensure 'misc' exists in 'trigger'
|
||||
if 'misc' not in config['trigger']:
|
||||
config['trigger']['misc'] = {}
|
||||
|
||||
# Add 'combine-quote-message' if not exists
|
||||
if 'combine-quote-message' not in config['trigger']['misc']:
|
||||
config['trigger']['misc']['combine-quote-message'] = False
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
|
||||
.values(
|
||||
{
|
||||
'config': config,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
}
|
||||
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ...entity.persistence import pipeline as persistence_pipeline
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(3)
|
||||
@@ -11,14 +10,23 @@ class DBMigrateN8nConfig(migration.DBMigration):
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
# read all pipelines
|
||||
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
# Read all pipelines using raw SQL
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
for pipeline in pipelines:
|
||||
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
config = serialized_pipeline['config']
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
# Ensure 'ai' exists
|
||||
if 'ai' not in config:
|
||||
config['ai'] = {}
|
||||
|
||||
# Add 'n8n-service-api' if not exists
|
||||
if 'n8n-service-api' not in config['ai']:
|
||||
config['ai']['n8n-service-api'] = {
|
||||
'webhook-url': 'http://your-n8n-webhook-url',
|
||||
@@ -33,16 +41,21 @@ class DBMigrateN8nConfig(migration.DBMigration):
|
||||
'output-key': 'response',
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
|
||||
.values(
|
||||
{
|
||||
'config': config,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
}
|
||||
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ...entity.persistence import pipeline as persistence_pipeline
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(4)
|
||||
@@ -11,27 +10,43 @@ class DBMigrateRAGKBUUID(migration.DBMigration):
|
||||
|
||||
async def upgrade(self):
|
||||
"""升级"""
|
||||
# read all pipelines
|
||||
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
# Read all pipelines using raw SQL
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
for pipeline in pipelines:
|
||||
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
config = serialized_pipeline['config']
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
# Ensure nested structure exists
|
||||
if 'ai' not in config:
|
||||
config['ai'] = {}
|
||||
if 'local-agent' not in config['ai']:
|
||||
config['ai']['local-agent'] = {}
|
||||
|
||||
# Add 'knowledge-base' if not exists
|
||||
if 'knowledge-base' not in config['ai']['local-agent']:
|
||||
config['ai']['local-agent']['knowledge-base'] = ''
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
|
||||
.values(
|
||||
{
|
||||
'config': config,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
}
|
||||
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""降级"""
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ...entity.persistence import pipeline as persistence_pipeline
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(5)
|
||||
@@ -11,27 +10,43 @@ class DBMigratePipelineRemoveCotConfig(migration.DBMigration):
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
# read all pipelines
|
||||
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
# Read all pipelines using raw SQL
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
for pipeline in pipelines:
|
||||
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
config = serialized_pipeline['config']
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
# Ensure nested structure exists
|
||||
if 'output' not in config:
|
||||
config['output'] = {}
|
||||
if 'misc' not in config['output']:
|
||||
config['output']['misc'] = {}
|
||||
|
||||
# Add 'remove-think' if not exists
|
||||
if 'remove-think' not in config['output']['misc']:
|
||||
config['output']['misc']['remove-think'] = False
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
|
||||
.values(
|
||||
{
|
||||
'config': config,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
}
|
||||
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ...entity.persistence import pipeline as persistence_pipeline
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(6)
|
||||
@@ -11,14 +10,23 @@ class DBMigrateLangflowApiConfig(migration.DBMigration):
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
# read all pipelines
|
||||
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
# Read all pipelines using raw SQL
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
for pipeline in pipelines:
|
||||
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
config = serialized_pipeline['config']
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
# Ensure 'ai' exists
|
||||
if 'ai' not in config:
|
||||
config['ai'] = {}
|
||||
|
||||
# Add 'langflow-api' if not exists
|
||||
if 'langflow-api' not in config['ai']:
|
||||
config['ai']['langflow-api'] = {
|
||||
'base-url': 'http://localhost:7860',
|
||||
@@ -29,16 +37,21 @@ class DBMigrateLangflowApiConfig(migration.DBMigration):
|
||||
'tweaks': '{}',
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
|
||||
.values(
|
||||
{
|
||||
'config': config,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
}
|
||||
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ...entity.persistence import pipeline as persistence_pipeline
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(10)
|
||||
@@ -11,16 +10,20 @@ class DBMigratePipelineMultiKnowledgeBase(migration.DBMigration):
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
# read all pipelines
|
||||
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
# Read all pipelines using raw SQL
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
for pipeline in pipelines:
|
||||
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
config = serialized_pipeline['config']
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
# Convert knowledge-base from string to array
|
||||
if 'local-agent' in config['ai']:
|
||||
if 'ai' in config and 'local-agent' in config['ai']:
|
||||
current_kb = config['ai']['local-agent'].get('knowledge-base', '')
|
||||
|
||||
# If it's already a list, skip
|
||||
@@ -37,29 +40,38 @@ class DBMigratePipelineMultiKnowledgeBase(migration.DBMigration):
|
||||
if 'knowledge-base' in config['ai']['local-agent']:
|
||||
del config['ai']['local-agent']['knowledge-base']
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
|
||||
.values(
|
||||
{
|
||||
'config': config,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
}
|
||||
)
|
||||
)
|
||||
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
# read all pipelines
|
||||
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
# Read all pipelines using raw SQL
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
for pipeline in pipelines:
|
||||
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
config = serialized_pipeline['config']
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
# Convert knowledge-bases from array back to string
|
||||
if 'local-agent' in config['ai']:
|
||||
if 'ai' in config and 'local-agent' in config['ai']:
|
||||
current_kbs = config['ai']['local-agent'].get('knowledge-bases', [])
|
||||
|
||||
# If it's already a string, skip
|
||||
@@ -76,13 +88,18 @@ class DBMigratePipelineMultiKnowledgeBase(migration.DBMigration):
|
||||
if 'knowledge-bases' in config['ai']['local-agent']:
|
||||
del config['ai']['local-agent']['knowledge-bases']
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
|
||||
.values(
|
||||
{
|
||||
'config': config,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
}
|
||||
)
|
||||
)
|
||||
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ...entity.persistence import pipeline as persistence_pipeline
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(11)
|
||||
@@ -11,29 +10,45 @@ class DBMigrateDifyApiConfig(migration.DBMigration):
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
# read all pipelines
|
||||
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
# Read all pipelines using raw SQL
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
for pipeline in pipelines:
|
||||
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
config = serialized_pipeline['config']
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
# Ensure nested structure exists
|
||||
if 'ai' not in config:
|
||||
config['ai'] = {}
|
||||
if 'dify-service-api' not in config['ai']:
|
||||
config['ai']['dify-service-api'] = {}
|
||||
|
||||
# Add 'base-prompt' if not exists
|
||||
if 'base-prompt' not in config['ai']['dify-service-api']:
|
||||
config['ai']['dify-service-api']['base-prompt'] = (
|
||||
'When the file content is readable, please read the content of this file. When the file is an image, describe the content of this image.',
|
||||
)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
|
||||
.values(
|
||||
{
|
||||
'config': config,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
}
|
||||
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ...entity.persistence import pipeline as persistence_pipeline
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(12)
|
||||
@@ -11,14 +10,25 @@ class DBMigratePipelineExtensionsEnableAll(migration.DBMigration):
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
# read all pipelines
|
||||
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
# Read all pipelines using raw SQL
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, extensions_preferences FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
for pipeline in pipelines:
|
||||
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
extensions_preferences = serialized_pipeline['extensions_preferences']
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
extensions_preferences = (
|
||||
json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
)
|
||||
|
||||
# Ensure extensions_preferences is a dict
|
||||
if extensions_preferences is None:
|
||||
extensions_preferences = {}
|
||||
|
||||
# Add 'enable_all_plugins' if not exists
|
||||
if 'enable_all_plugins' not in extensions_preferences:
|
||||
if 'plugins' in extensions_preferences:
|
||||
extensions_preferences['enable_all_plugins'] = False
|
||||
@@ -26,6 +36,7 @@ class DBMigratePipelineExtensionsEnableAll(migration.DBMigration):
|
||||
extensions_preferences['enable_all_plugins'] = True
|
||||
extensions_preferences['plugins'] = []
|
||||
|
||||
# Add 'enable_all_mcp_servers' if not exists
|
||||
if 'enable_all_mcp_servers' not in extensions_preferences:
|
||||
if 'mcp_servers' in extensions_preferences:
|
||||
extensions_preferences['enable_all_mcp_servers'] = False
|
||||
@@ -33,14 +44,29 @@ class DBMigratePipelineExtensionsEnableAll(migration.DBMigration):
|
||||
extensions_preferences['enable_all_mcp_servers'] = True
|
||||
extensions_preferences['mcp_servers'] = []
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
|
||||
.values(
|
||||
extensions_preferences=extensions_preferences,
|
||||
for_version=self.ap.ver_mgr.get_current_version(),
|
||||
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET extensions_preferences = :extensions_preferences::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{
|
||||
'extensions_preferences': json.dumps(extensions_preferences),
|
||||
'for_version': current_version,
|
||||
'uuid': uuid,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET extensions_preferences = :extensions_preferences, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{
|
||||
'extensions_preferences': json.dumps(extensions_preferences),
|
||||
'for_version': current_version,
|
||||
'uuid': uuid,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(13)
|
||||
class DBMigrateKnowledgeBaseUpdatedAt(migration.DBMigration):
|
||||
"""Add updated_at field to knowledge_bases table"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
# Get all column names from the table
|
||||
columns = []
|
||||
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
"SELECT column_name FROM information_schema.columns WHERE table_name = 'knowledge_bases';"
|
||||
)
|
||||
)
|
||||
all_result = result.fetchall()
|
||||
columns = [row[0] for row in all_result]
|
||||
else:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text('PRAGMA table_info(knowledge_bases);'))
|
||||
all_result = result.fetchall()
|
||||
columns = [row[1] for row in all_result]
|
||||
|
||||
# Check and add updated_at column
|
||||
if 'updated_at' not in columns:
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'ALTER TABLE knowledge_bases ADD COLUMN updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP'
|
||||
)
|
||||
)
|
||||
else:
|
||||
# SQLite doesn't support DEFAULT CURRENT_TIMESTAMP in ALTER TABLE
|
||||
# Add column without default first
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE knowledge_bases ADD COLUMN updated_at DATETIME')
|
||||
)
|
||||
|
||||
# Set initial updated_at values to created_at for existing records
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('UPDATE knowledge_bases SET updated_at = created_at WHERE updated_at IS NULL')
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
pass
|
||||
@@ -237,6 +237,7 @@ class RuntimePipeline:
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
message_event=query.message_event,
|
||||
message_chain=query.message_chain,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ from .. import stage, entities
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
import langbot_plugin.api.entities.events as events
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
|
||||
|
||||
@stage.stage_class('PreProcessor')
|
||||
@@ -75,18 +75,25 @@ class PreProcessor(stage.PipelineStage):
|
||||
self.ap.logger.debug(f'Bound MCP servers: {bound_mcp_servers}')
|
||||
self.ap.logger.debug(f'Use funcs: {query.use_funcs}')
|
||||
|
||||
# Extract sender name from message event
|
||||
sender_name = ''
|
||||
if isinstance(query.message_event, (platform_events.FriendMessage, platform_events.GroupMessage)):
|
||||
sender_name = query.message_event.sender.get_name()
|
||||
|
||||
if isinstance(query.message_event, platform_events.GroupMessage):
|
||||
sender_name = query.message_event.sender.member_name
|
||||
elif isinstance(query.message_event, platform_events.FriendMessage):
|
||||
sender_name = query.message_event.sender.nickname
|
||||
|
||||
variables = {
|
||||
'launcher_type': query.session.launcher_type.value,
|
||||
'launcher_id': query.session.launcher_id,
|
||||
'sender_id': query.sender_id,
|
||||
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
'conversation_id': conversation.uuid,
|
||||
'msg_create_time': (
|
||||
int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp())
|
||||
),
|
||||
'sender_id': str(query.sender_id),
|
||||
'group_name': query.message_event.group.name
|
||||
if isinstance(query.message_event, platform_events.GroupMessage)
|
||||
else '',
|
||||
'sender_name': sender_name,
|
||||
}
|
||||
query.variables.update(variables)
|
||||
@@ -119,6 +126,12 @@ class PreProcessor(stage.PipelineStage):
|
||||
):
|
||||
if me.base64 is not None:
|
||||
content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
|
||||
elif isinstance(me, platform_message.Voice):
|
||||
# 转成文件链接,让下游 runner 上传到目标模型
|
||||
if me.base64:
|
||||
content_list.append(provider_message.ContentElement.from_file_base64(me.base64, 'voice.silk'))
|
||||
elif me.url:
|
||||
content_list.append(provider_message.ContentElement.from_file_url(me.url, 'voice'))
|
||||
elif isinstance(me, platform_message.File):
|
||||
# if me.url is not None:
|
||||
content_list.append(provider_message.ContentElement.from_file_url(me.url, me.name))
|
||||
|
||||
@@ -40,6 +40,7 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
text_message=str(query.message_chain),
|
||||
message_event=query.message_event,
|
||||
message_chain=query.message_chain,
|
||||
query=query,
|
||||
)
|
||||
@@ -75,7 +76,7 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
runner = r(self.ap, query.pipeline_config)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}')
|
||||
raise ValueError(f'Request Runner not found: {query.pipeline_config["ai"]["runner"]["runner"]}')
|
||||
if is_stream:
|
||||
resp_message_id = uuid.uuid4()
|
||||
|
||||
@@ -90,7 +91,9 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
await query.adapter.create_message_card(str(resp_message_id), query.message_event)
|
||||
is_create_card = True
|
||||
query.resp_messages.append(result)
|
||||
self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}')
|
||||
self.ap.logger.info(
|
||||
f'Conversation({query.query_id}) Streaming Response: {self.cut_str(result.readable_str())}'
|
||||
)
|
||||
|
||||
if result.content is not None:
|
||||
text_length += len(result.content)
|
||||
@@ -101,7 +104,9 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
async for result in runner.run(query):
|
||||
query.resp_messages.append(result)
|
||||
|
||||
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
|
||||
self.ap.logger.info(
|
||||
f'Conversation({query.query_id}) Response: {self.cut_str(result.readable_str())}'
|
||||
)
|
||||
|
||||
if result.content is not None:
|
||||
text_length += len(result.content)
|
||||
@@ -112,7 +117,7 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
query.session.using_conversation.messages.extend(query.resp_messages)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
|
||||
self.ap.logger.error(f'Conversation({query.query_id}) Request Failed: {type(e).__name__} {str(e)}')
|
||||
traceback.print_exc()
|
||||
|
||||
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
|
||||
|
||||
@@ -66,22 +66,27 @@ class RuntimeBot:
|
||||
message_session_id=f'person_{event.sender.id}',
|
||||
)
|
||||
|
||||
# Push to webhooks
|
||||
# Push to webhooks and check if pipeline should be skipped
|
||||
skip_pipeline = False
|
||||
if hasattr(self.ap, 'webhook_pusher') and self.ap.webhook_pusher:
|
||||
asyncio.create_task(
|
||||
self.ap.webhook_pusher.push_person_message(event, self.bot_entity.uuid, adapter.__class__.__name__)
|
||||
skip_pipeline = await self.ap.webhook_pusher.push_person_message(
|
||||
event, self.bot_entity.uuid, adapter.__class__.__name__
|
||||
)
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter,
|
||||
pipeline_uuid=self.bot_entity.use_pipeline_uuid,
|
||||
)
|
||||
# Only add to query pool if no webhook requested to skip pipeline
|
||||
if not skip_pipeline:
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter,
|
||||
pipeline_uuid=self.bot_entity.use_pipeline_uuid,
|
||||
)
|
||||
else:
|
||||
await self.logger.info(f'Pipeline skipped for person message due to webhook response')
|
||||
|
||||
async def on_group_message(
|
||||
event: platform_events.GroupMessage,
|
||||
@@ -97,22 +102,27 @@ class RuntimeBot:
|
||||
message_session_id=f'group_{event.group.id}',
|
||||
)
|
||||
|
||||
# Push to webhooks
|
||||
# Push to webhooks and check if pipeline should be skipped
|
||||
skip_pipeline = False
|
||||
if hasattr(self.ap, 'webhook_pusher') and self.ap.webhook_pusher:
|
||||
asyncio.create_task(
|
||||
self.ap.webhook_pusher.push_group_message(event, self.bot_entity.uuid, adapter.__class__.__name__)
|
||||
skip_pipeline = await self.ap.webhook_pusher.push_group_message(
|
||||
event, self.bot_entity.uuid, adapter.__class__.__name__
|
||||
)
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=provider_session.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter,
|
||||
pipeline_uuid=self.bot_entity.use_pipeline_uuid,
|
||||
)
|
||||
# Only add to query pool if no webhook requested to skip pipeline
|
||||
if not skip_pipeline:
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=provider_session.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter,
|
||||
pipeline_uuid=self.bot_entity.use_pipeline_uuid,
|
||||
)
|
||||
else:
|
||||
await self.logger.info(f'Pipeline skipped for group message due to webhook response')
|
||||
|
||||
self.adapter.register_listener(platform_events.FriendMessage, on_friend_message)
|
||||
self.adapter.register_listener(platform_events.GroupMessage, on_group_message)
|
||||
@@ -245,6 +255,10 @@ class PlatformManager:
|
||||
logger,
|
||||
)
|
||||
|
||||
# 如果 adapter 支持 set_bot_uuid 方法,设置 bot_uuid(用于统一 webhook)
|
||||
if hasattr(adapter_inst, 'set_bot_uuid'):
|
||||
adapter_inst.set_bot_uuid(bot_entity.uuid)
|
||||
|
||||
runtime_bot = RuntimeBot(ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst, logger=logger)
|
||||
|
||||
await runtime_bot.initialize()
|
||||
|
||||
@@ -327,9 +327,6 @@ class AiocqhttpEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title=event.sender['title'] if 'title' in event.sender else '',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=yiri_chain,
|
||||
time=event.time,
|
||||
|
||||
@@ -119,9 +119,6 @@ class DingTalkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
)
|
||||
time = event.incoming_message.create_at
|
||||
return platform_events.GroupMessage(
|
||||
|
||||
@@ -8,6 +8,9 @@ import base64
|
||||
import uuid
|
||||
import os
|
||||
import datetime
|
||||
|
||||
# 使用BytesIO创建文件对象,避免路径问题
|
||||
import io
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
|
||||
@@ -594,7 +597,7 @@ class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter
|
||||
break
|
||||
|
||||
text_string = ''
|
||||
image_files = []
|
||||
files = []
|
||||
|
||||
for ele in message_chain:
|
||||
if isinstance(ele, platform_message.Image):
|
||||
@@ -668,22 +671,67 @@ class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter
|
||||
continue # 跳过读取失败的文件
|
||||
|
||||
if image_bytes:
|
||||
# 使用BytesIO创建文件对象,避免路径问题
|
||||
import io
|
||||
|
||||
image_files.append(discord.File(fp=io.BytesIO(image_bytes), filename=filename))
|
||||
files.append(discord.File(fp=io.BytesIO(image_bytes), filename=filename))
|
||||
elif isinstance(ele, platform_message.Plain):
|
||||
text_string += ele.text
|
||||
elif isinstance(ele, platform_message.Voice):
|
||||
file_bytes = None
|
||||
filename = f'{uuid.uuid4()}.mp3'
|
||||
if ele.base64:
|
||||
if ele.base64.startswith('data:'):
|
||||
data_header = ele.base64.split(',')[0]
|
||||
if 'wav' in data_header:
|
||||
filename = f'{uuid.uuid4()}.wav'
|
||||
elif 'mp3' in data_header:
|
||||
filename = f'{uuid.uuid4()}.mp3'
|
||||
elif 'ogg' in data_header:
|
||||
filename = f'{uuid.uuid4()}.ogg'
|
||||
elif 'm4a' in data_header:
|
||||
filename = f'{uuid.uuid4()}.m4a'
|
||||
elif 'aac' in data_header:
|
||||
filename = f'{uuid.uuid4()}.aac'
|
||||
elif 'flac' in data_header:
|
||||
filename = f'{uuid.uuid4()}.flac'
|
||||
elif 'alac' in data_header:
|
||||
filename = f'{uuid.uuid4()}.alac'
|
||||
elif 'opus' in data_header:
|
||||
filename = f'{uuid.uuid4()}.opus'
|
||||
elif 'webm' in data_header:
|
||||
filename = f'{uuid.uuid4()}.webm'
|
||||
|
||||
file_base64 = ele.base64.split(',')[-1]
|
||||
file_bytes = base64.b64decode(file_base64)
|
||||
elif ele.url:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(ele.url) as response:
|
||||
file_bytes = await response.read()
|
||||
if file_bytes:
|
||||
files.append(discord.File(fp=io.BytesIO(file_bytes), filename=filename))
|
||||
elif isinstance(ele, platform_message.File):
|
||||
file_bytes = None
|
||||
filename = f'{uuid.uuid4()}.{ele.name.split(".")[-1]}'
|
||||
if ele.base64:
|
||||
if ele.base64.startswith('data:'):
|
||||
file_base64 = ele.base64.split(',')[1]
|
||||
file_bytes = base64.b64decode(file_base64)
|
||||
else:
|
||||
file_bytes = base64.b64decode(ele.base64)
|
||||
elif ele.url:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(ele.url) as response:
|
||||
file_bytes = await response.read()
|
||||
if file_bytes:
|
||||
files.append(discord.File(fp=io.BytesIO(file_bytes), filename=filename))
|
||||
elif isinstance(ele, platform_message.Forward):
|
||||
for node in ele.node_list:
|
||||
(
|
||||
node_text,
|
||||
node_images,
|
||||
node_files,
|
||||
) = await DiscordMessageConverter.yiri2target(node.message_chain)
|
||||
text_string += node_text
|
||||
image_files.extend(node_images)
|
||||
files.extend(node_files)
|
||||
|
||||
return text_string, image_files
|
||||
return text_string, files
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(message: discord.Message) -> platform_message.MessageChain:
|
||||
@@ -769,9 +817,6 @@ class DiscordEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event.created_at.timestamp(),
|
||||
@@ -993,7 +1038,7 @@ class DiscordAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
await self.voice_manager.cleanup_inactive_connections()
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
msg_to_send, image_files = await self.message_converter.yiri2target(message)
|
||||
msg_to_send, files = await self.message_converter.yiri2target(message)
|
||||
|
||||
try:
|
||||
# 获取频道对象
|
||||
@@ -1006,8 +1051,8 @@ class DiscordAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
'content': msg_to_send,
|
||||
}
|
||||
|
||||
if len(image_files) > 0:
|
||||
args['files'] = image_files
|
||||
if len(files) > 0:
|
||||
args['files'] = files
|
||||
|
||||
await channel.send(**args)
|
||||
|
||||
@@ -1021,15 +1066,16 @@ class DiscordAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
msg_to_send, image_files = await self.message_converter.yiri2target(message)
|
||||
msg_to_send, files = await self.message_converter.yiri2target(message)
|
||||
|
||||
assert isinstance(message_source.source_platform_object, discord.Message)
|
||||
|
||||
args = {
|
||||
'content': msg_to_send,
|
||||
}
|
||||
|
||||
if len(image_files) > 0:
|
||||
args['files'] = image_files
|
||||
if len(files) > 0:
|
||||
args['files'] = files
|
||||
|
||||
if quote_origin:
|
||||
args['reference'] = message_source.source_platform_object
|
||||
|
||||
BIN
src/langbot/pkg/platform/sources/kook.png
Normal file
BIN
src/langbot/pkg/platform/sources/kook.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
680
src/langbot/pkg/platform/sources/kook.py
Normal file
680
src/langbot/pkg/platform/sources/kook.py
Normal file
@@ -0,0 +1,680 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import zlib
|
||||
import traceback
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
import websockets
|
||||
import pydantic
|
||||
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||
|
||||
|
||||
class KookMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
"""Convert between LangBot MessageChain and KOOK message format"""
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain: platform_message.MessageChain) -> tuple[str, int]:
|
||||
"""
|
||||
Convert LangBot MessageChain to KOOK message format
|
||||
|
||||
Returns:
|
||||
tuple: (content, message_type)
|
||||
- content: message content string
|
||||
- message_type: 1=text, 2=image, 4=file, 9=KMarkdown
|
||||
"""
|
||||
content_parts = []
|
||||
message_type = 1 # Default to text
|
||||
|
||||
for component in message_chain:
|
||||
if isinstance(component, platform_message.Plain):
|
||||
content_parts.append(component.text)
|
||||
elif isinstance(component, platform_message.At):
|
||||
# KOOK mention format: (met)user_id(met)
|
||||
if component.target:
|
||||
content_parts.append(f'(met){component.target}(met)')
|
||||
elif isinstance(component, platform_message.AtAll):
|
||||
# KOOK @all format: (met)all(met)
|
||||
content_parts.append('(met)all(met)')
|
||||
elif isinstance(component, platform_message.Image):
|
||||
# For images, we need to upload first via KOOK's asset API
|
||||
# For now, we'll send the image URL if available
|
||||
if component.url:
|
||||
content_parts.append(component.url)
|
||||
message_type = 2 # Image message type
|
||||
elif isinstance(component, platform_message.Forward):
|
||||
# Handle forward messages by concatenating content
|
||||
for node in component.node_list:
|
||||
forward_content, _ = await KookMessageConverter.yiri2target(node.message_chain)
|
||||
content_parts.append(forward_content)
|
||||
# Ignore Source and other components
|
||||
|
||||
content = ''.join(content_parts)
|
||||
return content, message_type
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(kook_message: dict, bot_account_id: str = '') -> platform_message.MessageChain:
|
||||
"""
|
||||
Convert KOOK message format to LangBot MessageChain
|
||||
|
||||
Args:
|
||||
kook_message: KOOK message event data dict
|
||||
bot_account_id: Bot's account ID for handling role mentions
|
||||
"""
|
||||
components = []
|
||||
|
||||
msg_type = kook_message.get('type', 1)
|
||||
content = kook_message.get('content', '')
|
||||
extra = kook_message.get('extra', {})
|
||||
|
||||
# Handle mentions
|
||||
mentions = extra.get('mention', [])
|
||||
mention_all = extra.get('mention_all', False)
|
||||
mention_roles = extra.get('mention_roles', [])
|
||||
|
||||
if mention_all:
|
||||
components.append(platform_message.AtAll())
|
||||
|
||||
for mention_id in mentions:
|
||||
components.append(platform_message.At(target=str(mention_id)))
|
||||
|
||||
# Handle role mentions (when bot is mentioned via role)
|
||||
# In KOOK, when a role that the bot has is mentioned, we receive it as a role mention
|
||||
# We need to convert this to an At with the bot's account ID for the pipeline to recognize it
|
||||
if mention_roles and bot_account_id:
|
||||
# Add an At component with the bot's account ID when any role is mentioned
|
||||
# This is because KOOK bots are often assigned roles and @role mentions should trigger responses
|
||||
components.append(platform_message.At(target=bot_account_id))
|
||||
|
||||
# Strip mention patterns from content
|
||||
# Remove user mention patterns: (met)USER_ID(met)
|
||||
for mention_id in mentions:
|
||||
content = content.replace(f'(met){mention_id}(met)', '')
|
||||
|
||||
# Remove @all pattern
|
||||
if mention_all:
|
||||
content = content.replace('(met)all(met)', '')
|
||||
|
||||
# Remove role mention patterns: (rol)ROLE_ID(rol)
|
||||
for role_id in mention_roles:
|
||||
content = content.replace(f'(rol){role_id}(rol)', '')
|
||||
|
||||
# Clean up extra whitespace
|
||||
content = content.strip()
|
||||
|
||||
# Handle different message types
|
||||
if msg_type == 1: # Text message
|
||||
if content:
|
||||
components.append(platform_message.Plain(text=content))
|
||||
elif msg_type == 2: # Image message
|
||||
# Image content is typically a URL
|
||||
if content:
|
||||
# Download image and convert to base64
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(content) as response:
|
||||
if response.status == 200:
|
||||
image_bytes = await response.read()
|
||||
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||
# Detect image format
|
||||
content_type = response.headers.get('Content-Type', 'image/png')
|
||||
components.append(
|
||||
platform_message.Image(base64=f'data:{content_type};base64,{image_base64}')
|
||||
)
|
||||
except Exception:
|
||||
# If download fails, just add as plain text
|
||||
components.append(platform_message.Plain(text=f'[Image: {content}]'))
|
||||
elif msg_type == 4: # File message
|
||||
# For file messages, content is typically the file URL
|
||||
attachments = extra.get('attachments', {})
|
||||
file_name = attachments.get('name', 'file')
|
||||
components.append(platform_message.File(url=content, name=file_name))
|
||||
elif msg_type == 8: # Audio message
|
||||
# For audio messages, content is typically the audio URL
|
||||
attachments = extra.get('attachments', {})
|
||||
components.append(platform_message.Voice(url=content))
|
||||
elif msg_type == 9: # KMarkdown message
|
||||
# Note: content is already stripped of mention patterns above
|
||||
if content:
|
||||
components.append(platform_message.Plain(text=content))
|
||||
elif msg_type == 10: # Card message
|
||||
# Card messages are complex, for now just indicate it's a card
|
||||
components.append(platform_message.Plain(text='[Card Message]'))
|
||||
else:
|
||||
# Other message types, just use content as plain text
|
||||
if content:
|
||||
components.append(platform_message.Plain(text=content))
|
||||
|
||||
return platform_message.MessageChain(components)
|
||||
|
||||
|
||||
class KookEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
"""Convert between LangBot events and KOOK events"""
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(event: platform_events.MessageEvent):
|
||||
"""Convert LangBot event to KOOK event (not implemented)"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(kook_event: dict, bot_account_id: str = '') -> platform_events.MessageEvent:
|
||||
"""
|
||||
Convert KOOK event to LangBot MessageEvent
|
||||
|
||||
Args:
|
||||
kook_event: KOOK event data dict containing channel_type, type, etc.
|
||||
bot_account_id: Bot's account ID for handling role mentions
|
||||
|
||||
Returns:
|
||||
FriendMessage or GroupMessage depending on channel_type
|
||||
"""
|
||||
channel_type = kook_event.get('channel_type')
|
||||
author_id = kook_event.get('author_id')
|
||||
target_id = kook_event.get('target_id')
|
||||
msg_timestamp = kook_event.get('msg_timestamp', int(time.time() * 1000))
|
||||
extra = kook_event.get('extra', {})
|
||||
|
||||
# Convert message to MessageChain
|
||||
message_chain = await KookMessageConverter.target2yiri(kook_event, bot_account_id)
|
||||
|
||||
# Convert timestamp from milliseconds to seconds
|
||||
event_time = msg_timestamp / 1000.0
|
||||
|
||||
if channel_type == 'PERSON':
|
||||
# Direct/Private message
|
||||
author = extra.get('author', {})
|
||||
author_name = author.get('nickname', author.get('username', str(author_id)))
|
||||
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=str(author_id),
|
||||
nickname=author_name,
|
||||
remark=str(author_id),
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event_time,
|
||||
source_platform_object=kook_event,
|
||||
)
|
||||
elif channel_type == 'GROUP':
|
||||
# Guild/Server channel message
|
||||
author = extra.get('author', {})
|
||||
author_name = author.get('nickname', author.get('username', str(author_id)))
|
||||
|
||||
# guild_id = extra.get('guild_id', '')
|
||||
channel_name = extra.get('channel_name', str(target_id))
|
||||
|
||||
return platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=str(author_id),
|
||||
member_name=author_name,
|
||||
permission=platform_entities.Permission.Member,
|
||||
group=platform_entities.Group(
|
||||
id=str(target_id), # Channel ID
|
||||
name=channel_name,
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event_time,
|
||||
source_platform_object=kook_event,
|
||||
)
|
||||
else:
|
||||
# Fallback to FriendMessage for unknown channel types
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=str(author_id),
|
||||
nickname=str(author_id),
|
||||
remark=str(author_id),
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event_time,
|
||||
source_platform_object=kook_event,
|
||||
)
|
||||
|
||||
|
||||
class KookAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
"""KOOK platform adapter for LangBot"""
|
||||
|
||||
config: dict
|
||||
message_converter: KookMessageConverter = KookMessageConverter()
|
||||
event_converter: KookEventConverter = KookEventConverter()
|
||||
listeners: typing.Dict[
|
||||
typing.Type[platform_events.Event],
|
||||
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||||
] = {}
|
||||
|
||||
# WebSocket connection
|
||||
ws: typing.Optional[websockets.WebSocketClientProtocol] = pydantic.Field(exclude=True, default=None)
|
||||
ws_task: typing.Optional[asyncio.Task] = pydantic.Field(exclude=True, default=None)
|
||||
heartbeat_task: typing.Optional[asyncio.Task] = pydantic.Field(exclude=True, default=None)
|
||||
running: bool = pydantic.Field(exclude=True, default=False)
|
||||
|
||||
# Connection state
|
||||
session_id: str = pydantic.Field(exclude=True, default='')
|
||||
current_sn: int = pydantic.Field(exclude=True, default=0)
|
||||
gateway_url: str = pydantic.Field(exclude=True, default='')
|
||||
|
||||
# HTTP session
|
||||
http_session: typing.Optional[aiohttp.ClientSession] = pydantic.Field(exclude=True, default=None)
|
||||
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||||
# Debug: Track init
|
||||
with open('/tmp/kook_adapter_init.txt', 'w') as f:
|
||||
f.write(f'KOOK adapter __init__ called at {time.time()}\n')
|
||||
|
||||
# Validate required config
|
||||
if 'token' not in config:
|
||||
raise Exception('KOOK adapter requires "token" in config')
|
||||
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
bot_account_id='', # Will be set after connection
|
||||
listeners={},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _get_gateway_url(self) -> str:
|
||||
"""Get WebSocket gateway URL from KOOK API"""
|
||||
base_url = 'https://www.kookapp.cn/api/v3/gateway/index'
|
||||
|
||||
# Always use compression for better performance
|
||||
params = {'compress': 1}
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bot {self.config["token"]}',
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(base_url, params=params, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
if data.get('code') == 0:
|
||||
gateway_url = data['data']['url']
|
||||
return gateway_url
|
||||
else:
|
||||
raise Exception(f'Failed to get gateway URL: {data.get("message")}')
|
||||
else:
|
||||
raise Exception(f'Failed to get gateway URL: HTTP {response.status}')
|
||||
|
||||
async def _get_bot_user_info(self) -> dict:
|
||||
"""Get bot's own user information from KOOK API"""
|
||||
base_url = 'https://www.kookapp.cn/api/v3/user/me'
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bot {self.config["token"]}',
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(base_url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
if data.get('code') == 0:
|
||||
user_info = data['data']
|
||||
return user_info
|
||||
else:
|
||||
raise Exception(f'Failed to get bot user info: {data.get("message")}')
|
||||
else:
|
||||
raise Exception(f'Failed to get bot user info: HTTP {response.status}')
|
||||
|
||||
async def _handle_hello(self, data: dict):
|
||||
"""Handle HELLO signal (signal 1)"""
|
||||
session_id = data.get('session_id', '')
|
||||
self.session_id = session_id
|
||||
await self.logger.info(f'KOOK WebSocket HELLO received, session_id: {session_id}')
|
||||
|
||||
async def _handle_event(self, data: dict, sn: int):
|
||||
"""Handle EVENT signal (signal 0)"""
|
||||
self.current_sn = max(self.current_sn, sn)
|
||||
|
||||
# Check if this is a message event
|
||||
event_type = data.get('type')
|
||||
channel_type = data.get('channel_type')
|
||||
author_id = data.get('author_id')
|
||||
|
||||
# Ignore messages from bot itself to prevent infinite loops
|
||||
if self.bot_account_id and str(author_id) == self.bot_account_id:
|
||||
return
|
||||
|
||||
# Only process text messages (type 1, 2, 4, 8, 9, 10) in GROUP or PERSON channels
|
||||
if event_type in [1, 2, 4, 8, 9, 10] and channel_type in ['GROUP', 'PERSON']:
|
||||
try:
|
||||
# Convert to LangBot event
|
||||
lb_event = await self.event_converter.target2yiri(data, self.bot_account_id)
|
||||
|
||||
# Call registered listener
|
||||
event_class = type(lb_event)
|
||||
if event_class in self.listeners:
|
||||
await self.listeners[event_class](lb_event, self)
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Error handling KOOK event: {e}\n{traceback.format_exc()}')
|
||||
|
||||
async def _handle_pong(self, data: dict):
|
||||
"""Handle PONG signal (signal 3)"""
|
||||
# PONG received, connection is healthy
|
||||
pass
|
||||
|
||||
async def _heartbeat_loop(self):
|
||||
"""Send PING every 30 seconds"""
|
||||
try:
|
||||
while self.running and self.ws:
|
||||
await asyncio.sleep(30)
|
||||
|
||||
if self.ws:
|
||||
try:
|
||||
ping_msg = {
|
||||
's': 2, # PING signal
|
||||
'sn': self.current_sn,
|
||||
}
|
||||
await self.ws.send(json.dumps(ping_msg))
|
||||
except Exception:
|
||||
# Connection closed or send failed, exit loop
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Heartbeat error: {e}')
|
||||
|
||||
async def _websocket_loop(self):
|
||||
"""Main WebSocket event loop"""
|
||||
retry_count = 0
|
||||
max_retries = 3
|
||||
|
||||
while self.running and retry_count < max_retries:
|
||||
try:
|
||||
# Get gateway URL if not already retrieved
|
||||
if not self.gateway_url:
|
||||
self.gateway_url = await self._get_gateway_url()
|
||||
|
||||
# Connect to WebSocket
|
||||
async with websockets.connect(self.gateway_url) as ws:
|
||||
await self.logger.info(f'Connected to KOOK WebSocket: {self.gateway_url}')
|
||||
self.ws = ws
|
||||
|
||||
# Start heartbeat
|
||||
self.heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
|
||||
# Wait for HELLO within 6 seconds
|
||||
try:
|
||||
hello_msg = await asyncio.wait_for(ws.recv(), timeout=6.0)
|
||||
|
||||
# Handle compressed messages (same as main message loop)
|
||||
if isinstance(hello_msg, bytes):
|
||||
# Decompress if compressed
|
||||
try:
|
||||
hello_msg = zlib.decompress(hello_msg).decode('utf-8')
|
||||
except Exception:
|
||||
# Not compressed or decompression failed
|
||||
hello_msg = hello_msg.decode('utf-8')
|
||||
|
||||
hello_data = json.loads(hello_msg)
|
||||
|
||||
if hello_data.get('s') == 1: # HELLO signal
|
||||
await self._handle_hello(hello_data['d'])
|
||||
else:
|
||||
raise Exception(f'Expected HELLO signal, got signal {hello_data.get("s")}')
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception('Did not receive HELLO within 6 seconds')
|
||||
|
||||
# Reset retry count on successful connection
|
||||
retry_count = 0
|
||||
|
||||
# Main message loop
|
||||
async for message in ws:
|
||||
if isinstance(message, bytes):
|
||||
# Decompress if compressed
|
||||
try:
|
||||
message = zlib.decompress(message).decode('utf-8')
|
||||
except Exception:
|
||||
# Not compressed or decompression failed
|
||||
message = message.decode('utf-8')
|
||||
|
||||
try:
|
||||
msg_data = json.loads(message)
|
||||
signal = msg_data.get('s')
|
||||
|
||||
if signal == 0: # EVENT
|
||||
data = msg_data.get('d', {})
|
||||
sn = msg_data.get('sn', 0)
|
||||
await self._handle_event(data, sn)
|
||||
elif signal == 3: # PONG
|
||||
await self._handle_pong(msg_data.get('d', {}))
|
||||
elif signal == 5: # RECONNECT
|
||||
# await self.logger.info('Received RECONNECT signal')
|
||||
break # Break to reconnect
|
||||
elif signal == 6: # RESUME ACK
|
||||
# await self.logger.info('Resume successful')
|
||||
pass
|
||||
except json.JSONDecodeError:
|
||||
await self.logger.error(f'Failed to parse message: {message}')
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Error processing message: {e}\n{traceback.format_exc()}')
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
await self.logger.warning('KOOK WebSocket connection closed, reconnecting...')
|
||||
retry_count += 1
|
||||
await asyncio.sleep(2**retry_count) # Exponential backoff
|
||||
except Exception as e:
|
||||
await self.logger.error(f'KOOK WebSocket error: {e}\n{traceback.format_exc()}')
|
||||
retry_count += 1
|
||||
await asyncio.sleep(2**retry_count)
|
||||
finally:
|
||||
# Stop heartbeat
|
||||
if self.heartbeat_task:
|
||||
self.heartbeat_task.cancel()
|
||||
try:
|
||||
await self.heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.ws = None
|
||||
|
||||
if retry_count >= max_retries:
|
||||
await self.logger.error(f'Failed to connect after {max_retries} retries')
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
"""Send a message to a channel or user"""
|
||||
content, msg_type = await self.message_converter.yiri2target(message)
|
||||
|
||||
# Determine endpoint based on target_type
|
||||
if target_type == 'GROUP':
|
||||
# Send to channel
|
||||
url = 'https://www.kookapp.cn/api/v3/message/create'
|
||||
payload = {
|
||||
'target_id': target_id,
|
||||
'content': content,
|
||||
'type': msg_type,
|
||||
}
|
||||
else: # PERSON or default
|
||||
# Send direct message
|
||||
url = 'https://www.kookapp.cn/api/v3/direct-message/create'
|
||||
payload = {
|
||||
'target_id': target_id,
|
||||
'content': content,
|
||||
'type': msg_type,
|
||||
}
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bot {self.config["token"]}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
try:
|
||||
if not self.http_session:
|
||||
self.http_session = aiohttp.ClientSession()
|
||||
|
||||
async with self.http_session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if result.get('code') == 0:
|
||||
await self.logger.debug(f'Message sent successfully to {target_id}')
|
||||
else:
|
||||
await self.logger.error(f'Failed to send message: {result.get("message")}')
|
||||
else:
|
||||
await self.logger.error(f'Failed to send message: HTTP {response.status}')
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Error sending message: {e}')
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
"""Reply to a message"""
|
||||
content, msg_type = await self.message_converter.yiri2target(message)
|
||||
|
||||
kook_event = message_source.source_platform_object
|
||||
channel_type = kook_event.get('channel_type')
|
||||
target_id = kook_event.get('target_id')
|
||||
msg_id = kook_event.get('msg_id')
|
||||
|
||||
# Determine endpoint based on channel_type
|
||||
if channel_type == 'GROUP':
|
||||
url = 'https://www.kookapp.cn/api/v3/message/create'
|
||||
payload = {
|
||||
'target_id': target_id,
|
||||
'content': content,
|
||||
'type': msg_type,
|
||||
}
|
||||
else: # PERSON
|
||||
url = 'https://www.kookapp.cn/api/v3/direct-message/create'
|
||||
# For direct messages, we need the chat_code or target_id
|
||||
author_id = kook_event.get('author_id')
|
||||
extra = kook_event.get('extra', {})
|
||||
chat_code = extra.get('code', '')
|
||||
|
||||
payload = {
|
||||
'content': content,
|
||||
'type': msg_type,
|
||||
}
|
||||
|
||||
if chat_code:
|
||||
payload['chat_code'] = chat_code
|
||||
else:
|
||||
payload['target_id'] = str(author_id)
|
||||
|
||||
# Add quote if requested
|
||||
if quote_origin and msg_id:
|
||||
payload['quote'] = msg_id
|
||||
|
||||
payload['reply_msg_id'] = msg_id
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bot {self.config["token"]}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
try:
|
||||
if not self.http_session:
|
||||
self.http_session = aiohttp.ClientSession()
|
||||
|
||||
async with self.http_session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if result.get('code') == 0:
|
||||
await self.logger.debug('Reply sent successfully')
|
||||
else:
|
||||
await self.logger.error(f'Failed to send reply: {result.get("message")}')
|
||||
else:
|
||||
await self.logger.error(f'Failed to send reply: HTTP {response.status}')
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Error sending reply: {e}')
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
"""Check if bot is muted in a group (not implemented for KOOK)"""
|
||||
return False
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
"""Register an event listener"""
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
"""Unregister an event listener"""
|
||||
self.listeners.pop(event_type, None)
|
||||
|
||||
async def run_async(self):
|
||||
"""Start the KOOK adapter"""
|
||||
# Debug: Track run_async
|
||||
with open('/tmp/kook_adapter_run.txt', 'w') as f:
|
||||
f.write(f'KOOK adapter run_async called at {time.time()}\n')
|
||||
|
||||
self.running = True
|
||||
|
||||
try:
|
||||
# Create HTTP session
|
||||
self.http_session = aiohttp.ClientSession()
|
||||
|
||||
await self.logger.info('Starting KOOK adapter')
|
||||
|
||||
# Get bot's user information and set bot_account_id
|
||||
try:
|
||||
bot_info = await self._get_bot_user_info()
|
||||
self.bot_account_id = str(bot_info.get('id', ''))
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Failed to get bot user info: {e}')
|
||||
# Continue anyway, but bot will process its own messages
|
||||
|
||||
# Start WebSocket connection
|
||||
self.ws_task = asyncio.create_task(self._websocket_loop())
|
||||
|
||||
# Keep running
|
||||
await self.ws_task
|
||||
except Exception as e:
|
||||
await self.logger.error(f'KOOK adapter error: {e}\n{traceback.format_exc()}')
|
||||
finally:
|
||||
self.running = False
|
||||
|
||||
async def kill(self) -> bool:
|
||||
"""Stop the KOOK adapter"""
|
||||
self.running = False
|
||||
|
||||
# Cancel tasks
|
||||
if self.heartbeat_task:
|
||||
self.heartbeat_task.cancel()
|
||||
try:
|
||||
await self.heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self.ws_task:
|
||||
self.ws_task.cancel()
|
||||
try:
|
||||
await self.ws_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close WebSocket
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.close()
|
||||
except Exception:
|
||||
pass # Already closed or error during close
|
||||
|
||||
# Close HTTP session
|
||||
if self.http_session:
|
||||
await self.http_session.close()
|
||||
|
||||
await self.logger.info('KOOK adapter stopped')
|
||||
return True
|
||||
24
src/langbot/pkg/platform/sources/kook.yaml
Normal file
24
src/langbot/pkg/platform/sources/kook.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
apiVersion: v1
|
||||
kind: MessagePlatformAdapter
|
||||
metadata:
|
||||
name: kook
|
||||
label:
|
||||
en_US: KOOK
|
||||
zh_Hans: KOOK
|
||||
description:
|
||||
en_US: KOOK Adapter (formerly KaiHeiLa)
|
||||
zh_Hans: KOOK 适配器(原开黑啦),支持频道消息和私聊消息
|
||||
icon: kook.png
|
||||
spec:
|
||||
config:
|
||||
- name: token
|
||||
label:
|
||||
en_US: Bot Token
|
||||
zh_Hans: 机器人令牌
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
execution:
|
||||
python:
|
||||
path: ./kook.py
|
||||
attr: KookAdapter
|
||||
@@ -55,9 +55,7 @@ class AESCipher(object):
|
||||
|
||||
class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@staticmethod
|
||||
async def upload_image_to_lark(
|
||||
msg: platform_message.Image, api_client: lark_oapi.Client
|
||||
) -> typing.Optional[str]:
|
||||
async def upload_image_to_lark(msg: platform_message.Image, api_client: lark_oapi.Client) -> typing.Optional[str]:
|
||||
"""Upload an image to Lark and return the image_key, or None if upload fails."""
|
||||
image_bytes = None
|
||||
|
||||
@@ -95,7 +93,9 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
return None
|
||||
|
||||
if image_bytes is None:
|
||||
print(f'No image data available for Image message (url={msg.url}, base64={bool(msg.base64)}, path={msg.path})')
|
||||
print(
|
||||
f'No image data available for Image message (url={msg.url}, base64={bool(msg.base64)}, path={msg.path})'
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -113,10 +113,7 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
request = (
|
||||
CreateImageRequest.builder()
|
||||
.request_body(
|
||||
CreateImageRequestBody.builder()
|
||||
.image_type('message')
|
||||
.image(open(temp_file_path, 'rb'))
|
||||
.build()
|
||||
CreateImageRequestBody.builder().image_type('message').image(open(temp_file_path, 'rb')).build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
@@ -143,7 +140,7 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
message_chain: platform_message.MessageChain, api_client: lark_oapi.Client
|
||||
) -> typing.Tuple[list, list]:
|
||||
"""Convert message chain to Lark format.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (text_elements, image_keys):
|
||||
- text_elements: List of paragraphs for post message format
|
||||
@@ -159,24 +156,24 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
async def process_text_with_images(text: str) -> typing.Tuple[str, list]:
|
||||
"""Extract Markdown images from text and return cleaned text + image URLs."""
|
||||
extracted_urls = []
|
||||
|
||||
|
||||
# Find all Markdown images
|
||||
matches = list(markdown_image_pattern.finditer(text))
|
||||
if not matches:
|
||||
return text, []
|
||||
|
||||
|
||||
# Extract URLs and remove image syntax from text
|
||||
cleaned_text = text
|
||||
for match in reversed(matches): # Reverse to maintain correct positions
|
||||
url = match.group(2)
|
||||
extracted_urls.insert(0, url) # Insert at beginning since we're going in reverse
|
||||
# Replace image syntax with empty string or a placeholder
|
||||
cleaned_text = cleaned_text[:match.start()] + cleaned_text[match.end():]
|
||||
|
||||
cleaned_text = cleaned_text[: match.start()] + cleaned_text[match.end() :]
|
||||
|
||||
# Clean up multiple consecutive newlines that might result from removing images
|
||||
cleaned_text = re.sub(r'\n{3,}', '\n\n', cleaned_text)
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
|
||||
return cleaned_text, extracted_urls
|
||||
|
||||
for msg in message_chain:
|
||||
@@ -189,14 +186,14 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
text = msg.text.encode('latin1').decode('utf-8')
|
||||
except UnicodeError:
|
||||
text = msg.text.encode('utf-8', errors='replace').decode('utf-8')
|
||||
|
||||
|
||||
# Check for and extract Markdown images from text
|
||||
cleaned_text, extracted_urls = await process_text_with_images(text)
|
||||
|
||||
|
||||
# Add cleaned text if not empty
|
||||
if cleaned_text:
|
||||
pending_paragraph.append({'tag': 'md', 'text': cleaned_text})
|
||||
|
||||
|
||||
# Process extracted image URLs
|
||||
for url in extracted_urls:
|
||||
# Create a temporary Image message to upload
|
||||
@@ -204,7 +201,7 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
image_key = await LarkMessageConverter.upload_image_to_lark(temp_image, api_client)
|
||||
if image_key:
|
||||
image_keys.append(image_key)
|
||||
|
||||
|
||||
elif isinstance(msg, platform_message.At):
|
||||
pending_paragraph.append({'tag': 'at', 'user_id': msg.target, 'style': []})
|
||||
elif isinstance(msg, platform_message.AtAll):
|
||||
@@ -300,6 +297,10 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
message_content['content'] = new_list
|
||||
elif message.message_type == 'image':
|
||||
message_content['content'] = [{'tag': 'img', 'image_key': message_content['image_key'], 'style': []}]
|
||||
elif message.message_type == 'file':
|
||||
message_content['content'] = [
|
||||
{'tag': 'file', 'file_key': message_content['file_key'], 'file_name': message_content['file_name']}
|
||||
]
|
||||
|
||||
for ele in message_content['content']:
|
||||
if ele['tag'] == 'text':
|
||||
@@ -330,6 +331,33 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
image_format = response.raw.headers['content-type']
|
||||
|
||||
lb_msg_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}'))
|
||||
elif ele['tag'] == 'file':
|
||||
file_key = ele['file_key']
|
||||
file_name = ele['file_name']
|
||||
|
||||
request: GetMessageResourceRequest = (
|
||||
GetMessageResourceRequest.builder()
|
||||
.message_id(message.message_id)
|
||||
.file_key(file_key)
|
||||
.type('file')
|
||||
.build()
|
||||
)
|
||||
|
||||
response: GetMessageResourceResponse = await api_client.im.v1.message_resource.aget(request)
|
||||
|
||||
if not response.success():
|
||||
raise Exception(
|
||||
f'client.im.v1.message_resource.get failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
|
||||
file_bytes = response.file.read()
|
||||
file_base64 = base64.b64encode(file_bytes).decode()
|
||||
|
||||
file_format = response.raw.headers['content-type']
|
||||
|
||||
lb_msg_list.append(
|
||||
platform_message.File(base64=f'data:{file_format};base64,{file_base64}', name=file_name)
|
||||
)
|
||||
|
||||
return platform_message.MessageChain(lb_msg_list)
|
||||
|
||||
@@ -369,9 +397,6 @@ class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event.event.message.create_time,
|
||||
|
||||
@@ -437,9 +437,6 @@ class GewechatEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event['Data']['CreateTime'],
|
||||
|
||||
@@ -153,9 +153,6 @@ class NakuruProjectEventConverter(abstract_platform_adapter.AbstractEventConvert
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title=event.sender.title,
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=yiri_chain,
|
||||
time=event.time,
|
||||
|
||||
@@ -279,11 +279,6 @@ class OfficialEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=int(
|
||||
datetime.datetime.strptime(event.member.joined_at, '%Y-%m-%dT%H:%M:%S%z').timestamp()
|
||||
),
|
||||
last_speak_timestamp=datetime.datetime.now().timestamp(),
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
|
||||
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
|
||||
@@ -312,9 +307,6 @@ class OfficialEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=int(0),
|
||||
last_speak_timestamp=datetime.datetime.now().timestamp(),
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
|
||||
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
|
||||
|
||||
@@ -108,9 +108,6 @@ class LINEEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event.timestamp,
|
||||
@@ -121,6 +118,7 @@ class LINEEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
class LINEAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: MessagingApi
|
||||
api_client: ApiClient
|
||||
parser: WebhookParser
|
||||
|
||||
bot_account_id: str # 用于在流水线中识别at是否是本bot,直接以bot_name作为标识
|
||||
message_converter: LINEMessageConverter
|
||||
@@ -132,7 +130,7 @@ class LINEAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
]
|
||||
|
||||
config: dict
|
||||
quart_app: quart.Quart
|
||||
bot_uuid: str = None
|
||||
|
||||
card_id_dict: dict[str, str] # 消息id到卡片id的映射,便于创建卡片后的发送消息到指定卡片
|
||||
|
||||
@@ -149,7 +147,6 @@ class LINEAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
quart_app=quart.Quart(__name__),
|
||||
listeners={},
|
||||
card_id_dict={},
|
||||
seq=1,
|
||||
@@ -163,29 +160,6 @@ class LINEAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot_account_id=bot_account_id,
|
||||
)
|
||||
|
||||
@self.quart_app.route('/line/callback', methods=['POST'])
|
||||
async def line_callback():
|
||||
try:
|
||||
signature = quart.request.headers.get('X-Line-Signature')
|
||||
body = await quart.request.get_data(as_text=True)
|
||||
events = parser.parse(body, signature) # 解密解析消息
|
||||
|
||||
try:
|
||||
# print(events)
|
||||
lb_event = await self.event_converter.target2yiri(events[0], self.api_client)
|
||||
if lb_event.__class__ in self.listeners:
|
||||
await self.listeners[lb_event.__class__](lb_event, self)
|
||||
except InvalidSignatureError:
|
||||
self.logger.info(
|
||||
f'Invalid signature. Please check your channel access token/channel secret.{traceback.format_exc()}'
|
||||
)
|
||||
return quart.Response('Invalid signature', status=400)
|
||||
|
||||
return {'code': 200, 'message': 'ok'}
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in LINE callback: {traceback.format_exc()}')
|
||||
return {'code': 500, 'message': 'error'}
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
pass
|
||||
|
||||
@@ -236,18 +210,60 @@ class LINEAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
):
|
||||
self.listeners.pop(event_type)
|
||||
|
||||
async def run_async(self):
|
||||
port = self.config['port']
|
||||
def set_bot_uuid(self, bot_uuid: str):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
async def shutdown_trigger_placeholder():
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
try:
|
||||
signature = request.headers.get('X-Line-Signature')
|
||||
body = await request.get_data(as_text=True)
|
||||
|
||||
# Check if signature header exists
|
||||
if not signature:
|
||||
await self.logger.warning('Missing X-Line-Signature header')
|
||||
return quart.Response('Missing X-Line-Signature header', status=400)
|
||||
|
||||
try:
|
||||
events = self.parser.parse(body, signature) # 解密解析消息
|
||||
except InvalidSignatureError:
|
||||
await self.logger.info(
|
||||
f'Invalid signature. Please check your channel access token/channel secret.{traceback.format_exc()}'
|
||||
)
|
||||
return quart.Response('Invalid signature', status=400)
|
||||
|
||||
# 处理事件
|
||||
if events and len(events) > 0:
|
||||
lb_event = await self.event_converter.target2yiri(events[0], self.api_client)
|
||||
if lb_event.__class__ in self.listeners:
|
||||
await self.listeners[lb_event.__class__](lb_event, self)
|
||||
|
||||
return {'code': 200, 'message': 'ok'}
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in LINE callback: {traceback.format_exc()}')
|
||||
print(traceback.format_exc())
|
||||
return {'code': 500, 'message': 'error'}
|
||||
|
||||
async def run_async(self):
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
# 保持运行但不启动独立端口
|
||||
|
||||
# 打印 webhook 回调地址
|
||||
async def keep_alive():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await self.quart_app.run_task(
|
||||
host='0.0.0.0',
|
||||
port=port,
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
await keep_alive()
|
||||
|
||||
async def kill(self) -> bool:
|
||||
pass
|
||||
|
||||
@@ -22,18 +22,6 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: port
|
||||
label:
|
||||
en_US: Webhook Port
|
||||
zh_Hans: Webhook端口
|
||||
description:
|
||||
en_US: Only valid when webhook mode is enabled, please fill in the webhook port
|
||||
zh_Hans: 请填写 Webhook 端口
|
||||
ja_JP: Webhookポートを入力してください
|
||||
zh_Hant: 請填寫 Webhook 端口
|
||||
type: integer
|
||||
required: true
|
||||
default: 2287
|
||||
- name: channel_secret
|
||||
label:
|
||||
en_US: Channel secret
|
||||
|
||||
@@ -11,7 +11,7 @@ from langbot.libs.official_account_api.api import OAClientForLongerResponse
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
from langbot.pkg.platform.logger import EventLogger
|
||||
from ..logger import EventLogger
|
||||
|
||||
|
||||
class OAMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@@ -58,13 +58,16 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
|
||||
message_converter: OAMessageConverter = OAMessageConverter()
|
||||
event_converter: OAEventConverter = OAEventConverter()
|
||||
bot: typing.Union[OAClient, OAClientForLongerResponse] = pydantic.Field(exclude=True)
|
||||
bot_uuid: str = None
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
# 校验必填项
|
||||
required_keys = ['token', 'EncodingAESKey', 'AppSecret', 'AppID', 'Mode']
|
||||
missing_keys = [k for k in required_keys if k not in config]
|
||||
if missing_keys:
|
||||
raise Exception(f'OfficialAccount 缺少配置项: {missing_keys}')
|
||||
|
||||
# 创建运行时 bot 对象,始终使用统一 webhook 模式
|
||||
if config['Mode'] == 'drop':
|
||||
bot = OAClient(
|
||||
token=config['token'],
|
||||
@@ -72,6 +75,7 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
|
||||
Appsecret=config['AppSecret'],
|
||||
AppID=config['AppID'],
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
)
|
||||
elif config['Mode'] == 'passive':
|
||||
bot = OAClientForLongerResponse(
|
||||
@@ -81,6 +85,7 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
|
||||
AppID=config['AppID'],
|
||||
LoadingMessage=config.get('LoadingMessage', ''),
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
)
|
||||
else:
|
||||
raise KeyError('请设置微信公众号通信模式')
|
||||
@@ -129,16 +134,32 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
pass
|
||||
|
||||
def set_bot_uuid(self, bot_uuid: str):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
async def run_async(self):
|
||||
async def shutdown_trigger_placeholder():
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
# 保持运行但不启动独立端口
|
||||
|
||||
async def keep_alive():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await self.bot.run_task(
|
||||
host=self.config['host'],
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
await keep_alive()
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -53,23 +53,6 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: "AI正在思考中,请发送任意内容获取回复。"
|
||||
- name: host
|
||||
label:
|
||||
en_US: Host
|
||||
zh_Hans: 监听主机
|
||||
description:
|
||||
en_US: The host that Official Account listens on for Webhook connections.
|
||||
zh_Hans: 微信公众号监听的主机,除非你知道自己在做什么,否则请写 0.0.0.0
|
||||
type: string
|
||||
required: true
|
||||
default: 0.0.0.0
|
||||
- name: port
|
||||
label:
|
||||
en_US: Port
|
||||
zh_Hans: 监听端口
|
||||
type: integer
|
||||
required: true
|
||||
default: 2287
|
||||
execution:
|
||||
python:
|
||||
path: ./officialaccount.py
|
||||
|
||||
@@ -11,8 +11,8 @@ import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
from langbot.libs.qq_official_api.api import QQOfficialClient
|
||||
from langbot.libs.qq_official_api.qqofficialevent import QQOfficialEvent
|
||||
from langbot.pkg.utils import image
|
||||
from langbot.pkg.platform.logger import EventLogger
|
||||
from ...utils import image
|
||||
from ..logger import EventLogger
|
||||
|
||||
|
||||
class QQOfficialMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@@ -94,9 +94,6 @@ class QQOfficialEventConverter(abstract_platform_adapter.AbstractEventConverter)
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
)
|
||||
time = int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp())
|
||||
return platform_events.GroupMessage(
|
||||
@@ -117,9 +114,6 @@ class QQOfficialEventConverter(abstract_platform_adapter.AbstractEventConverter)
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
)
|
||||
time = int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp())
|
||||
return platform_events.GroupMessage(
|
||||
@@ -134,11 +128,14 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
bot: QQOfficialClient
|
||||
config: dict
|
||||
bot_account_id: str
|
||||
bot_uuid: str = None
|
||||
message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter()
|
||||
event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
bot = QQOfficialClient(app_id=config['appid'], secret=config['secret'], token=config['token'], logger=logger)
|
||||
bot = QQOfficialClient(
|
||||
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=logger, unified_mode=True
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
config=config,
|
||||
@@ -223,16 +220,32 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
self.bot.on_message('GROUP_AT_MESSAGE_CREATE')(on_message)
|
||||
self.bot.on_message('AT_MESSAGE_CREATE')(on_message)
|
||||
|
||||
def set_bot_uuid(self, bot_uuid: str):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
async def run_async(self):
|
||||
async def shutdown_trigger_placeholder():
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
# 保持运行但不启动独立端口
|
||||
|
||||
async def keep_alive():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await self.bot.run_task(
|
||||
host='0.0.0.0',
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
await keep_alive()
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -25,13 +25,6 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: port
|
||||
label:
|
||||
en_US: Port
|
||||
zh_Hans: 监听端口
|
||||
type: integer
|
||||
required: true
|
||||
default: 2284
|
||||
- name: token
|
||||
label:
|
||||
en_US: Token
|
||||
|
||||
@@ -76,9 +76,6 @@ class SlackEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
id=event.channel_id, name='MEMBER', permission=platform_entities.Permission.Member
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
)
|
||||
time = int(datetime.datetime.utcnow().timestamp())
|
||||
return platform_events.GroupMessage(
|
||||
@@ -97,13 +94,12 @@ class SlackEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: SlackClient
|
||||
bot_account_id: str
|
||||
bot_uuid: str = None
|
||||
message_converter: SlackMessageConverter = SlackMessageConverter()
|
||||
event_converter: SlackEventConverter = SlackEventConverter()
|
||||
config: dict
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
required_keys = [
|
||||
'bot_token',
|
||||
'signing_secret',
|
||||
@@ -112,8 +108,15 @@ class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
if missing_keys:
|
||||
raise command_errors.ParamNotEnoughError('Slack机器人缺少相关配置项,请查看文档或联系管理员')
|
||||
|
||||
self.bot = SlackClient(
|
||||
bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger
|
||||
bot = SlackClient(
|
||||
bot_token=config['bot_token'], signing_secret=config['signing_secret'], logger=logger, unified_mode=True
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
bot=bot,
|
||||
bot_account_id=config['bot_token'],
|
||||
)
|
||||
|
||||
async def reply_message(
|
||||
@@ -165,16 +168,31 @@ class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
self.bot.on_message('channel')(on_message)
|
||||
|
||||
def set_bot_uuid(self, bot_uuid: str):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
async def run_async(self):
|
||||
async def shutdown_trigger_placeholder():
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
# 保持运行但不启动独立端口
|
||||
async def keep_alive():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await self.bot.run_task(
|
||||
host='0.0.0.0',
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
await keep_alive()
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -25,13 +25,6 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: port
|
||||
label:
|
||||
en_US: Port
|
||||
zh_Hans: 监听端口
|
||||
type: int
|
||||
required: true
|
||||
default: 2288
|
||||
execution:
|
||||
python:
|
||||
path: ./slack.py
|
||||
|
||||
@@ -120,9 +120,6 @@ class TelegramEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=lb_message,
|
||||
time=event.message.date.timestamp(),
|
||||
|
||||
@@ -97,8 +97,6 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
# 推送到所有相关连接
|
||||
await self.outbound_message_queue.put(message_data)
|
||||
|
||||
await self.logger.info(f'Send message to {target_id}: {message}')
|
||||
|
||||
return message_data
|
||||
|
||||
async def reply_message(
|
||||
@@ -117,7 +115,7 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
|
||||
# 从message_source获取pipeline_uuid和connection_id
|
||||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||||
# session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
|
||||
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
|
||||
|
||||
# 生成新的消息ID
|
||||
msg_id = len(session.get_message_list(pipeline_uuid)) + 1
|
||||
@@ -134,13 +132,15 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
# 保存到历史记录
|
||||
session.get_message_list(pipeline_uuid).append(message_data)
|
||||
|
||||
# 直接广播到所有该pipeline的连接
|
||||
# 直接广播到所有该pipeline的连接,包含session_type信息
|
||||
await ws_connection_manager.broadcast_to_pipeline(
|
||||
pipeline_uuid,
|
||||
{
|
||||
'type': 'response',
|
||||
'session_type': session_type,
|
||||
'data': message_data.model_dump(),
|
||||
},
|
||||
session_type=session_type,
|
||||
)
|
||||
|
||||
return message_data.model_dump()
|
||||
@@ -162,6 +162,7 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
)
|
||||
|
||||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||||
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
|
||||
message_list = session.get_message_list(pipeline_uuid)
|
||||
|
||||
# 检查是否是新的流式消息(通过bot_message对象判断)
|
||||
@@ -197,13 +198,15 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_list[-1] = message_data
|
||||
|
||||
# 直接广播到所有该pipeline的连接
|
||||
# 直接广播到所有该pipeline的连接,包含session_type信息
|
||||
await ws_connection_manager.broadcast_to_pipeline(
|
||||
pipeline_uuid,
|
||||
{
|
||||
'type': 'response',
|
||||
'session_type': session_type,
|
||||
'data': message_data.model_dump(),
|
||||
},
|
||||
session_type=session_type,
|
||||
)
|
||||
|
||||
return message_data.model_dump()
|
||||
@@ -237,7 +240,6 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
|
||||
async def run_async(self):
|
||||
"""运行适配器"""
|
||||
await self.logger.info('WebSocket适配器已启动')
|
||||
|
||||
try:
|
||||
while True:
|
||||
@@ -253,12 +255,11 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
await self.logger.info('WebSocket适配器已停止')
|
||||
raise
|
||||
|
||||
async def kill(self):
|
||||
"""停止适配器"""
|
||||
await self.logger.info('WebSocket适配器正在停止')
|
||||
pass
|
||||
|
||||
async def _process_image_components(self, message_chain_obj: list):
|
||||
"""
|
||||
@@ -344,13 +345,15 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
)
|
||||
use_session.get_message_list(pipeline_uuid).append(user_message)
|
||||
|
||||
# 广播用户消息到所有连接(包括发送者)
|
||||
# 广播用户消息到所有连接(包括发送者),包含session_type信息
|
||||
await ws_connection_manager.broadcast_to_pipeline(
|
||||
pipeline_uuid,
|
||||
{
|
||||
'type': 'user_message',
|
||||
'session_type': session_type,
|
||||
'data': user_message.model_dump(),
|
||||
},
|
||||
session_type=session_type,
|
||||
)
|
||||
|
||||
# 添加消息源
|
||||
|
||||
@@ -134,9 +134,20 @@ class WebSocketConnectionManager:
|
||||
connection_ids = self.session_connections.get(session_type, set())
|
||||
return [self.connections[cid] for cid in connection_ids if cid in self.connections]
|
||||
|
||||
async def broadcast_to_pipeline(self, pipeline_uuid: str, message: dict):
|
||||
"""向指定流水线的所有连接广播消息"""
|
||||
async def broadcast_to_pipeline(self, pipeline_uuid: str, message: dict, session_type: str = None):
|
||||
"""向指定流水线的所有连接广播消息
|
||||
|
||||
Args:
|
||||
pipeline_uuid: 流水线UUID
|
||||
message: 要广播的消息
|
||||
session_type: 可选的会话类型过滤器,如果提供则只向匹配的session_type连接广播
|
||||
"""
|
||||
connections = await self.get_connections_by_pipeline(pipeline_uuid)
|
||||
|
||||
# 如果指定了session_type,只向匹配的连接广播
|
||||
if session_type is not None:
|
||||
connections = [conn for conn in connections if conn.session_type == session_type]
|
||||
|
||||
tasks = []
|
||||
for conn in connections:
|
||||
tasks.append(self.send_to_connection(conn.connection_id, message))
|
||||
|
||||
@@ -501,9 +501,6 @@ class WeChatPadEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event['create_time'],
|
||||
|
||||
@@ -8,8 +8,8 @@ import datetime
|
||||
from langbot.libs.wecom_api.api import WecomClient
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
from langbot.libs.wecom_api.wecomevent import WecomEvent
|
||||
from langbot.pkg.utils import image
|
||||
from langbot.pkg.platform.logger import EventLogger
|
||||
from ...utils import image
|
||||
from ..logger import EventLogger
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
@@ -35,6 +35,20 @@ class WecomMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
'media_id': await bot.get_media_id(msg),
|
||||
}
|
||||
)
|
||||
elif type(msg) is platform_message.Voice:
|
||||
content_list.append(
|
||||
{
|
||||
'type': 'voice',
|
||||
'media_id': await bot.get_media_id(msg),
|
||||
}
|
||||
)
|
||||
elif type(msg) is platform_message.File:
|
||||
content_list.append(
|
||||
{
|
||||
'type': 'file',
|
||||
'media_id': await bot.get_media_id(msg),
|
||||
}
|
||||
)
|
||||
elif type(msg) is platform_message.Forward:
|
||||
for node in msg.node_list:
|
||||
content_list.extend((await WecomMessageConverter.yiri2target(node.message_chain, bot)))
|
||||
@@ -131,6 +145,7 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
message_converter: WecomMessageConverter = WecomMessageConverter()
|
||||
event_converter: WecomEventConverter = WecomEventConverter()
|
||||
config: dict
|
||||
bot_uuid: str = None
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
# 校验必填项
|
||||
@@ -141,11 +156,12 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
'EncodingAESKey',
|
||||
'contacts_secret',
|
||||
]
|
||||
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise Exception(f'Wecom 缺少配置项: {missing_keys}')
|
||||
|
||||
# 创建运行时 bot 对象
|
||||
# 创建运行时 bot 对象,始终使用统一 webhook 模式
|
||||
bot = WecomClient(
|
||||
corpid=config['corpid'],
|
||||
secret=config['secret'],
|
||||
@@ -153,6 +169,7 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
EncodingAESKey=config['EncodingAESKey'],
|
||||
contacts_secret=config['contacts_secret'],
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
@@ -162,6 +179,10 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot_account_id='',
|
||||
)
|
||||
|
||||
def set_bot_uuid(self, bot_uuid: str):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
@@ -178,11 +199,12 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
await self.bot.send_private_msg(fixed_user_id, Wecom_event.agent_id, content['content'])
|
||||
elif content['type'] == 'image':
|
||||
await self.bot.send_image(fixed_user_id, Wecom_event.agent_id, content['media_id'])
|
||||
elif content['type'] == 'voice':
|
||||
await self.bot.send_voice(fixed_user_id, Wecom_event.agent_id, content['media_id'])
|
||||
elif content['type'] == 'file':
|
||||
await self.bot.send_file(fixed_user_id, Wecom_event.agent_id, content['media_id'])
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
"""企业微信目前只有发送给个人的方法,
|
||||
构造target_id的方式为前半部分为账户id,后半部分为agent_id,中间使用“|”符号隔开。
|
||||
"""
|
||||
content_list = await WecomMessageConverter.yiri2target(message, self.bot)
|
||||
parts = target_id.split('|')
|
||||
user_id = parts[0]
|
||||
@@ -193,6 +215,10 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
await self.bot.send_private_msg(user_id, agent_id, content['content'])
|
||||
if content['type'] == 'image':
|
||||
await self.bot.send_image(user_id, agent_id, content['media'])
|
||||
if content['type'] == 'voice':
|
||||
await self.bot.send_voice(user_id, agent_id, content['media'])
|
||||
if content['type'] == 'file':
|
||||
await self.bot.send_file(user_id, agent_id, content['media'])
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
@@ -214,16 +240,25 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
pass
|
||||
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
async def run_async(self):
|
||||
async def shutdown_trigger_placeholder():
|
||||
async def keep_alive():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await self.bot.run_task(
|
||||
host=self.config['host'],
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
await keep_alive()
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -11,23 +11,6 @@ metadata:
|
||||
icon: wecom.png
|
||||
spec:
|
||||
config:
|
||||
- name: host
|
||||
label:
|
||||
en_US: Host
|
||||
zh_Hans: 监听主机
|
||||
description:
|
||||
en_US: Webhook host, unless you know what you're doing, please write 0.0.0.0
|
||||
zh_Hans: Webhook 监听主机,除非你知道自己在做什么,否则请写 0.0.0.0
|
||||
type: string
|
||||
required: true
|
||||
default: "0.0.0.0"
|
||||
- name: port
|
||||
label:
|
||||
en_US: Port
|
||||
zh_Hans: 监听端口
|
||||
type: integer
|
||||
required: true
|
||||
default: 2290
|
||||
- name: corpid
|
||||
label:
|
||||
en_US: Corpid
|
||||
@@ -38,21 +21,21 @@ spec:
|
||||
- name: secret
|
||||
label:
|
||||
en_US: Secret
|
||||
zh_Hans: 密钥
|
||||
zh_Hans: 密钥 (Secret)
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: token
|
||||
label:
|
||||
en_US: Token
|
||||
zh_Hans: 令牌
|
||||
zh_Hans: 令牌 (Token)
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: EncodingAESKey
|
||||
label:
|
||||
en_US: EncodingAESKey
|
||||
zh_Hans: 消息加解密密钥
|
||||
zh_Hans: 消息加解密密钥 (EncodingAESKey)
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
|
||||
@@ -8,7 +8,7 @@ import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platf
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
from langbot.pkg.platform.logger import EventLogger
|
||||
from ..logger import EventLogger
|
||||
from langbot.libs.wecom_ai_bot_api.wecombotevent import WecomBotEvent
|
||||
from langbot.libs.wecom_ai_bot_api.api import WecomBotClient
|
||||
|
||||
@@ -28,9 +28,105 @@ class WecomBotMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
||||
if event.type == 'group':
|
||||
yiri_msg_list.append(platform_message.At(target=event.ai_bot_id))
|
||||
yiri_msg_list.append(platform_message.Source(id=event.message_id, time=datetime.datetime.now()))
|
||||
yiri_msg_list.append(platform_message.Plain(text=event.content))
|
||||
if event.picurl != '':
|
||||
yiri_msg_list.append(platform_message.Image(base64=event.picurl))
|
||||
|
||||
if event.content:
|
||||
yiri_msg_list.append(platform_message.Plain(text=event.content))
|
||||
|
||||
images = []
|
||||
if event.images:
|
||||
images.extend([img for img in event.images if img])
|
||||
if not images and event.picurl:
|
||||
images.append(event.picurl)
|
||||
for image_base64 in images:
|
||||
if image_base64:
|
||||
yiri_msg_list.append(platform_message.Image(base64=image_base64))
|
||||
|
||||
file_info = event.file or {}
|
||||
if file_info:
|
||||
file_url = (
|
||||
file_info.get('download_url')
|
||||
or file_info.get('url')
|
||||
or file_info.get('fileurl')
|
||||
or file_info.get('path')
|
||||
)
|
||||
file_base64 = file_info.get('base64')
|
||||
file_name = file_info.get('filename') or file_info.get('name')
|
||||
file_size = file_info.get('filesize') or file_info.get('size')
|
||||
file_data = file_url or file_base64
|
||||
if file_data or file_name:
|
||||
file_kwargs = {}
|
||||
if file_data:
|
||||
file_kwargs['url'] = file_data
|
||||
if file_name:
|
||||
file_kwargs['name'] = file_name
|
||||
if file_size is not None:
|
||||
file_kwargs['size'] = file_size
|
||||
try:
|
||||
yiri_msg_list.append(platform_message.File(**file_kwargs))
|
||||
except Exception:
|
||||
# 兜底
|
||||
yiri_msg_list.append(platform_message.Unknown(text='[file message unsupported]'))
|
||||
|
||||
voice_info = event.voice or {}
|
||||
if voice_info:
|
||||
voice_payload = voice_info.get('base64') or voice_info.get('url')
|
||||
if voice_payload:
|
||||
if voice_info.get('base64') and not voice_payload.startswith('data:'):
|
||||
voice_payload = f'data:audio/mpeg;base64,{voice_info.get("base64")}'
|
||||
try:
|
||||
yiri_msg_list.append(platform_message.Voice(base64=voice_payload))
|
||||
except Exception:
|
||||
try:
|
||||
voice_kwargs = {'url': voice_payload}
|
||||
voice_name = voice_info.get('filename') or voice_info.get('name')
|
||||
voice_size = voice_info.get('filesize') or voice_info.get('size')
|
||||
if voice_name:
|
||||
voice_kwargs['name'] = voice_name
|
||||
if voice_size is not None:
|
||||
voice_kwargs['size'] = voice_size
|
||||
yiri_msg_list.append(platform_message.File(**voice_kwargs))
|
||||
except Exception:
|
||||
yiri_msg_list.append(platform_message.Unknown(text='[voice message unsupported]'))
|
||||
|
||||
video_info = event.video or {}
|
||||
if video_info:
|
||||
video_payload = (
|
||||
video_info.get('base64')
|
||||
or video_info.get('url')
|
||||
or video_info.get('download_url')
|
||||
or video_info.get('fileurl')
|
||||
)
|
||||
if video_payload:
|
||||
video_kwargs = {'url': video_payload}
|
||||
video_name = video_info.get('filename') or video_info.get('name')
|
||||
video_size = video_info.get('filesize') or video_info.get('size')
|
||||
if video_name:
|
||||
video_kwargs['name'] = video_name
|
||||
if video_size is not None:
|
||||
video_kwargs['size'] = video_size
|
||||
try:
|
||||
# 没有专门的视频类型,沿用 File 传递给上层
|
||||
yiri_msg_list.append(platform_message.File(**video_kwargs))
|
||||
except Exception:
|
||||
yiri_msg_list.append(platform_message.Unknown(text='[video message unsupported]'))
|
||||
|
||||
if event.msgtype == 'link' and event.link:
|
||||
link = event.link
|
||||
summary = '\n'.join(
|
||||
filter(
|
||||
None,
|
||||
[link.get('title', ''), link.get('description') or link.get('digest', ''), link.get('url', '')],
|
||||
)
|
||||
)
|
||||
if summary:
|
||||
yiri_msg_list.append(platform_message.Plain(text=summary))
|
||||
|
||||
has_content_element = any(
|
||||
not isinstance(element, (platform_message.Source, platform_message.At)) for element in yiri_msg_list
|
||||
)
|
||||
if not has_content_element:
|
||||
fallback_type = event.msgtype or 'unknown'
|
||||
yiri_msg_list.append(platform_message.Unknown(text=f'[unsupported wecom msgtype: {fallback_type}]'))
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
|
||||
return chain
|
||||
@@ -67,9 +163,6 @@ class WecomBotEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
)
|
||||
time = datetime.datetime.now().timestamp()
|
||||
return platform_events.GroupMessage(
|
||||
@@ -88,19 +181,20 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
message_converter: WecomBotMessageConverter = WecomBotMessageConverter()
|
||||
event_converter: WecomBotEventConverter = WecomBotEventConverter()
|
||||
config: dict
|
||||
bot_uuid: str = None
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
required_keys = ['Token', 'EncodingAESKey', 'Corpid', 'BotId', 'port']
|
||||
required_keys = ['Token', 'EncodingAESKey', 'Corpid', 'BotId']
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise Exception(f'WecomBot 缺少配置项: {missing_keys}')
|
||||
|
||||
# 创建运行时 bot 对象
|
||||
bot = WecomBotClient(
|
||||
Token=config['Token'],
|
||||
EnCodingAESKey=config['EncodingAESKey'],
|
||||
Corpid=config['Corpid'],
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
)
|
||||
bot_account_id = config['BotId']
|
||||
|
||||
@@ -189,16 +283,32 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
except Exception:
|
||||
print(traceback.format_exc())
|
||||
|
||||
def set_bot_uuid(self, bot_uuid: str):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
async def run_async(self):
|
||||
async def shutdown_trigger_placeholder():
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
# 保持运行但不启动独立端口
|
||||
|
||||
async def keep_alive():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await self.bot.run_task(
|
||||
host='0.0.0.0',
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
await keep_alive()
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -11,13 +11,6 @@ metadata:
|
||||
icon: wecombot.png
|
||||
spec:
|
||||
config:
|
||||
- name: port
|
||||
label:
|
||||
en_US: Port
|
||||
zh_Hans: 监听端口
|
||||
type: integer
|
||||
required: true
|
||||
default: 2291
|
||||
- name: Corpid
|
||||
label:
|
||||
en_US: Corpid
|
||||
@@ -28,14 +21,14 @@ spec:
|
||||
- name: Token
|
||||
label:
|
||||
en_US: Token
|
||||
zh_Hans: 令牌
|
||||
zh_Hans: 令牌 (Token)
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: EncodingAESKey
|
||||
label:
|
||||
en_US: EncodingAESKey
|
||||
zh_Hans: 消息加解密密钥
|
||||
zh_Hans: 消息加解密密钥 (EncodingAESKey)
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
|
||||
@@ -121,6 +121,7 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: WecomCSClient = pydantic.Field(exclude=True)
|
||||
message_converter: WecomMessageConverter = WecomMessageConverter()
|
||||
event_converter: WecomEventConverter = WecomEventConverter()
|
||||
bot_uuid: str = None
|
||||
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
|
||||
required_keys = [
|
||||
@@ -139,6 +140,7 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
token=config['token'],
|
||||
EncodingAESKey=config['EncodingAESKey'],
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
@@ -170,6 +172,10 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
pass
|
||||
|
||||
def set_bot_uuid(self, bot_uuid: str):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
@@ -190,16 +196,28 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
pass
|
||||
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
async def run_async(self):
|
||||
async def shutdown_trigger_placeholder():
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
# 保持运行但不启动独立端口
|
||||
|
||||
async def keep_alive():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await self.bot.run_task(
|
||||
host='0.0.0.0',
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
await keep_alive()
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -11,13 +11,6 @@ metadata:
|
||||
icon: wecom.png
|
||||
spec:
|
||||
config:
|
||||
- name: port
|
||||
label:
|
||||
en_US: Port
|
||||
zh_Hans: 监听端口
|
||||
type: int
|
||||
required: true
|
||||
default: 2289
|
||||
- name: corpid
|
||||
label:
|
||||
en_US: Corpid
|
||||
|
||||
@@ -22,12 +22,16 @@ class WebhookPusher:
|
||||
self.ap = ap
|
||||
self.logger = self.ap.logger
|
||||
|
||||
async def push_person_message(self, event: platform_events.FriendMessage, bot_uuid: str, adapter_name: str) -> None:
|
||||
"""Push person message event to webhooks"""
|
||||
async def push_person_message(self, event: platform_events.FriendMessage, bot_uuid: str, adapter_name: str) -> bool:
|
||||
"""Push person message event to webhooks
|
||||
|
||||
Returns:
|
||||
bool: True if any webhook responded with skip_pipeline=true, False otherwise
|
||||
"""
|
||||
try:
|
||||
webhooks = await self.ap.webhook_service.get_enabled_webhooks()
|
||||
if not webhooks:
|
||||
return
|
||||
return False
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
@@ -47,17 +51,30 @@ class WebhookPusher:
|
||||
|
||||
# Push to all webhooks asynchronously
|
||||
tasks = [self._push_to_webhook(webhook['url'], payload) for webhook in webhooks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check if any webhook responded with skip_pipeline=true
|
||||
for result in results:
|
||||
if isinstance(result, dict) and result.get('skip_pipeline') is True:
|
||||
self.logger.info(f'Webhook responded with skip_pipeline=true, skipping pipeline for person message')
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f'Failed to push person message to webhooks: {e}')
|
||||
return False
|
||||
|
||||
async def push_group_message(self, event: platform_events.GroupMessage, bot_uuid: str, adapter_name: str) -> None:
|
||||
"""Push group message event to webhooks"""
|
||||
async def push_group_message(self, event: platform_events.GroupMessage, bot_uuid: str, adapter_name: str) -> bool:
|
||||
"""Push group message event to webhooks
|
||||
|
||||
Returns:
|
||||
bool: True if any webhook responded with skip_pipeline=true, False otherwise
|
||||
"""
|
||||
try:
|
||||
webhooks = await self.ap.webhook_service.get_enabled_webhooks()
|
||||
if not webhooks:
|
||||
return
|
||||
return False
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
@@ -81,13 +98,26 @@ class WebhookPusher:
|
||||
|
||||
# Push to all webhooks asynchronously
|
||||
tasks = [self._push_to_webhook(webhook['url'], payload) for webhook in webhooks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check if any webhook responded with skip_pipeline=true
|
||||
for result in results:
|
||||
if isinstance(result, dict) and result.get('skip_pipeline') is True:
|
||||
self.logger.info(f'Webhook responded with skip_pipeline=true, skipping pipeline for group message')
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f'Failed to push group message to webhooks: {e}')
|
||||
return False
|
||||
|
||||
async def _push_to_webhook(self, url: str, payload: dict) -> None:
|
||||
"""Push payload to a single webhook URL"""
|
||||
async def _push_to_webhook(self, url: str, payload: dict) -> dict | None:
|
||||
"""Push payload to a single webhook URL
|
||||
|
||||
Returns:
|
||||
dict | None: The response JSON if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
@@ -98,9 +128,17 @@ class WebhookPusher:
|
||||
) as response:
|
||||
if response.status >= 400:
|
||||
self.logger.warning(f'Webhook {url} returned status {response.status}')
|
||||
return None
|
||||
else:
|
||||
self.logger.debug(f'Successfully pushed to webhook {url}')
|
||||
try:
|
||||
return await response.json()
|
||||
except Exception as json_error:
|
||||
self.logger.debug(f'Failed to parse JSON response from webhook {url}: {json_error}')
|
||||
return None
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning(f'Timeout pushing to webhook {url}')
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.warning(f'Error pushing to webhook {url}: {e}')
|
||||
return None
|
||||
|
||||
@@ -385,6 +385,12 @@ class PluginRuntimeConnector:
|
||||
async def get_plugin_assets(self, plugin_author: str, plugin_name: str, filepath: str) -> dict[str, Any]:
|
||||
return await self.handler.get_plugin_assets(plugin_author, plugin_name, filepath)
|
||||
|
||||
async def get_debug_info(self) -> dict[str, Any]:
|
||||
"""Get debug information including debug key and WS URL"""
|
||||
if not self.is_enable_plugin:
|
||||
return {}
|
||||
return await self.handler.get_debug_info()
|
||||
|
||||
async def emit_event(
|
||||
self,
|
||||
event: events.BaseEventModel,
|
||||
|
||||
@@ -139,6 +139,8 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
|
||||
message_chain_obj = platform_message.MessageChain.model_validate(message_chain)
|
||||
|
||||
self.ap.logger.debug(f'Reply message: {message_chain_obj.model_dump(serialize_as_any=False)}')
|
||||
|
||||
await query.adapter.reply_message(
|
||||
query.message_event,
|
||||
message_chain_obj,
|
||||
@@ -563,7 +565,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
'event_context': event_context,
|
||||
'include_plugins': include_plugins,
|
||||
},
|
||||
timeout=60,
|
||||
timeout=180,
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -758,3 +760,12 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
timeout=30,
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_debug_info(self) -> dict[str, Any]:
|
||||
"""Get debug information including debug key and WS URL"""
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.GET_DEBUG_INFO,
|
||||
{},
|
||||
timeout=10,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -4,6 +4,7 @@ import typing
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
import mimetypes
|
||||
|
||||
|
||||
from langbot.pkg.provider import runner
|
||||
@@ -12,6 +13,7 @@ import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
from langbot.pkg.utils import image
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from langbot.libs.dify_service_api.v1 import client, errors
|
||||
import httpx
|
||||
|
||||
|
||||
@runner.runner_class('dify-service-api')
|
||||
@@ -70,14 +72,43 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
content = f'<think>\n{thinking_content}\n</think>\n{content}'.strip()
|
||||
return content, thinking_content
|
||||
|
||||
async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]:
|
||||
"""预处理用户消息,提取纯文本,并将图片上传到 Dify 服务
|
||||
async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[dict]]:
|
||||
"""预处理用户消息,提取纯文本,并将图片/文件上传到 Dify 服务
|
||||
|
||||
Returns:
|
||||
tuple[str, list[str]]: 纯文本和图片的 Dify 服务图片 ID
|
||||
tuple[str, list[dict]]: 纯文本和上传后的文件描述(包含 type 与 id)
|
||||
"""
|
||||
plain_text = ''
|
||||
file_ids = []
|
||||
upload_files: list[dict] = []
|
||||
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
|
||||
|
||||
async def upload_file_bytes(file_name: str, file_bytes: bytes, content_type: str) -> str:
|
||||
file_name = file_name or 'file'
|
||||
content_type = content_type or 'application/octet-stream'
|
||||
file = (file_name, file_bytes, content_type)
|
||||
resp = await self.dify_client.upload_file(file, user_tag)
|
||||
return resp['id']
|
||||
|
||||
async def download_file(file_url: str) -> tuple[bytes, str]:
|
||||
"""Download file from url (supports data url)."""
|
||||
|
||||
async with httpx.AsyncClient() as client_session:
|
||||
resp = await client_session.get(file_url)
|
||||
resp.raise_for_status()
|
||||
content_type = (
|
||||
resp.headers.get('content-type') or mimetypes.guess_type(file_url)[0] or 'application/octet-stream'
|
||||
)
|
||||
return resp.content, content_type
|
||||
|
||||
def _detect_file_type(content_type: str) -> str:
|
||||
"""Map MIME to dify file type."""
|
||||
if content_type and content_type.startswith('image/'):
|
||||
return 'image'
|
||||
if content_type and content_type.startswith('audio/'):
|
||||
return 'audio'
|
||||
if content_type and content_type.startswith('video/'):
|
||||
return 'video'
|
||||
return 'document'
|
||||
|
||||
if isinstance(query.user_message.content, list):
|
||||
for ce in query.user_message.content:
|
||||
@@ -86,30 +117,36 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
elif ce.type == 'image_base64':
|
||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
file_bytes = base64.b64decode(image_b64)
|
||||
file = ('img.png', file_bytes, f'image/{image_format}')
|
||||
file_upload_resp = await self.dify_client.upload_file(
|
||||
file,
|
||||
f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
)
|
||||
image_id = file_upload_resp['id']
|
||||
file_ids.append(image_id)
|
||||
# elif ce.type == "file_url":
|
||||
# file_bytes = base64.b64decode(ce.file_url)
|
||||
# file_upload_resp = await self.dify_client.upload_file(
|
||||
# file_bytes,
|
||||
# f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
# )
|
||||
# file_id = file_upload_resp['id']
|
||||
# file_ids.append(file_id)
|
||||
image_id = await upload_file_bytes(f'img.{image_format}', file_bytes, f'image/{image_format}')
|
||||
upload_files.append({'type': 'image', 'id': image_id})
|
||||
elif ce.type == 'file_url':
|
||||
file_url = getattr(ce, 'file_url', None)
|
||||
file_name = getattr(ce, 'file_name', None) or 'file'
|
||||
try:
|
||||
file_bytes, content_type = await download_file(file_url)
|
||||
file_id = await upload_file_bytes(file_name, file_bytes, content_type)
|
||||
file_type = _detect_file_type(content_type)
|
||||
upload_files.append({'type': file_type, 'id': file_id})
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'dify file upload failed: {e}')
|
||||
elif ce.type == 'file_base64':
|
||||
file_name = getattr(ce, 'file_name', None) or 'file'
|
||||
|
||||
header, b64_data = ce.file_base64.split(',', 1)
|
||||
content_type = 'application/octet-stream'
|
||||
if ';' in header:
|
||||
content_type = header.split(';')[0][5:] or content_type
|
||||
file_bytes = base64.b64decode(b64_data)
|
||||
file_id = await upload_file_bytes(file_name, file_bytes, content_type)
|
||||
file_type = _detect_file_type(content_type)
|
||||
upload_files.append({'type': file_type, 'id': file_id})
|
||||
|
||||
elif isinstance(query.user_message.content, str):
|
||||
plain_text = query.user_message.content
|
||||
# plain_text = "When the file content is readable, please read the content of this file. When the file is an image, describe the content of this image." if file_ids and not plain_text else plain_text
|
||||
# plain_text = "The user message type cannot be parsed." if not file_ids and not plain_text else plain_text
|
||||
# plain_text = plain_text if plain_text else "When the file content is readable, please read the content of this file. When the file is an image, describe the content of this image."
|
||||
# print(self.pipeline_config['ai'])
|
||||
|
||||
plain_text = plain_text if plain_text else self.pipeline_config['ai']['dify-service-api']['base-prompt']
|
||||
|
||||
return plain_text, file_ids
|
||||
return plain_text, upload_files
|
||||
|
||||
async def _chat_messages(
|
||||
self, query: pipeline_query.Query
|
||||
@@ -118,14 +155,15 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
cov_id = query.session.using_conversation.uuid or ''
|
||||
query.variables['conversation_id'] = cov_id
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
plain_text, upload_files = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': 'image',
|
||||
'upload_file_id': image_id,
|
||||
'type': f['type'],
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': f['id'],
|
||||
}
|
||||
for image_id in image_ids
|
||||
for f in upload_files
|
||||
]
|
||||
|
||||
mode = 'basic' # 标记是基础编排还是工作流编排
|
||||
@@ -183,15 +221,15 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
cov_id = query.session.using_conversation.uuid or ''
|
||||
query.variables['conversation_id'] = cov_id
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
plain_text, upload_files = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': 'image',
|
||||
'type': f['type'],
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': image_id,
|
||||
'upload_file_id': f['id'],
|
||||
}
|
||||
for image_id in image_ids
|
||||
for f in upload_files
|
||||
]
|
||||
|
||||
ignored_events = []
|
||||
@@ -280,15 +318,15 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
query.variables['conversation_id'] = query.session.using_conversation.uuid
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
plain_text, upload_files = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': 'image',
|
||||
'type': f['type'],
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': image_id,
|
||||
'upload_file_id': f['id'],
|
||||
}
|
||||
for image_id in image_ids
|
||||
for f in upload_files
|
||||
]
|
||||
|
||||
ignored_events = ['text_chunk', 'workflow_started']
|
||||
@@ -298,8 +336,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
'langbot_session_id': query.variables['session_id'],
|
||||
'langbot_conversation_id': query.variables['conversation_id'],
|
||||
'langbot_msg_create_time': query.variables['msg_create_time'],
|
||||
'langbot_sender_id': query.variables.get('sender_id', ''),
|
||||
'langbot_sender_name': query.variables.get('sender_name', ''),
|
||||
}
|
||||
|
||||
inputs.update(query.variables)
|
||||
@@ -354,15 +390,15 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
cov_id = query.session.using_conversation.uuid or ''
|
||||
query.variables['conversation_id'] = cov_id
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
plain_text, upload_files = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': 'image',
|
||||
'type': f['type'],
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': image_id,
|
||||
'upload_file_id': f['id'],
|
||||
}
|
||||
for image_id in image_ids
|
||||
for f in upload_files
|
||||
]
|
||||
|
||||
basic_mode_pending_chunk = ''
|
||||
@@ -438,15 +474,15 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
cov_id = query.session.using_conversation.uuid or ''
|
||||
query.variables['conversation_id'] = cov_id
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
plain_text, upload_files = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': 'image',
|
||||
'type': f['type'],
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': image_id,
|
||||
'upload_file_id': f['id'],
|
||||
}
|
||||
for image_id in image_ids
|
||||
for f in upload_files
|
||||
]
|
||||
|
||||
ignored_events = []
|
||||
@@ -560,15 +596,15 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
query.variables['conversation_id'] = query.session.using_conversation.uuid
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
plain_text, upload_files = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': 'image',
|
||||
'type': f['type'],
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': image_id,
|
||||
'upload_file_id': f['id'],
|
||||
}
|
||||
for image_id in image_ids
|
||||
for f in upload_files
|
||||
]
|
||||
|
||||
ignored_events = ['workflow_started']
|
||||
@@ -578,8 +614,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
'langbot_session_id': query.variables['session_id'],
|
||||
'langbot_conversation_id': query.variables['conversation_id'],
|
||||
'langbot_msg_create_time': query.variables['msg_create_time'],
|
||||
'langbot_sender_id': query.variables.get('sender_id', ''),
|
||||
'langbot_sender_name': query.variables.get('sender_name', ''),
|
||||
}
|
||||
|
||||
inputs.update(query.variables)
|
||||
|
||||
@@ -94,7 +94,6 @@ class LangflowAPIRunner(runner.RequestRunner):
|
||||
if is_stream:
|
||||
# 流式请求
|
||||
async with client.stream('POST', url, json=payload, headers=headers, timeout=120.0) as response:
|
||||
print(response)
|
||||
response.raise_for_status()
|
||||
|
||||
accumulated_content = ''
|
||||
|
||||
@@ -33,6 +33,7 @@ class SessionManager:
|
||||
session = provider_session.Session(
|
||||
launcher_type=query.launcher_type,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
)
|
||||
session._semaphore = asyncio.Semaphore(session_concurrency)
|
||||
self.session_list.append(session)
|
||||
|
||||
@@ -2,7 +2,7 @@ import langbot
|
||||
|
||||
semantic_version = f'v{langbot.__version__}'
|
||||
|
||||
required_database_version = 12
|
||||
required_database_version = 13
|
||||
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
|
||||
|
||||
debug_mode = False
|
||||
|
||||
@@ -4,6 +4,8 @@ from ..core import app
|
||||
from .vdb import VectorDatabase
|
||||
from .vdbs.chroma import ChromaVectorDatabase
|
||||
from .vdbs.qdrant import QdrantVectorDatabase
|
||||
from .vdbs.milvus import MilvusVectorDatabase
|
||||
from .vdbs.pgvector_db import PgVectorDatabase
|
||||
|
||||
|
||||
class VectorDBManager:
|
||||
@@ -16,12 +18,47 @@ class VectorDBManager:
|
||||
async def initialize(self):
|
||||
kb_config = self.ap.instance_config.data.get('vdb')
|
||||
if kb_config:
|
||||
if kb_config.get('use') == 'chroma':
|
||||
vdb_type = kb_config.get('use')
|
||||
|
||||
if vdb_type == 'chroma':
|
||||
self.vector_db = ChromaVectorDatabase(self.ap)
|
||||
self.ap.logger.info('Initialized Chroma vector database backend.')
|
||||
elif kb_config.get('use') == 'qdrant':
|
||||
|
||||
elif vdb_type == 'qdrant':
|
||||
self.vector_db = QdrantVectorDatabase(self.ap)
|
||||
self.ap.logger.info('Initialized Qdrant vector database backend.')
|
||||
|
||||
elif vdb_type == 'milvus':
|
||||
# Get Milvus configuration
|
||||
milvus_config = kb_config.get('milvus', {})
|
||||
uri = milvus_config.get('uri', './data/milvus.db')
|
||||
token = milvus_config.get('token')
|
||||
self.vector_db = MilvusVectorDatabase(self.ap, uri=uri, token=token)
|
||||
self.ap.logger.info('Initialized Milvus vector database backend.')
|
||||
|
||||
elif vdb_type == 'pgvector':
|
||||
# Get pgvector configuration
|
||||
pgvector_config = kb_config.get('pgvector', {})
|
||||
connection_string = pgvector_config.get('connection_string')
|
||||
if connection_string:
|
||||
self.vector_db = PgVectorDatabase(self.ap, connection_string=connection_string)
|
||||
else:
|
||||
# Use individual parameters
|
||||
host = pgvector_config.get('host', 'localhost')
|
||||
port = pgvector_config.get('port', 5432)
|
||||
database = pgvector_config.get('database', 'langbot')
|
||||
user = pgvector_config.get('user', 'postgres')
|
||||
password = pgvector_config.get('password', 'postgres')
|
||||
self.vector_db = PgVectorDatabase(
|
||||
self.ap,
|
||||
host=host,
|
||||
port=port,
|
||||
database=database,
|
||||
user=user,
|
||||
password=password
|
||||
)
|
||||
self.ap.logger.info('Initialized pgvector database backend.')
|
||||
|
||||
else:
|
||||
self.vector_db = ChromaVectorDatabase(self.ap)
|
||||
self.ap.logger.warning('No valid vector database backend configured, defaulting to Chroma.')
|
||||
|
||||
249
src/langbot/pkg/vector/vdbs/milvus.py
Normal file
249
src/langbot/pkg/vector/vdbs/milvus.py
Normal file
@@ -0,0 +1,249 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from pymilvus import MilvusClient, DataType
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
from langbot.pkg.core import app
|
||||
|
||||
|
||||
class MilvusVectorDatabase(VectorDatabase):
|
||||
"""Milvus vector database implementation"""
|
||||
|
||||
def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None):
|
||||
"""Initialize Milvus vector database
|
||||
|
||||
Args:
|
||||
ap: Application instance
|
||||
uri: Milvus connection URI. For local file: "milvus.db"
|
||||
For remote server: "http://localhost:19530"
|
||||
token: Optional authentication token for remote connections
|
||||
"""
|
||||
self.ap = ap
|
||||
self.uri = uri
|
||||
self.token = token
|
||||
self.client = None
|
||||
self._collections = {}
|
||||
self._initialize_client()
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize Milvus client connection"""
|
||||
try:
|
||||
if self.token:
|
||||
self.client = MilvusClient(uri=self.uri, token=self.token)
|
||||
else:
|
||||
self.client = MilvusClient(uri=self.uri)
|
||||
self.ap.logger.info(f"Connected to Milvus at {self.uri}")
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Failed to connect to Milvus: {e}")
|
||||
raise
|
||||
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
"""Get or create a Milvus collection
|
||||
|
||||
Args:
|
||||
collection: Collection name (corresponds to knowledge base UUID)
|
||||
"""
|
||||
if collection in self._collections:
|
||||
return self._collections[collection]
|
||||
|
||||
# Check if collection exists
|
||||
has_collection = await asyncio.to_thread(
|
||||
self.client.has_collection, collection_name=collection
|
||||
)
|
||||
|
||||
if not has_collection:
|
||||
# Create collection with custom schema to support string IDs
|
||||
from pymilvus import CollectionSchema, FieldSchema, DataType
|
||||
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=255),
|
||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1536),
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="file_id", dtype=DataType.VARCHAR, max_length=255),
|
||||
FieldSchema(name="chunk_uuid", dtype=DataType.VARCHAR, max_length=255),
|
||||
]
|
||||
|
||||
schema = CollectionSchema(fields=fields, description="LangBot knowledge base vectors")
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.client.create_collection,
|
||||
collection_name=collection,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
# Create index for vector field (required for loading/searching)
|
||||
index_params = {
|
||||
"metric_type": "COSINE",
|
||||
"index_type": "AUTOINDEX",
|
||||
"params": {}
|
||||
}
|
||||
await asyncio.to_thread(
|
||||
self.client.create_index,
|
||||
collection_name=collection,
|
||||
field_name="vector",
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.ap.logger.info(f"Created Milvus collection '{collection}' with index")
|
||||
else:
|
||||
self.ap.logger.info(f"Milvus collection '{collection}' already exists")
|
||||
|
||||
self._collections[collection] = collection
|
||||
return collection
|
||||
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: list[str],
|
||||
embeddings_list: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Add vector embeddings to Milvus collection
|
||||
|
||||
Args:
|
||||
collection: Collection name
|
||||
ids: List of unique IDs for each vector
|
||||
embeddings_list: List of embedding vectors
|
||||
metadatas: List of metadata dictionaries for each vector
|
||||
"""
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
# Prepare data in Milvus format
|
||||
data = []
|
||||
for i, vector_id in enumerate(ids):
|
||||
entry = {
|
||||
"id": vector_id,
|
||||
"vector": embeddings_list[i],
|
||||
}
|
||||
# Add metadata fields
|
||||
if metadatas and i < len(metadatas):
|
||||
metadata = metadatas[i]
|
||||
# Add common metadata fields
|
||||
if "text" in metadata:
|
||||
entry["text"] = metadata["text"]
|
||||
if "file_id" in metadata:
|
||||
entry["file_id"] = metadata["file_id"]
|
||||
if "uuid" in metadata:
|
||||
entry["chunk_uuid"] = metadata["uuid"]
|
||||
data.append(entry)
|
||||
|
||||
# Insert data into Milvus
|
||||
await asyncio.to_thread(
|
||||
self.client.insert,
|
||||
collection_name=collection,
|
||||
data=data
|
||||
)
|
||||
|
||||
# Load collection for searching (Milvus requires this)
|
||||
await asyncio.to_thread(
|
||||
self.client.load_collection,
|
||||
collection_name=collection
|
||||
)
|
||||
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'")
|
||||
|
||||
async def search(
|
||||
self, collection: str, query_embedding: list[float], k: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for similar vectors in Milvus collection
|
||||
|
||||
Args:
|
||||
collection: Collection name
|
||||
query_embedding: Query vector
|
||||
k: Number of top results to return
|
||||
|
||||
Returns:
|
||||
Dictionary with search results in Chroma-compatible format
|
||||
"""
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
# Perform search
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {}
|
||||
}
|
||||
|
||||
results = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
collection_name=collection,
|
||||
data=[query_embedding],
|
||||
limit=k,
|
||||
search_params=search_params,
|
||||
output_fields=["text", "file_id", "chunk_uuid"]
|
||||
)
|
||||
|
||||
# Convert results to Chroma-compatible format
|
||||
# Milvus returns: [[ {id, distance, entity: {...}} ]]
|
||||
ids = []
|
||||
distances = []
|
||||
metadatas = []
|
||||
|
||||
if results and len(results) > 0:
|
||||
for hit in results[0]:
|
||||
ids.append(hit.get("id", ""))
|
||||
distances.append(hit.get("distance", 0.0))
|
||||
|
||||
# Build metadata from entity fields
|
||||
entity = hit.get("entity", {})
|
||||
metadata = {}
|
||||
if "text" in entity:
|
||||
metadata["text"] = entity["text"]
|
||||
if "file_id" in entity:
|
||||
metadata["file_id"] = entity["file_id"]
|
||||
if "chunk_uuid" in entity:
|
||||
metadata["uuid"] = entity["chunk_uuid"]
|
||||
metadatas.append(metadata)
|
||||
|
||||
# Return in Chroma-compatible format (nested lists)
|
||||
result = {
|
||||
"ids": [ids],
|
||||
"distances": [distances],
|
||||
"metadatas": [metadatas]
|
||||
}
|
||||
|
||||
self.ap.logger.info(
|
||||
f"Milvus search in '{collection}' returned {len(ids)} results"
|
||||
)
|
||||
return result
|
||||
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
"""Delete vectors from collection by file_id
|
||||
|
||||
Args:
|
||||
collection: Collection name
|
||||
file_id: File ID to filter deletion
|
||||
"""
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
# Delete entities matching the file_id
|
||||
await asyncio.to_thread(
|
||||
self.client.delete,
|
||||
collection_name=collection,
|
||||
filter=f'file_id == "{file_id}"'
|
||||
)
|
||||
self.ap.logger.info(
|
||||
f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}"
|
||||
)
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
"""Delete a Milvus collection
|
||||
|
||||
Args:
|
||||
collection: Collection name to delete
|
||||
"""
|
||||
if collection in self._collections:
|
||||
del self._collections[collection]
|
||||
|
||||
# Check if collection exists before attempting deletion
|
||||
has_collection = await asyncio.to_thread(
|
||||
self.client.has_collection, collection_name=collection
|
||||
)
|
||||
|
||||
if has_collection:
|
||||
await asyncio.to_thread(
|
||||
self.client.drop_collection, collection_name=collection
|
||||
)
|
||||
self.ap.logger.info(f"Deleted Milvus collection '{collection}'")
|
||||
else:
|
||||
self.ap.logger.warning(f"Milvus collection '{collection}' not found")
|
||||
286
src/langbot/pkg/vector/vdbs/pgvector_db.py
Normal file
286
src/langbot/pkg/vector/vdbs/pgvector_db.py
Normal file
@@ -0,0 +1,286 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from sqlalchemy import create_engine, text, Column, String, Text
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, Session
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
from langbot.pkg.core import app
|
||||
import uuid
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class PgVectorEntry(Base):
|
||||
"""SQLAlchemy model for pgvector entries"""
|
||||
__tablename__ = 'langbot_vectors'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
collection = Column(String, index=True, nullable=False)
|
||||
embedding = Column(Vector(1536)) # Default dimension, will be created dynamically
|
||||
text = Column(Text)
|
||||
file_id = Column(String, index=True)
|
||||
chunk_uuid = Column(String)
|
||||
|
||||
|
||||
class PgVectorDatabase(VectorDatabase):
|
||||
"""PostgreSQL with pgvector extension database implementation"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap: app.Application,
|
||||
connection_string: str = None,
|
||||
host: str = "localhost",
|
||||
port: int = 5432,
|
||||
database: str = "langbot",
|
||||
user: str = "postgres",
|
||||
password: str = "postgres"
|
||||
):
|
||||
"""Initialize pgvector database
|
||||
|
||||
Args:
|
||||
ap: Application instance
|
||||
connection_string: Full PostgreSQL connection string (overrides other params)
|
||||
host: PostgreSQL host
|
||||
port: PostgreSQL port
|
||||
database: Database name
|
||||
user: Database user
|
||||
password: Database password
|
||||
"""
|
||||
self.ap = ap
|
||||
|
||||
# Build connection string if not provided
|
||||
if connection_string:
|
||||
self.connection_string = connection_string
|
||||
else:
|
||||
self.connection_string = (
|
||||
f"postgresql+psycopg://{user}:{password}@{host}:{port}/{database}"
|
||||
)
|
||||
|
||||
self.async_connection_string = self.connection_string.replace(
|
||||
"postgresql://", "postgresql+asyncpg://"
|
||||
).replace(
|
||||
"postgresql+psycopg://", "postgresql+asyncpg://"
|
||||
)
|
||||
|
||||
self.engine = None
|
||||
self.async_engine = None
|
||||
self.SessionLocal = None
|
||||
self.AsyncSessionLocal = None
|
||||
self._collections = set()
|
||||
self._initialize_db()
|
||||
|
||||
def _initialize_db(self):
|
||||
"""Initialize database connection and create tables"""
|
||||
try:
|
||||
# Create async engine for async operations
|
||||
self.async_engine = create_async_engine(
|
||||
self.async_connection_string,
|
||||
echo=False,
|
||||
pool_pre_ping=True
|
||||
)
|
||||
self.AsyncSessionLocal = async_sessionmaker(
|
||||
self.async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
# Create sync engine for table creation
|
||||
sync_connection_string = self.connection_string.replace(
|
||||
"postgresql+asyncpg://", "postgresql+psycopg://"
|
||||
)
|
||||
self.engine = create_engine(sync_connection_string, echo=False)
|
||||
|
||||
# Create pgvector extension and tables
|
||||
with self.engine.connect() as conn:
|
||||
# Enable pgvector extension
|
||||
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
conn.commit()
|
||||
|
||||
# Create tables
|
||||
Base.metadata.create_all(self.engine)
|
||||
|
||||
self.ap.logger.info(f"Connected to PostgreSQL with pgvector")
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Failed to connect to PostgreSQL: {e}")
|
||||
raise
|
||||
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
"""Get or create a collection (logical grouping in pgvector)
|
||||
|
||||
Args:
|
||||
collection: Collection name (knowledge base UUID)
|
||||
"""
|
||||
# In pgvector, collections are logical - we just track them
|
||||
if collection not in self._collections:
|
||||
self._collections.add(collection)
|
||||
self.ap.logger.info(f"Registered pgvector collection '{collection}'")
|
||||
return collection
|
||||
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: list[str],
|
||||
embeddings_list: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Add vector embeddings to pgvector
|
||||
|
||||
Args:
|
||||
collection: Collection name
|
||||
ids: List of unique IDs for each vector
|
||||
embeddings_list: List of embedding vectors
|
||||
metadatas: List of metadata dictionaries
|
||||
"""
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
try:
|
||||
for i, vector_id in enumerate(ids):
|
||||
metadata = metadatas[i] if i < len(metadatas) else {}
|
||||
|
||||
entry = PgVectorEntry(
|
||||
id=vector_id,
|
||||
collection=collection,
|
||||
embedding=embeddings_list[i],
|
||||
text=metadata.get("text", ""),
|
||||
file_id=metadata.get("file_id", ""),
|
||||
chunk_uuid=metadata.get("uuid", "")
|
||||
)
|
||||
session.add(entry)
|
||||
|
||||
await session.commit()
|
||||
self.ap.logger.info(
|
||||
f"Added {len(ids)} embeddings to pgvector collection '{collection}'"
|
||||
)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.ap.logger.error(f"Error adding embeddings to pgvector: {e}")
|
||||
raise
|
||||
|
||||
async def search(
|
||||
self, collection: str, query_embedding: list[float], k: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for similar vectors using cosine distance
|
||||
|
||||
Args:
|
||||
collection: Collection name
|
||||
query_embedding: Query vector
|
||||
k: Number of top results to return
|
||||
|
||||
Returns:
|
||||
Dictionary with search results in Chroma-compatible format
|
||||
"""
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
try:
|
||||
# Use cosine distance for similarity search
|
||||
from sqlalchemy import select, func
|
||||
|
||||
# Query for similar vectors
|
||||
stmt = (
|
||||
select(
|
||||
PgVectorEntry.id,
|
||||
PgVectorEntry.text,
|
||||
PgVectorEntry.file_id,
|
||||
PgVectorEntry.chunk_uuid,
|
||||
PgVectorEntry.embedding.cosine_distance(query_embedding).label('distance')
|
||||
)
|
||||
.filter(PgVectorEntry.collection == collection)
|
||||
.order_by(PgVectorEntry.embedding.cosine_distance(query_embedding))
|
||||
.limit(k)
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert to Chroma-compatible format
|
||||
ids = []
|
||||
distances = []
|
||||
metadatas = []
|
||||
|
||||
for row in rows:
|
||||
ids.append(row.id)
|
||||
distances.append(float(row.distance))
|
||||
metadatas.append({
|
||||
"text": row.text or "",
|
||||
"file_id": row.file_id or "",
|
||||
"uuid": row.chunk_uuid or ""
|
||||
})
|
||||
|
||||
result_dict = {
|
||||
"ids": [ids],
|
||||
"distances": [distances],
|
||||
"metadatas": [metadatas]
|
||||
}
|
||||
|
||||
self.ap.logger.info(
|
||||
f"pgvector search in '{collection}' returned {len(ids)} results"
|
||||
)
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Error searching pgvector: {e}")
|
||||
raise
|
||||
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
"""Delete vectors by file_id
|
||||
|
||||
Args:
|
||||
collection: Collection name
|
||||
file_id: File ID to filter deletion
|
||||
"""
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
try:
|
||||
from sqlalchemy import delete
|
||||
|
||||
stmt = delete(PgVectorEntry).where(
|
||||
PgVectorEntry.collection == collection,
|
||||
PgVectorEntry.file_id == file_id
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
self.ap.logger.info(
|
||||
f"Deleted embeddings from pgvector collection '{collection}' with file_id: {file_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.ap.logger.error(f"Error deleting from pgvector: {e}")
|
||||
raise
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
"""Delete all vectors in a collection
|
||||
|
||||
Args:
|
||||
collection: Collection name to delete
|
||||
"""
|
||||
if collection in self._collections:
|
||||
self._collections.remove(collection)
|
||||
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
try:
|
||||
from sqlalchemy import delete
|
||||
|
||||
stmt = delete(PgVectorEntry).where(
|
||||
PgVectorEntry.collection == collection
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
self.ap.logger.info(f"Deleted pgvector collection '{collection}'")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.ap.logger.error(f"Error deleting pgvector collection: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""Close database connections"""
|
||||
if self.async_engine:
|
||||
await self.async_engine.dispose()
|
||||
if self.engine:
|
||||
self.engine.dispose()
|
||||
@@ -1,6 +1,7 @@
|
||||
admins: []
|
||||
api:
|
||||
port: 5300
|
||||
webhook_prefix: 'http://127.0.0.1:5300'
|
||||
command:
|
||||
enable: true
|
||||
prefix:
|
||||
@@ -35,6 +36,15 @@ vdb:
|
||||
host: localhost
|
||||
port: 6333
|
||||
api_key: ''
|
||||
milvus:
|
||||
uri: 'http://127.0.0.1:19530'
|
||||
token: ''
|
||||
pgvector:
|
||||
host: '127.0.0.1'
|
||||
port: 5433
|
||||
database: 'langbot'
|
||||
user: 'postgres'
|
||||
password: 'postgres'
|
||||
storage:
|
||||
use: local
|
||||
s3:
|
||||
@@ -48,3 +58,4 @@ plugin:
|
||||
runtime_ws_url: 'ws://langbot_plugin_runtime:5400/control/ws'
|
||||
enable_marketplace: true
|
||||
cloud_service_url: 'https://space.langbot.app'
|
||||
display_plugin_debug_url: 'ws://localhost:5401/plugin/debug/ws'
|
||||
|
||||
143
tests/unit_tests/config/test_webhook_display_prefix.py
Normal file
143
tests/unit_tests/config/test_webhook_display_prefix.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Tests for webhook_prefix configuration
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _apply_env_overrides_to_config(cfg: dict) -> dict:
|
||||
"""Apply environment variable overrides to data/config.yaml
|
||||
|
||||
Environment variables should be uppercase and use __ (double underscore)
|
||||
to represent nested keys. For example:
|
||||
- CONCURRENCY__PIPELINE overrides concurrency.pipeline
|
||||
- PLUGIN__RUNTIME_WS_URL overrides plugin.runtime_ws_url
|
||||
|
||||
Arrays and dict types are ignored.
|
||||
|
||||
Args:
|
||||
cfg: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Updated configuration dictionary
|
||||
"""
|
||||
|
||||
def convert_value(value: str, original_value: Any) -> Any:
|
||||
"""Convert string value to appropriate type based on original value
|
||||
|
||||
Args:
|
||||
value: String value from environment variable
|
||||
original_value: Original value to infer type from
|
||||
|
||||
Returns:
|
||||
Converted value (falls back to string if conversion fails)
|
||||
"""
|
||||
if isinstance(original_value, bool):
|
||||
return value.lower() in ('true', '1', 'yes', 'on')
|
||||
elif isinstance(original_value, int):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
# If conversion fails, keep as string (user error, but non-breaking)
|
||||
return value
|
||||
elif isinstance(original_value, float):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
# If conversion fails, keep as string (user error, but non-breaking)
|
||||
return value
|
||||
else:
|
||||
return value
|
||||
|
||||
# Process environment variables
|
||||
for env_key, env_value in os.environ.items():
|
||||
# Check if the environment variable is uppercase and contains __
|
||||
if not env_key.isupper():
|
||||
continue
|
||||
if '__' not in env_key:
|
||||
continue
|
||||
|
||||
# Convert environment variable name to config path
|
||||
# e.g., CONCURRENCY__PIPELINE -> ['concurrency', 'pipeline']
|
||||
keys = [key.lower() for key in env_key.split('__')]
|
||||
|
||||
# Navigate to the target value and validate the path
|
||||
current = cfg
|
||||
|
||||
for i, key in enumerate(keys):
|
||||
if not isinstance(current, dict) or key not in current:
|
||||
break
|
||||
|
||||
if i == len(keys) - 1:
|
||||
# At the final key - check if it's a scalar value
|
||||
if isinstance(current[key], (dict, list)):
|
||||
# Skip dict and list types
|
||||
pass
|
||||
else:
|
||||
# Valid scalar value - convert and set it
|
||||
converted_value = convert_value(env_value, current[key])
|
||||
current[key] = converted_value
|
||||
else:
|
||||
# Navigate deeper
|
||||
current = current[key]
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
class TestWebhookDisplayPrefix:
|
||||
"""Test webhook_prefix configuration functionality"""
|
||||
|
||||
def test_default_webhook_prefix(self):
|
||||
"""Test that the default webhook display prefix is correctly set"""
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
||||
|
||||
# Should have the default value
|
||||
assert cfg['api']['webhook_prefix'] == 'http://127.0.0.1:5300'
|
||||
|
||||
def test_webhook_prefix_env_override(self):
|
||||
"""Test overriding webhook_prefix via environment variable"""
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
||||
|
||||
# Set environment variable
|
||||
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com:8080'
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['api']['webhook_prefix'] == 'https://example.com:8080'
|
||||
|
||||
# Cleanup
|
||||
del os.environ['API__WEBHOOK_PREFIX']
|
||||
|
||||
def test_webhook_prefix_with_custom_domain(self):
|
||||
"""Test webhook_prefix with custom domain"""
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
||||
|
||||
# Set to a custom domain
|
||||
os.environ['API__WEBHOOK_PREFIX'] = 'https://bot.mycompany.com'
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['api']['webhook_prefix'] == 'https://bot.mycompany.com'
|
||||
|
||||
# Cleanup
|
||||
del os.environ['API__WEBHOOK_PREFIX']
|
||||
|
||||
def test_webhook_prefix_with_subdirectory(self):
|
||||
"""Test webhook_prefix with subdirectory path"""
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
||||
|
||||
# Set to a URL with subdirectory
|
||||
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com/langbot'
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['api']['webhook_prefix'] == 'https://example.com/langbot'
|
||||
|
||||
# Cleanup
|
||||
del os.environ['API__WEBHOOK_PREFIX']
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
@@ -14,6 +14,8 @@ from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
from langbot.pkg.pipeline import entities as pipeline_entities
|
||||
@@ -159,12 +161,18 @@ def sample_message_chain():
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message_event(sample_message_chain):
|
||||
"""Provides sample message event"""
|
||||
event = Mock()
|
||||
event.sender = Mock()
|
||||
event.sender.id = 12345
|
||||
event.time = 1609459200 # 2021-01-01 00:00:00
|
||||
return event
|
||||
"""Provides sample message event (FriendMessage)"""
|
||||
sender = platform_entities.Friend(
|
||||
id=12345,
|
||||
nickname='TestUser',
|
||||
remark=None,
|
||||
)
|
||||
return platform_events.FriendMessage(
|
||||
type='FriendMessage',
|
||||
sender=sender,
|
||||
message_chain=sample_message_chain,
|
||||
time=1609459200, # 2021-01-01 00:00:00
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
"""
|
||||
Test that preproc module adds sender_id and sender_name variables.
|
||||
These tests verify the variables are correctly extracted from message events.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sender_variables_friend_message():
|
||||
"""Test sender_id and sender_name are extracted from FriendMessage"""
|
||||
# Create mock Friend sender
|
||||
mock_sender = Mock()
|
||||
mock_sender.id = 'test_user_123'
|
||||
mock_sender.nickname = 'Test User'
|
||||
mock_sender.remark = 'Test Remark'
|
||||
mock_sender.get_name = Mock(return_value='Test User')
|
||||
|
||||
# Verify get_name returns nickname
|
||||
assert mock_sender.get_name() == 'Test User'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sender_variables_group_message():
|
||||
"""Test sender_id and sender_name are extracted from GroupMessage"""
|
||||
# Create mock GroupMember sender
|
||||
mock_sender = Mock()
|
||||
mock_sender.id = 'group_user_456'
|
||||
mock_sender.member_name = 'Group User Name'
|
||||
mock_sender.get_name = Mock(return_value='Group User Name')
|
||||
|
||||
# Verify get_name returns member_name
|
||||
assert mock_sender.get_name() == 'Group User Name'
|
||||
|
||||
|
||||
def test_sender_id_string_conversion():
|
||||
"""Test sender_id is converted to string"""
|
||||
# Test integer sender_id
|
||||
sender_id_int = 12345
|
||||
assert str(sender_id_int) == '12345'
|
||||
|
||||
# Test string sender_id
|
||||
sender_id_str = 'user_abc'
|
||||
assert str(sender_id_str) == 'user_abc'
|
||||
|
||||
|
||||
def test_sender_name_empty_fallback():
|
||||
"""Test sender_name defaults to empty string when not available"""
|
||||
# When sender has no name
|
||||
sender_name = ''
|
||||
assert sender_name == ''
|
||||
|
||||
|
||||
def test_variables_dict_structure():
|
||||
"""Test the variables dictionary has expected structure"""
|
||||
# Simulate what the variables dict should look like
|
||||
variables = {
|
||||
'session_id': 'person_12345',
|
||||
'conversation_id': 'conv-uuid-123',
|
||||
'msg_create_time': 1609459200,
|
||||
'sender_id': '12345',
|
||||
'sender_name': 'Test User',
|
||||
}
|
||||
|
||||
# Verify all expected keys are present
|
||||
assert 'session_id' in variables
|
||||
assert 'conversation_id' in variables
|
||||
assert 'msg_create_time' in variables
|
||||
assert 'sender_id' in variables
|
||||
assert 'sender_name' in variables
|
||||
|
||||
# Verify values
|
||||
assert variables['sender_id'] == '12345'
|
||||
assert variables['sender_name'] == 'Test User'
|
||||
|
||||
|
||||
def test_dify_workflow_inputs_structure():
|
||||
"""Test the Dify workflow inputs have expected legacy variables"""
|
||||
plain_text = 'Hello world'
|
||||
variables = {
|
||||
'session_id': 'person_12345',
|
||||
'conversation_id': 'conv-uuid-123',
|
||||
'msg_create_time': 1609459200,
|
||||
'sender_id': '12345',
|
||||
'sender_name': 'Test User',
|
||||
}
|
||||
|
||||
# Simulate Dify workflow inputs structure
|
||||
inputs = {
|
||||
'langbot_user_message_text': plain_text,
|
||||
'langbot_session_id': variables['session_id'],
|
||||
'langbot_conversation_id': variables['conversation_id'],
|
||||
'langbot_msg_create_time': variables['msg_create_time'],
|
||||
'langbot_sender_id': variables.get('sender_id', ''),
|
||||
'langbot_sender_name': variables.get('sender_name', ''),
|
||||
}
|
||||
inputs.update(variables)
|
||||
|
||||
# Verify all legacy variables are present
|
||||
assert inputs['langbot_user_message_text'] == 'Hello world'
|
||||
assert inputs['langbot_session_id'] == 'person_12345'
|
||||
assert inputs['langbot_conversation_id'] == 'conv-uuid-123'
|
||||
assert inputs['langbot_msg_create_time'] == 1609459200
|
||||
assert inputs['langbot_sender_id'] == '12345'
|
||||
assert inputs['langbot_sender_name'] == 'Test User'
|
||||
|
||||
# Verify regular variables are also present
|
||||
assert inputs['sender_id'] == '12345'
|
||||
assert inputs['sender_name'] == 'Test User'
|
||||
@@ -51,7 +51,7 @@
|
||||
"input-otp": "^1.4.2",
|
||||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.507.0",
|
||||
"next": "15.4.7",
|
||||
"next": "15.4.8",
|
||||
"next-themes": "^0.4.6",
|
||||
"postcss": "^8.5.3",
|
||||
"react": "^19.0.0",
|
||||
@@ -74,11 +74,18 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/eslintrc": "^3",
|
||||
"@types/debug": "^4.1.12",
|
||||
"@types/estree": "^1.0.8",
|
||||
"@types/estree-jsx": "^1.0.5",
|
||||
"@types/hast": "^3.0.4",
|
||||
"@types/lodash": "^4.17.16",
|
||||
"@types/mdast": "^4.0.4",
|
||||
"@types/ms": "^2.1.0",
|
||||
"@types/node": "^20",
|
||||
"@types/react": "^19",
|
||||
"@types/react-dom": "^19",
|
||||
"@types/react-syntax-highlighter": "^15.5.13",
|
||||
"@types/unist": "^3.0.3",
|
||||
"eslint": "^9",
|
||||
"eslint-config-next": "15.2.4",
|
||||
"eslint-config-prettier": "^10.1.2",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import {
|
||||
IChooseAdapterEntity,
|
||||
IPipelineEntity,
|
||||
@@ -112,11 +112,86 @@ export default function BotForm({
|
||||
IDynamicFormItemSchema[]
|
||||
>([]);
|
||||
const [, setIsLoading] = useState<boolean>(false);
|
||||
const [webhookUrl, setWebhookUrl] = useState<string>('');
|
||||
const webhookInputRef = React.useRef<HTMLInputElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
setBotFormValues();
|
||||
}, []);
|
||||
|
||||
// 复制到剪贴板的辅助函数 - 使用页面上的真实input元素
|
||||
const copyToClipboard = () => {
|
||||
console.log('[Copy] Attempting to copy from input element');
|
||||
|
||||
const inputElement = webhookInputRef.current;
|
||||
if (!inputElement) {
|
||||
console.error('[Copy] Input element not found');
|
||||
toast.error(t('common.copyFailed'));
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// 确保input元素可见且未被禁用
|
||||
inputElement.disabled = false;
|
||||
inputElement.readOnly = false;
|
||||
|
||||
// 聚焦并选中所有文本
|
||||
inputElement.focus();
|
||||
inputElement.select();
|
||||
|
||||
// 尝试使用现代API
|
||||
if (navigator.clipboard && navigator.clipboard.writeText) {
|
||||
console.log(
|
||||
'[Copy] Using Clipboard API with input value:',
|
||||
inputElement.value,
|
||||
);
|
||||
navigator.clipboard
|
||||
.writeText(inputElement.value)
|
||||
.then(() => {
|
||||
console.log('[Copy] Clipboard API success');
|
||||
inputElement.blur(); // 取消选中
|
||||
inputElement.readOnly = true;
|
||||
toast.success(t('bots.webhookUrlCopied'));
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error(
|
||||
'[Copy] Clipboard API failed, trying execCommand:',
|
||||
err,
|
||||
);
|
||||
// 降级到execCommand
|
||||
const successful = document.execCommand('copy');
|
||||
console.log('[Copy] execCommand result:', successful);
|
||||
inputElement.blur();
|
||||
inputElement.readOnly = true;
|
||||
if (successful) {
|
||||
toast.success(t('bots.webhookUrlCopied'));
|
||||
} else {
|
||||
toast.error(t('common.copyFailed'));
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// 直接使用execCommand
|
||||
console.log(
|
||||
'[Copy] Using execCommand with input value:',
|
||||
inputElement.value,
|
||||
);
|
||||
const successful = document.execCommand('copy');
|
||||
console.log('[Copy] execCommand result:', successful);
|
||||
inputElement.blur();
|
||||
inputElement.readOnly = true;
|
||||
if (successful) {
|
||||
toast.success(t('bots.webhookUrlCopied'));
|
||||
} else {
|
||||
toast.error(t('common.copyFailed'));
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('[Copy] Copy failed:', err);
|
||||
inputElement.readOnly = true;
|
||||
toast.error(t('common.copyFailed'));
|
||||
}
|
||||
};
|
||||
|
||||
function setBotFormValues() {
|
||||
initBotFormComponent().then(() => {
|
||||
// 拉取初始化表单信息
|
||||
@@ -131,12 +206,20 @@ export default function BotForm({
|
||||
form.setValue('use_pipeline_uuid', val.use_pipeline_uuid || '');
|
||||
handleAdapterSelect(val.adapter);
|
||||
// dynamicForm.setFieldsValue(val.adapter_config);
|
||||
|
||||
// 设置 webhook 地址(如果有)
|
||||
if (val.webhook_full_url) {
|
||||
setWebhookUrl(val.webhook_full_url);
|
||||
} else {
|
||||
setWebhookUrl('');
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
toast.error(t('bots.getBotConfigError') + err.message);
|
||||
});
|
||||
} else {
|
||||
form.reset();
|
||||
setWebhookUrl('');
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -209,7 +292,7 @@ export default function BotForm({
|
||||
|
||||
async function getBotConfig(
|
||||
botId: string,
|
||||
): Promise<z.infer<typeof formSchema>> {
|
||||
): Promise<z.infer<typeof formSchema> & { webhook_full_url?: string }> {
|
||||
return new Promise((resolve, reject) => {
|
||||
httpClient
|
||||
.getBot(botId)
|
||||
@@ -222,6 +305,10 @@ export default function BotForm({
|
||||
adapter_config: bot.adapter_config,
|
||||
enable: bot.enable ?? true,
|
||||
use_pipeline_uuid: bot.use_pipeline_uuid ?? '',
|
||||
webhook_full_url: bot.adapter_runtime_values
|
||||
? ((bot.adapter_runtime_values as Record<string, unknown>)
|
||||
.webhook_full_url as string)
|
||||
: undefined,
|
||||
});
|
||||
})
|
||||
.catch((err) => {
|
||||
@@ -360,51 +447,86 @@ export default function BotForm({
|
||||
<div className="space-y-4">
|
||||
{/* 是否启用 & 绑定流水线 仅在编辑模式 */}
|
||||
{initBotId && (
|
||||
<div className="flex items-center gap-6">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="enable"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex flex-col justify-start gap-[0.8rem] h-[3.8rem]">
|
||||
<FormLabel>{t('common.enable')}</FormLabel>
|
||||
<FormControl>
|
||||
<Switch
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
/>
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<>
|
||||
<div className="flex items-center gap-6">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="enable"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex flex-col justify-start gap-[0.8rem] h-[3.8rem]">
|
||||
<FormLabel>{t('common.enable')}</FormLabel>
|
||||
<FormControl>
|
||||
<Switch
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
/>
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="use_pipeline_uuid"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex flex-col justify-start gap-[0.8rem] h-[3.8rem]">
|
||||
<FormLabel>{t('bots.bindPipeline')}</FormLabel>
|
||||
<FormControl>
|
||||
<Select onValueChange={field.onChange} {...field}>
|
||||
<SelectTrigger className="bg-[#ffffff] dark:bg-[#2a2a2e]">
|
||||
<SelectValue
|
||||
placeholder={t('bots.selectPipeline')}
|
||||
/>
|
||||
</SelectTrigger>
|
||||
<SelectContent className="fixed z-[1000]">
|
||||
<SelectGroup>
|
||||
{pipelineNameList.map((item) => (
|
||||
<SelectItem key={item.value} value={item.value}>
|
||||
{item.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="use_pipeline_uuid"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex flex-col justify-start gap-[0.8rem] h-[3.8rem]">
|
||||
<FormLabel>{t('bots.bindPipeline')}</FormLabel>
|
||||
<FormControl>
|
||||
<Select onValueChange={field.onChange} {...field}>
|
||||
<SelectTrigger className="bg-[#ffffff] dark:bg-[#2a2a2e]">
|
||||
<SelectValue
|
||||
placeholder={t('bots.selectPipeline')}
|
||||
/>
|
||||
</SelectTrigger>
|
||||
<SelectContent className="fixed z-[1000]">
|
||||
<SelectGroup>
|
||||
{pipelineNameList.map((item) => (
|
||||
<SelectItem
|
||||
key={item.value}
|
||||
value={item.value}
|
||||
>
|
||||
{item.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Webhook 地址显示(统一 Webhook 模式) */}
|
||||
{webhookUrl && (
|
||||
<FormItem>
|
||||
<FormLabel>{t('bots.webhookUrl')}</FormLabel>
|
||||
<div className="flex items-center gap-2">
|
||||
<Input
|
||||
ref={webhookInputRef}
|
||||
value={webhookUrl}
|
||||
readOnly
|
||||
className="flex-1 bg-gray-50 dark:bg-gray-900"
|
||||
onClick={(e) => {
|
||||
// 点击输入框时自动全选
|
||||
(e.target as HTMLInputElement).select();
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={copyToClipboard}
|
||||
>
|
||||
{t('common.copy')}
|
||||
</Button>
|
||||
</div>
|
||||
<p className="text-sm text-gray-500 mt-1">
|
||||
{t('bots.webhookUrlHint')}
|
||||
</p>
|
||||
</FormItem>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
<FormField
|
||||
|
||||
@@ -44,12 +44,35 @@ export function BotLogCard({ botLog }: { botLog: BotLog }) {
|
||||
const strArr = str.split('');
|
||||
return strArr;
|
||||
}
|
||||
|
||||
// 根据日志级别返回对应的样式类
|
||||
function getLevelStyles(level: string) {
|
||||
switch (level.toLowerCase()) {
|
||||
case 'error':
|
||||
return 'bg-red-100 text-red-800 dark:bg-red-900/30 dark:text-red-400';
|
||||
case 'warning':
|
||||
return 'bg-orange-100 text-orange-800 dark:bg-orange-900/30 dark:text-orange-400';
|
||||
case 'info':
|
||||
return 'bg-blue-100 text-blue-800 dark:bg-blue-900/30 dark:text-blue-400';
|
||||
case 'debug':
|
||||
return 'bg-gray-100 text-gray-800 dark:bg-gray-900/30 dark:text-gray-400';
|
||||
default:
|
||||
return 'bg-gray-100 text-gray-800 dark:bg-gray-900/30 dark:text-gray-400';
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={`${styles.botLogCardContainer}`}>
|
||||
{/* 头部标签,时间 */}
|
||||
<div className={`${styles.cardTitleContainer}`}>
|
||||
<div className={`flex flex-row gap-4`}>
|
||||
<div className={`${styles.tag}`}>{botLog.level}</div>
|
||||
<div className={`flex flex-row gap-2 items-center`}>
|
||||
<div
|
||||
className={`px-2 py-1 rounded text-xs font-medium uppercase ${getLevelStyles(
|
||||
botLog.level,
|
||||
)}`}
|
||||
>
|
||||
{botLog.level}
|
||||
</div>
|
||||
{botLog.message_session_id && (
|
||||
<div
|
||||
className={`${styles.tag} ${styles.chatTag}`}
|
||||
@@ -60,6 +83,7 @@ export function BotLogCard({ botLog }: { botLog: BotLog }) {
|
||||
toast.success(t('common.copySuccess'));
|
||||
});
|
||||
}}
|
||||
title={t('common.clickToCopy')}
|
||||
>
|
||||
<svg
|
||||
className="icon"
|
||||
@@ -67,8 +91,8 @@ export function BotLogCard({ botLog }: { botLog: BotLog }) {
|
||||
version="1.1"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
p-id="1664"
|
||||
width="20"
|
||||
height="20"
|
||||
width="16"
|
||||
height="16"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path
|
||||
@@ -87,7 +111,6 @@ export function BotLogCard({ botLog }: { botLog: BotLog }) {
|
||||
fill="currentColor"
|
||||
></path>
|
||||
</svg>
|
||||
{/* 会话ID */}
|
||||
|
||||
<span className={`${styles.chatId}`}>
|
||||
{getSubChatId(botLog.message_session_id)}
|
||||
@@ -95,22 +118,25 @@ export function BotLogCard({ botLog }: { botLog: BotLog }) {
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div>{formatTime(botLog.timestamp)}</div>
|
||||
</div>
|
||||
<div className={`${styles.cardTitleContainer} ${styles.cardText}`}>
|
||||
{botLog.text}
|
||||
</div>
|
||||
<PhotoProvider className={``}>
|
||||
<div className={`w-50 mt-2`}>
|
||||
{botLog.images.map((item) => (
|
||||
<img
|
||||
key={item}
|
||||
src={`${baseURL}/api/v1/files/image/${item}`}
|
||||
alt=""
|
||||
/>
|
||||
))}
|
||||
<div className={`${styles.timestamp}`}>
|
||||
{formatTime(botLog.timestamp)}
|
||||
</div>
|
||||
</PhotoProvider>
|
||||
</div>
|
||||
<div className={`${styles.cardText}`}>{botLog.text}</div>
|
||||
{botLog.images.length > 0 && (
|
||||
<PhotoProvider>
|
||||
<div className={`flex flex-wrap gap-2 mt-3`}>
|
||||
{botLog.images.map((item) => (
|
||||
<img
|
||||
key={item}
|
||||
src={`${baseURL}/api/v1/files/image/${item}`}
|
||||
alt=""
|
||||
className="max-w-xs rounded cursor-pointer hover:opacity-90 transition-opacity"
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</PhotoProvider>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
'use client';
|
||||
|
||||
import { BotLogManager } from '@/app/home/bots/components/bot-log/BotLogManager';
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useCallback, useEffect, useRef, useState, useMemo } from 'react';
|
||||
import { BotLog } from '@/app/infra/http/requestParam/bots/GetBotLogsResponse';
|
||||
import { BotLogCard } from '@/app/home/bots/components/bot-log/view/BotLogCard';
|
||||
import styles from './botLog.module.css';
|
||||
import { Switch } from '@/components/ui/switch';
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from '@/components/ui/popover';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Checkbox } from '@/components/ui/checkbox';
|
||||
import { ChevronDownIcon } from 'lucide-react';
|
||||
import { debounce } from 'lodash';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -14,9 +22,21 @@ export function BotLogListComponent({ botId }: { botId: string }) {
|
||||
const manager = useRef(new BotLogManager(botId)).current;
|
||||
const [botLogList, setBotLogList] = useState<BotLog[]>([]);
|
||||
const [autoFlush, setAutoFlush] = useState(true);
|
||||
const [selectedLevels, setSelectedLevels] = useState<string[]>([
|
||||
'info',
|
||||
'warning',
|
||||
'error',
|
||||
]);
|
||||
const listContainerRef = useRef<HTMLDivElement>(null);
|
||||
const botLogListRef = useRef<BotLog[]>(botLogList);
|
||||
|
||||
const logLevels = [
|
||||
{ value: 'error', label: 'ERROR' },
|
||||
{ value: 'warning', label: 'WARNING' },
|
||||
{ value: 'info', label: 'INFO' },
|
||||
{ value: 'debug', label: 'DEBUG' },
|
||||
];
|
||||
|
||||
useEffect(() => {
|
||||
initComponent();
|
||||
return () => {
|
||||
@@ -28,6 +48,42 @@ export function BotLogListComponent({ botId }: { botId: string }) {
|
||||
botLogListRef.current = botLogList;
|
||||
}, [botLogList]);
|
||||
|
||||
// 根据级别过滤日志
|
||||
const filteredLogs = useMemo(() => {
|
||||
if (selectedLevels.length === 0) {
|
||||
return botLogList;
|
||||
}
|
||||
return botLogList.filter((log) => selectedLevels.includes(log.level));
|
||||
}, [botLogList, selectedLevels]);
|
||||
|
||||
const handleLevelToggle = (levelValue: string) => {
|
||||
setSelectedLevels((prev) => {
|
||||
if (prev.includes(levelValue)) {
|
||||
return prev.filter((l) => l !== levelValue);
|
||||
} else {
|
||||
return [...prev, levelValue];
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const getDisplayText = () => {
|
||||
if (selectedLevels.length === 0) {
|
||||
return t('bots.selectLevel');
|
||||
}
|
||||
if (selectedLevels.length === logLevels.length) {
|
||||
return t('bots.allLevels');
|
||||
}
|
||||
// 如果选中3个或以上,显示数量
|
||||
if (selectedLevels.length >= 3) {
|
||||
return `${selectedLevels.length} ${t('bots.levelsSelected')}`;
|
||||
}
|
||||
// 显示选中级别的标签(大写形式)
|
||||
return logLevels
|
||||
.filter((level) => selectedLevels.includes(level.value))
|
||||
.map((level) => level.label)
|
||||
.join(', ');
|
||||
};
|
||||
|
||||
// 观测自动刷新状态
|
||||
useEffect(() => {
|
||||
if (autoFlush) {
|
||||
@@ -116,9 +172,43 @@ export function BotLogListComponent({ botId }: { botId: string }) {
|
||||
<div className={`${styles.listHeader}`}>
|
||||
<div className={'mr-2'}>{t('bots.enableAutoRefresh')}</div>
|
||||
<Switch checked={autoFlush} onCheckedChange={(e) => setAutoFlush(e)} />
|
||||
<div className={'ml-4 mr-2'}>{t('bots.logLevel')}</div>
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
className="w-[180px] flex items-center justify-between"
|
||||
>
|
||||
<span className="text-sm truncate flex-1 text-left">
|
||||
{getDisplayText()}
|
||||
</span>
|
||||
<ChevronDownIcon className="ml-2 h-4 w-4 flex-shrink-0" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-[180px] p-2">
|
||||
<div className="flex flex-col gap-2">
|
||||
{logLevels.map((level) => (
|
||||
<div key={level.value} className="flex items-center space-x-2">
|
||||
<Checkbox
|
||||
id={level.value}
|
||||
checked={selectedLevels.includes(level.value)}
|
||||
onCheckedChange={() => handleLevelToggle(level.value)}
|
||||
/>
|
||||
<label
|
||||
htmlFor={level.value}
|
||||
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70 cursor-pointer"
|
||||
>
|
||||
{level.label}
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</div>
|
||||
|
||||
{botLogList.map((botLog) => {
|
||||
{filteredLogs.map((botLog) => {
|
||||
return <BotLogCard botLog={botLog} key={botLog.seq_id} />;
|
||||
})}
|
||||
</div>
|
||||
|
||||
@@ -1,21 +1,32 @@
|
||||
.botLogListContainer {
|
||||
width: 100%;
|
||||
max-width: 100%;
|
||||
min-height: 10rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: flex-start;
|
||||
overflow-y: scroll;
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.botLogCardContainer {
|
||||
width: 100%;
|
||||
max-width: 100%;
|
||||
background-color: #fff;
|
||||
border-radius: 10px;
|
||||
border: 1px solid #cbd5e1;
|
||||
padding: 1.2rem;
|
||||
margin-bottom: 1rem;
|
||||
cursor: pointer;
|
||||
border-radius: 8px;
|
||||
border: 1px solid #e2e8f0;
|
||||
padding: 1rem;
|
||||
margin-bottom: 0.75rem;
|
||||
transition: all 0.2s ease;
|
||||
overflow: hidden;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.botLogCardContainer:hover {
|
||||
border-color: #cbd5e1;
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
:global(.dark) .botLogCardContainer {
|
||||
@@ -23,40 +34,74 @@
|
||||
border: 1px solid #2a2a2e;
|
||||
}
|
||||
|
||||
:global(.dark) .botLogCardContainer:hover {
|
||||
border-color: #3a3a3e;
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
|
||||
.listHeader {
|
||||
width: 100%;
|
||||
height: 2.5rem;
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.tag {
|
||||
display: flex;
|
||||
display: inline-flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
justify-content: flex-start;
|
||||
gap: 0.2rem;
|
||||
height: 1.5rem;
|
||||
padding: 0.5rem;
|
||||
border-radius: 0.4rem;
|
||||
background-color: #a5d8ff;
|
||||
color: #ffffff;
|
||||
justify-content: center;
|
||||
gap: 0.25rem;
|
||||
height: auto;
|
||||
padding: 0.25rem 0.5rem;
|
||||
border-radius: 4px;
|
||||
background-color: #dbeafe;
|
||||
color: #1e40af;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 500;
|
||||
max-width: 16rem;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.025em;
|
||||
}
|
||||
|
||||
:global(.dark) .tag {
|
||||
background-color: #1e3a8a;
|
||||
color: #93c5fd;
|
||||
}
|
||||
|
||||
.chatTag {
|
||||
color: #626262;
|
||||
background-color: #d1d1d1;
|
||||
color: #4b5563;
|
||||
background-color: #f3f4f6;
|
||||
text-transform: none;
|
||||
cursor: pointer;
|
||||
transition: all 0.15s ease;
|
||||
}
|
||||
|
||||
.chatTag:hover {
|
||||
background-color: #e5e7eb;
|
||||
}
|
||||
|
||||
:global(.dark) .chatTag {
|
||||
color: #9ca3af;
|
||||
background-color: #374151;
|
||||
}
|
||||
|
||||
:global(.dark) .chatTag:hover {
|
||||
background-color: #4b5563;
|
||||
}
|
||||
|
||||
.chatId {
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
font-family: 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas,
|
||||
'Courier New', monospace;
|
||||
font-size: 0.7rem;
|
||||
}
|
||||
|
||||
.cardTitleContainer {
|
||||
@@ -65,9 +110,33 @@
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.cardText {
|
||||
margin-top: 0.4rem;
|
||||
color: #1e293b;
|
||||
font-size: 0.875rem;
|
||||
line-height: 1.7;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
word-break: break-all;
|
||||
overflow-wrap: anywhere;
|
||||
hyphens: auto;
|
||||
max-width: 100%;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen',
|
||||
'Ubuntu', 'Cantarell', sans-serif;
|
||||
}
|
||||
|
||||
:global(.dark) .cardText {
|
||||
color: #e2e8f0;
|
||||
}
|
||||
|
||||
.timestamp {
|
||||
color: #64748b;
|
||||
font-size: 0.75rem;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
:global(.dark) .timestamp {
|
||||
color: #64748b;
|
||||
}
|
||||
|
||||
@@ -509,6 +509,17 @@ export default function ExternalKBForm({
|
||||
</Select>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{t('knowledge.retrieverInstallInfo')}{' '}
|
||||
<a
|
||||
href="https://space.langbot.app/market?category=KnowledgeRetriever"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-primary underline hover:no-underline"
|
||||
>
|
||||
{t('knowledge.retrieverMarketLink')}
|
||||
</a>
|
||||
</p>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
@@ -154,7 +154,7 @@ export default function PipelineDialog({
|
||||
// 编辑流水线时的对话框
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="overflow-hidden p-0 !max-w-[50rem] h-[75vh] flex">
|
||||
<DialogContent className="overflow-hidden p-0 !max-w-[80vw] h-[75vh] flex">
|
||||
<SidebarProvider className="items-start w-full flex h-full min-h-0">
|
||||
<Sidebar
|
||||
collapsible="none"
|
||||
|
||||
@@ -20,6 +20,13 @@ import { toast } from 'sonner';
|
||||
import AtBadge from './AtBadge';
|
||||
import { WebSocketClient } from '@/app/infra/websocket/WebSocketClient';
|
||||
import ImagePreviewDialog from './ImagePreviewDialog';
|
||||
import ReactMarkdown from 'react-markdown';
|
||||
import remarkGfm from 'remark-gfm';
|
||||
import rehypeHighlight from 'rehype-highlight';
|
||||
import rehypeRaw from 'rehype-raw';
|
||||
import rehypeSlug from 'rehype-slug';
|
||||
import rehypeAutolinkHeadings from 'rehype-autolink-headings';
|
||||
import '@/styles/github-markdown.css';
|
||||
|
||||
interface DebugDialogProps {
|
||||
open: boolean;
|
||||
@@ -50,7 +57,9 @@ export default function DebugDialog({
|
||||
const [previewImageUrl, setPreviewImageUrl] = useState<string>('');
|
||||
const [showImagePreview, setShowImagePreview] = useState(false);
|
||||
const [quotedMessage, setQuotedMessage] = useState<Message | null>(null);
|
||||
const [hoveredMessageId, setHoveredMessageId] = useState<number | null>(null);
|
||||
const [rawModeMessages, setRawModeMessages] = useState<Set<string>>(
|
||||
new Set(),
|
||||
);
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const popoverRef = useRef<HTMLDivElement>(null);
|
||||
@@ -195,6 +204,8 @@ export default function DebugDialog({
|
||||
// 监听 sessionType 和 selectedPipelineId 变化,重新加载消息和连接
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
// 清空当前消息,避免显示旧的消息
|
||||
setMessages([]);
|
||||
loadMessages(selectedPipelineId);
|
||||
initWebSocket(selectedPipelineId);
|
||||
}
|
||||
@@ -554,7 +565,111 @@ export default function DebugDialog({
|
||||
return t('bots.earlier');
|
||||
};
|
||||
|
||||
// Generate a unique key for a message
|
||||
const getMessageKey = (message: Message): string => {
|
||||
return `${message.id}-${message.timestamp}`;
|
||||
};
|
||||
|
||||
// Toggle raw mode for a message (by default, messages are in markdown mode)
|
||||
const toggleRawMode = (message: Message) => {
|
||||
const key = getMessageKey(message);
|
||||
setRawModeMessages((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
if (newSet.has(key)) {
|
||||
newSet.delete(key);
|
||||
} else {
|
||||
newSet.add(key);
|
||||
}
|
||||
return newSet;
|
||||
});
|
||||
};
|
||||
|
||||
// Check if message has any Plain text content
|
||||
const hasPlainText = (message: Message): boolean => {
|
||||
return message.message_chain.some((c) => c.type === 'Plain');
|
||||
};
|
||||
|
||||
// Extract plain text from message chain
|
||||
const getPlainText = (message: Message): string => {
|
||||
return message.message_chain
|
||||
.filter((c) => c.type === 'Plain')
|
||||
.map((c) => (c as Plain).text)
|
||||
.join('');
|
||||
};
|
||||
|
||||
const renderMessageContent = (message: Message) => {
|
||||
const key = getMessageKey(message);
|
||||
const isRawMode = rawModeMessages.has(key);
|
||||
|
||||
// By default, render with markdown if there's plain text (unless raw mode is enabled)
|
||||
if (!isRawMode && hasPlainText(message)) {
|
||||
const plainText = getPlainText(message);
|
||||
const nonPlainComponents = message.message_chain.filter(
|
||||
(c) => c.type !== 'Plain' && c.type !== 'Source',
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="text-base leading-relaxed align-middle">
|
||||
{/* Render non-Plain components first */}
|
||||
{nonPlainComponents.map((component, index) =>
|
||||
renderMessageComponent(component, index),
|
||||
)}
|
||||
{/* Render Plain text as markdown */}
|
||||
<div className="markdown-body">
|
||||
<ReactMarkdown
|
||||
remarkPlugins={[remarkGfm]}
|
||||
rehypePlugins={[
|
||||
rehypeRaw,
|
||||
rehypeHighlight,
|
||||
rehypeSlug,
|
||||
[
|
||||
rehypeAutolinkHeadings,
|
||||
{
|
||||
behavior: 'wrap',
|
||||
properties: {
|
||||
className: ['anchor'],
|
||||
},
|
||||
},
|
||||
],
|
||||
]}
|
||||
components={{
|
||||
ul: ({ children }) => <ul className="list-disc">{children}</ul>,
|
||||
ol: ({ children }) => (
|
||||
<ol className="list-decimal">{children}</ol>
|
||||
),
|
||||
li: ({ children }) => <li className="ml-4">{children}</li>,
|
||||
img: ({ src, alt, ...props }) => {
|
||||
const imageSrc = src || '';
|
||||
|
||||
if (typeof imageSrc !== 'string') {
|
||||
return (
|
||||
<img
|
||||
src={src}
|
||||
alt={alt || ''}
|
||||
className="max-w-full h-auto rounded-lg my-4"
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<img
|
||||
src={imageSrc}
|
||||
alt={alt || ''}
|
||||
className="max-w-lg h-auto my-4"
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
},
|
||||
}}
|
||||
>
|
||||
{plainText}
|
||||
</ReactMarkdown>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="text-base leading-relaxed align-middle whitespace-pre-wrap">
|
||||
{message.message_chain.map((component, index) =>
|
||||
@@ -619,68 +734,109 @@ export default function DebugDialog({
|
||||
<div
|
||||
key={message.id + message.timestamp}
|
||||
className={cn(
|
||||
'flex group',
|
||||
'flex',
|
||||
message.role === 'user' ? 'justify-end' : 'justify-start',
|
||||
)}
|
||||
onMouseEnter={() => setHoveredMessageId(message.id)}
|
||||
onMouseLeave={() => setHoveredMessageId(null)}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
'relative flex items-end gap-2',
|
||||
message.role === 'user' ? 'flex-row-reverse' : 'flex-row',
|
||||
'max-w-3xl px-5 py-3 rounded-2xl',
|
||||
message.role === 'user'
|
||||
? 'user-message-bubble bg-blue-100 dark:bg-blue-900 text-gray-900 dark:text-gray-100 rounded-br-none'
|
||||
: 'bg-gray-100 dark:bg-gray-800 text-gray-900 dark:text-gray-100 rounded-bl-none',
|
||||
)}
|
||||
>
|
||||
{renderMessageContent(message)}
|
||||
<div
|
||||
className={cn(
|
||||
'max-w-md px-5 py-3 rounded-2xl',
|
||||
'text-xs mt-2 flex items-center justify-between gap-2',
|
||||
message.role === 'user'
|
||||
? 'bg-[#2288ee] text-white rounded-br-none'
|
||||
: 'bg-gray-100 dark:bg-gray-800 text-gray-900 dark:text-gray-100 rounded-bl-none',
|
||||
? 'text-gray-600 dark:text-gray-300'
|
||||
: 'text-gray-500 dark:text-gray-400',
|
||||
)}
|
||||
>
|
||||
{renderMessageContent(message)}
|
||||
<div
|
||||
className={cn(
|
||||
'text-xs mt-2 flex items-center justify-between gap-2',
|
||||
message.role === 'user'
|
||||
? 'text-white/70'
|
||||
: 'text-gray-500 dark:text-gray-400',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<span>
|
||||
{message.role === 'user'
|
||||
? t('pipelines.debugDialog.userMessage')
|
||||
: t('pipelines.debugDialog.botMessage')}
|
||||
</span>
|
||||
<span className="text-[10px]">
|
||||
{formatTimestamp(getMessageTimestamp(message))}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
{hoveredMessageId === message.id && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 px-2 text-xs opacity-0 group-hover:opacity-100 transition-opacity bg-gray-200 dark:bg-gray-700 hover:bg-gray-300 dark:hover:bg-gray-600 whitespace-nowrap"
|
||||
onClick={() => setQuotedMessage(message)}
|
||||
>
|
||||
<svg
|
||||
className="w-3 h-3 mr-1"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
{hasPlainText(message) && (
|
||||
<button
|
||||
onClick={() => toggleRawMode(message)}
|
||||
className={cn(
|
||||
'px-1.5 py-0.5 rounded text-[10px] transition-colors',
|
||||
message.role === 'user'
|
||||
? 'hover:bg-blue-200 dark:hover:bg-blue-800'
|
||||
: 'hover:bg-gray-200 dark:hover:bg-gray-700',
|
||||
)}
|
||||
title={
|
||||
rawModeMessages.has(getMessageKey(message))
|
||||
? t('pipelines.debugDialog.showMarkdown')
|
||||
: t('pipelines.debugDialog.showRaw')
|
||||
}
|
||||
>
|
||||
{rawModeMessages.has(getMessageKey(message)) ? (
|
||||
<span className="flex items-center gap-0.5">
|
||||
<svg
|
||||
className="w-3 h-3"
|
||||
viewBox="0 0 16 16"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path d="M14.85 3H1.15C.52 3 0 3.52 0 4.15v7.69C0 12.48.52 13 1.15 13h13.69c.64 0 1.15-.52 1.15-1.15v-7.7C16 3.52 15.48 3 14.85 3zM9 11H7V8L5.5 9.92 4 8v3H2V5h2l1.5 2L7 5h2v6zm2.99.5L9.5 8H11V5h2v3h1.5l-2.51 3.5z" />
|
||||
</svg>
|
||||
MD
|
||||
</span>
|
||||
) : (
|
||||
<span className="flex items-center gap-0.5">
|
||||
<svg
|
||||
className="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M4 6h16M4 12h16M4 18h7"
|
||||
/>
|
||||
</svg>
|
||||
{t('pipelines.debugDialog.showRaw')}
|
||||
</span>
|
||||
)}
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
onClick={() => setQuotedMessage(message)}
|
||||
className={cn(
|
||||
'px-1.5 py-0.5 rounded text-[10px] transition-colors flex items-center gap-0.5',
|
||||
message.role === 'user'
|
||||
? 'hover:bg-blue-200 dark:hover:bg-blue-800'
|
||||
: 'hover:bg-gray-200 dark:hover:bg-gray-700',
|
||||
)}
|
||||
title={t('pipelines.debugDialog.reply')}
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M3 10h10a8 8 0 018 8v2M3 10l6 6m-6-6l6-6"
|
||||
/>
|
||||
</svg>
|
||||
{t('pipelines.debugDialog.reply')}
|
||||
</Button>
|
||||
)}
|
||||
<svg
|
||||
className="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M3 10h10a8 8 0 018 8v2M3 10l6 6m-6-6l6-6"
|
||||
/>
|
||||
</svg>
|
||||
{t('pipelines.debugDialog.reply')}
|
||||
</button>
|
||||
</div>
|
||||
<span className="text-[10px]">
|
||||
{formatTimestamp(getMessageTimestamp(message))}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))
|
||||
|
||||
@@ -8,7 +8,7 @@ import rehypeHighlight from 'rehype-highlight';
|
||||
import rehypeSlug from 'rehype-slug';
|
||||
import rehypeAutolinkHeadings from 'rehype-autolink-headings';
|
||||
import { getAPILanguageCode } from '@/i18n/I18nProvider';
|
||||
import './github-markdown.css';
|
||||
import '@/styles/github-markdown.css';
|
||||
|
||||
export default function PluginReadme({
|
||||
pluginAuthor,
|
||||
|
||||
@@ -24,6 +24,9 @@ import {
|
||||
Power,
|
||||
Github,
|
||||
ChevronLeft,
|
||||
Code,
|
||||
Copy,
|
||||
Bug,
|
||||
} from 'lucide-react';
|
||||
import {
|
||||
DropdownMenu,
|
||||
@@ -38,6 +41,11 @@ import {
|
||||
DialogTitle,
|
||||
DialogFooter,
|
||||
} from '@/components/ui/dialog';
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from '@/components/ui/popover';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import React, { useState, useRef, useCallback, useEffect } from 'react';
|
||||
import { httpClient } from '@/app/infra/http/HttpClient';
|
||||
@@ -105,6 +113,11 @@ export default function PluginConfigPage() {
|
||||
);
|
||||
const [isEditMode, setIsEditMode] = useState(false);
|
||||
const [refreshKey, setRefreshKey] = useState(0);
|
||||
const [debugInfo, setDebugInfo] = useState<{
|
||||
debug_url: string;
|
||||
plugin_debug_key: string;
|
||||
} | null>(null);
|
||||
const [debugPopoverOpen, setDebugPopoverOpen] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchPluginSystemStatus = async () => {
|
||||
@@ -374,6 +387,22 @@ export default function PluginConfigPage() {
|
||||
[uploadPluginFile, isPluginSystemReady, t],
|
||||
);
|
||||
|
||||
const handleShowDebugInfo = async () => {
|
||||
try {
|
||||
const info = await httpClient.getPluginDebugInfo();
|
||||
setDebugInfo(info);
|
||||
setDebugPopoverOpen(true);
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch debug info:', error);
|
||||
toast.error(t('plugins.failedToGetDebugInfo'));
|
||||
}
|
||||
};
|
||||
|
||||
const handleCopyDebugInfo = (text: string) => {
|
||||
navigator.clipboard.writeText(text);
|
||||
toast.success(t('plugins.copiedToClipboard'));
|
||||
};
|
||||
|
||||
const renderPluginDisabledState = () => (
|
||||
<div className="flex flex-col items-center justify-center h-[60vh] text-center pt-[10vh]">
|
||||
<Power className="w-16 h-16 text-gray-400 mb-4" />
|
||||
@@ -466,7 +495,92 @@ export default function PluginConfigPage() {
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<div className="flex flex-row justify-end items-center">
|
||||
<div className="flex flex-row justify-end items-center gap-2">
|
||||
{activeTab === 'installed' && (
|
||||
<Popover
|
||||
open={debugPopoverOpen}
|
||||
onOpenChange={setDebugPopoverOpen}
|
||||
>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
className="px-4 py-5 cursor-pointer"
|
||||
onClick={handleShowDebugInfo}
|
||||
>
|
||||
<Code className="w-4 h-4 mr-2" />
|
||||
{t('plugins.debugInfo')}
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-[380px]" align="end">
|
||||
<div className="space-y-3">
|
||||
{/* Header with icon and title */}
|
||||
<div className="flex items-center gap-2 pb-2 border-b">
|
||||
<Bug className="w-4 h-4" />
|
||||
<h4 className="font-semibold text-sm">
|
||||
{t('plugins.debugInfoTitle')}
|
||||
</h4>
|
||||
</div>
|
||||
|
||||
{/* Debug URL row */}
|
||||
<div className="flex items-center gap-2">
|
||||
<label className="text-sm font-medium whitespace-nowrap min-w-[50px]">
|
||||
{t('plugins.debugUrl')}:
|
||||
</label>
|
||||
<Input
|
||||
value={debugInfo?.debug_url || ''}
|
||||
readOnly
|
||||
className="w-[220px] font-mono text-xs h-8"
|
||||
/>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-8 w-8 shrink-0"
|
||||
onClick={() =>
|
||||
handleCopyDebugInfo(debugInfo?.debug_url || '')
|
||||
}
|
||||
>
|
||||
<Copy className="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Debug Key row */}
|
||||
<div className="space-y-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<label className="text-sm font-medium whitespace-nowrap min-w-[50px]">
|
||||
{t('plugins.debugKey')}:
|
||||
</label>
|
||||
<Input
|
||||
value={
|
||||
debugInfo?.plugin_debug_key ||
|
||||
t('plugins.noDebugKey')
|
||||
}
|
||||
readOnly
|
||||
className="w-[220px] font-mono text-xs h-8"
|
||||
/>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-8 w-8 shrink-0"
|
||||
onClick={() =>
|
||||
handleCopyDebugInfo(
|
||||
debugInfo?.plugin_debug_key || '',
|
||||
)
|
||||
}
|
||||
disabled={!debugInfo?.plugin_debug_key}
|
||||
>
|
||||
<Copy className="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
{!debugInfo?.plugin_debug_key && (
|
||||
<p className="text-xs text-muted-foreground ml-[58px]">
|
||||
{t('plugins.debugKeyDisabled')}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
)}
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="default" className="px-6 py-4 cursor-pointer">
|
||||
|
||||
@@ -63,6 +63,7 @@ export interface KnowledgeBase {
|
||||
description: string;
|
||||
embedding_model_uuid: string;
|
||||
created_at?: string;
|
||||
updated_at?: string;
|
||||
top_k: number;
|
||||
}
|
||||
|
||||
@@ -142,6 +143,7 @@ export interface Bot {
|
||||
use_pipeline_uuid?: string;
|
||||
created_at?: string;
|
||||
updated_at?: string;
|
||||
adapter_runtime_values?: object;
|
||||
}
|
||||
|
||||
export interface ApiRespKnowledgeBases {
|
||||
|
||||
@@ -639,6 +639,13 @@ export class BackendClient extends BaseHttpClient {
|
||||
return this.get('/api/v1/system/status/plugin-system');
|
||||
}
|
||||
|
||||
public getPluginDebugInfo(): Promise<{
|
||||
debug_url: string;
|
||||
plugin_debug_key: string;
|
||||
}> {
|
||||
return this.get('/api/v1/plugins/debug-info');
|
||||
}
|
||||
|
||||
// ============ User API ============
|
||||
public checkIfInited(): Promise<{ initialized: boolean }> {
|
||||
return this.get('/api/v1/user/init');
|
||||
|
||||
@@ -79,8 +79,10 @@ export class WebSocketClient {
|
||||
// 构建WebSocket URL
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
// extract host from process.env.NEXT_PUBLIC_API_BASE_URL
|
||||
// 如果环境变量未定义,使用当前页面的 host (适配生产环境)
|
||||
const host =
|
||||
process.env.NEXT_PUBLIC_API_BASE_URL?.split('://')[1] || '';
|
||||
process.env.NEXT_PUBLIC_API_BASE_URL?.split('://')[1] ||
|
||||
window.location.host;
|
||||
const url = `${protocol}//${host}/api/v1/pipelines/${this.pipelineId}/ws/connect?session_type=${this.sessionType}`;
|
||||
|
||||
this.ws = new WebSocket(url);
|
||||
@@ -149,12 +151,28 @@ export class WebSocketClient {
|
||||
break;
|
||||
|
||||
case 'response':
|
||||
// 检查 session_type 是否匹配 - 如果消息没有 session_type 或者不匹配当前session,都忽略
|
||||
if (!data.session_type || data.session_type !== this.sessionType) {
|
||||
// 忽略不匹配的 session_type 消息
|
||||
console.debug(
|
||||
`忽略不匹配的消息: 当前session=${this.sessionType}, 消息session=${data.session_type}`,
|
||||
);
|
||||
break;
|
||||
}
|
||||
if (data.data) {
|
||||
this.onMessageCallback?.(data.data);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'user_message':
|
||||
// 检查 session_type 是否匹配 - 如果消息没有 session_type 或者不匹配当前session,都忽略
|
||||
if (!data.session_type || data.session_type !== this.sessionType) {
|
||||
// 忽略不匹配的 session_type 消息
|
||||
console.debug(
|
||||
`忽略不匹配的用户消息: 当前session=${this.sessionType}, 消息session=${data.session_type}`,
|
||||
);
|
||||
break;
|
||||
}
|
||||
// 用户消息广播(包括自己发送的消息)
|
||||
if (data.data) {
|
||||
this.onMessageCallback?.(data.data);
|
||||
|
||||
@@ -41,6 +41,7 @@ const enUS = {
|
||||
addRound: 'Add Round',
|
||||
copy: 'Copy',
|
||||
copySuccess: 'Copy Successfully',
|
||||
copyFailed: 'Copy Failed',
|
||||
test: 'Test',
|
||||
forgotPassword: 'Forgot Password?',
|
||||
loading: 'Loading...',
|
||||
@@ -187,6 +188,14 @@ const enUS = {
|
||||
log: 'Log',
|
||||
configuration: 'Configuration',
|
||||
logs: 'Logs',
|
||||
webhookUrl: 'Webhook Callback URL',
|
||||
webhookUrlCopied: 'Webhook URL copied',
|
||||
webhookUrlHint:
|
||||
'Click the input to select all, then press Ctrl+C (Mac: Cmd+C) to copy, or click the button',
|
||||
logLevel: 'Log Level',
|
||||
allLevels: 'All Levels',
|
||||
selectLevel: 'Select Level',
|
||||
levelsSelected: 'levels selected',
|
||||
},
|
||||
plugins: {
|
||||
title: 'Extensions',
|
||||
@@ -231,6 +240,15 @@ const enUS = {
|
||||
failedToGetStatus: 'Failed to get plugin system status',
|
||||
pluginSystemNotReady:
|
||||
'Plugin system is not ready, cannot perform this operation',
|
||||
debugInfo: 'Debug Info',
|
||||
debugInfoTitle: 'Plugin Debug Information',
|
||||
debugUrl: 'Debug URL',
|
||||
debugKey: 'Debug Key',
|
||||
noDebugKey: '(Not Set)',
|
||||
debugKeyDisabled:
|
||||
'Debug key is not set, plugin debugging does not require authentication',
|
||||
failedToGetDebugInfo: 'Failed to get debug information',
|
||||
copiedToClipboard: 'Copied to clipboard',
|
||||
deleting: 'Deleting...',
|
||||
deletePlugin: 'Delete Plugin',
|
||||
cancel: 'Cancel',
|
||||
@@ -525,6 +543,8 @@ const enUS = {
|
||||
imageUploadFailed: 'Image upload failed',
|
||||
reply: 'Reply',
|
||||
replyTo: 'Reply to',
|
||||
showMarkdown: 'Show Markdown',
|
||||
showRaw: 'Show Raw',
|
||||
},
|
||||
},
|
||||
knowledge: {
|
||||
@@ -603,6 +623,8 @@ const enUS = {
|
||||
retriever: 'Retriever',
|
||||
selectRetriever: 'Select a retriever...',
|
||||
retrieverConfiguration: 'Retriever Configuration',
|
||||
retrieverInstallInfo: 'You can install Knowledge Retriever plugins from',
|
||||
retrieverMarketLink: 'here',
|
||||
},
|
||||
register: {
|
||||
title: 'Initialize LangBot 👋',
|
||||
|
||||
@@ -42,6 +42,7 @@ const jaJP = {
|
||||
addRound: 'ラウンドを追加',
|
||||
copy: 'コピー',
|
||||
copySuccess: 'コピーに成功しました',
|
||||
copyFailed: 'コピーに失敗しました',
|
||||
test: 'テスト',
|
||||
forgotPassword: 'パスワードを忘れた?',
|
||||
loading: '読み込み中...',
|
||||
@@ -189,6 +190,14 @@ const jaJP = {
|
||||
log: 'ログ',
|
||||
configuration: '設定',
|
||||
logs: 'ログ',
|
||||
webhookUrl: 'Webhook コールバック URL',
|
||||
webhookUrlCopied: 'Webhook URL をコピーしました',
|
||||
webhookUrlHint:
|
||||
'入力ボックスをクリックして全選択し、Ctrl+C (Mac: Cmd+C) でコピーするか、右側のボタンをクリックしてください',
|
||||
logLevel: 'ログレベル',
|
||||
allLevels: 'すべてのレベル',
|
||||
selectLevel: 'レベルを選択',
|
||||
levelsSelected: 'レベル選択済み',
|
||||
},
|
||||
plugins: {
|
||||
title: '拡張機能',
|
||||
@@ -233,6 +242,15 @@ const jaJP = {
|
||||
failedToGetStatus: 'プラグインシステム状態の取得に失敗しました',
|
||||
pluginSystemNotReady:
|
||||
'プラグインシステムが準備されていません。この操作を実行できません',
|
||||
debugInfo: 'デバッグ情報',
|
||||
debugInfoTitle: 'プラグインデバッグ情報',
|
||||
debugUrl: 'デバッグURL',
|
||||
debugKey: 'デバッグキー',
|
||||
noDebugKey: '(未設定)',
|
||||
debugKeyDisabled:
|
||||
'デバッグキーが設定されていません。プラグインデバッグには認証が不要です',
|
||||
failedToGetDebugInfo: 'デバッグ情報の取得に失敗しました',
|
||||
copiedToClipboard: 'クリップボードにコピーしました',
|
||||
deleting: '削除中...',
|
||||
deletePlugin: 'プラグインを削除',
|
||||
cancel: 'キャンセル',
|
||||
@@ -529,6 +547,8 @@ const jaJP = {
|
||||
imageUploadFailed: '画像のアップロードに失敗しました',
|
||||
reply: '返信',
|
||||
replyTo: '返信先',
|
||||
showMarkdown: 'Markdownで表示',
|
||||
showRaw: '原文で表示',
|
||||
},
|
||||
},
|
||||
knowledge: {
|
||||
@@ -608,6 +628,8 @@ const jaJP = {
|
||||
retriever: '検索器',
|
||||
selectRetriever: '検索器を選択...',
|
||||
retrieverConfiguration: '検索器設定',
|
||||
retrieverInstallInfo: 'ナレッジ検索器プラグインは',
|
||||
retrieverMarketLink: 'こちらからインストールできます',
|
||||
},
|
||||
register: {
|
||||
title: 'LangBot を初期化 👋',
|
||||
|
||||
@@ -41,6 +41,7 @@ const zhHans = {
|
||||
addRound: '添加回合',
|
||||
copy: '复制',
|
||||
copySuccess: '复制成功',
|
||||
copyFailed: '复制失败',
|
||||
test: '测试',
|
||||
forgotPassword: '忘记密码?',
|
||||
loading: '加载中...',
|
||||
@@ -182,6 +183,14 @@ const zhHans = {
|
||||
log: '日志',
|
||||
configuration: '配置',
|
||||
logs: '日志',
|
||||
webhookUrl: 'Webhook 回调地址',
|
||||
webhookUrlCopied: 'Webhook 地址已复制',
|
||||
webhookUrlHint:
|
||||
'点击输入框自动全选,然后按 Ctrl+C (Mac: Cmd+C) 复制,或点击右侧按钮',
|
||||
logLevel: '日志级别',
|
||||
allLevels: '全部级别',
|
||||
selectLevel: '选择级别',
|
||||
levelsSelected: '个级别已选',
|
||||
},
|
||||
plugins: {
|
||||
title: '插件扩展',
|
||||
@@ -222,6 +231,14 @@ const zhHans = {
|
||||
loadingStatus: '正在检查插件系统状态...',
|
||||
failedToGetStatus: '获取插件系统状态失败',
|
||||
pluginSystemNotReady: '插件系统未就绪,无法执行此操作',
|
||||
debugInfo: '调试信息',
|
||||
debugInfoTitle: '插件调试信息',
|
||||
debugUrl: '调试地址',
|
||||
debugKey: '调试密钥',
|
||||
noDebugKey: '(未设置)',
|
||||
debugKeyDisabled: '未设置调试密钥,插件调试无需认证',
|
||||
failedToGetDebugInfo: '获取调试信息失败',
|
||||
copiedToClipboard: '已复制到剪贴板',
|
||||
deleting: '删除中...',
|
||||
deletePlugin: '删除插件',
|
||||
cancel: '取消',
|
||||
@@ -508,6 +525,8 @@ const zhHans = {
|
||||
imageUploadFailed: '图片上传失败',
|
||||
reply: '回复',
|
||||
replyTo: '回复给',
|
||||
showMarkdown: '渲染',
|
||||
showRaw: '原文',
|
||||
},
|
||||
},
|
||||
knowledge: {
|
||||
@@ -581,6 +600,8 @@ const zhHans = {
|
||||
retriever: '检索器',
|
||||
selectRetriever: '选择一个检索器...',
|
||||
retrieverConfiguration: '检索器配置',
|
||||
retrieverInstallInfo: '您可以从',
|
||||
retrieverMarketLink: '此处安装知识检索器插件',
|
||||
},
|
||||
register: {
|
||||
title: '初始化 LangBot 👋',
|
||||
|
||||
@@ -41,6 +41,7 @@ const zhHant = {
|
||||
addRound: '新增回合',
|
||||
copy: '複製',
|
||||
copySuccess: '複製成功',
|
||||
copyFailed: '複製失敗',
|
||||
test: '測試',
|
||||
forgotPassword: '忘記密碼?',
|
||||
loading: '載入中...',
|
||||
@@ -182,6 +183,14 @@ const zhHant = {
|
||||
log: '日誌',
|
||||
configuration: '設定',
|
||||
logs: '日誌',
|
||||
webhookUrl: 'Webhook 回調位址',
|
||||
webhookUrlCopied: 'Webhook 位址已複製',
|
||||
webhookUrlHint:
|
||||
'點擊輸入框自動全選,然後按 Ctrl+C (Mac: Cmd+C) 複製,或點擊右側按鈕',
|
||||
logLevel: '日誌級別',
|
||||
allLevels: '全部級別',
|
||||
selectLevel: '選擇級別',
|
||||
levelsSelected: '個級別已選',
|
||||
},
|
||||
plugins: {
|
||||
title: '外掛擴展',
|
||||
@@ -222,6 +231,14 @@ const zhHant = {
|
||||
loadingStatus: '正在檢查外掛系統狀態...',
|
||||
failedToGetStatus: '取得外掛系統狀態失敗',
|
||||
pluginSystemNotReady: '外掛系統未就緒,無法執行此操作',
|
||||
debugInfo: '偵錯資訊',
|
||||
debugInfoTitle: '外掛偵錯資訊',
|
||||
debugUrl: '偵錯位址',
|
||||
debugKey: '偵錯金鑰',
|
||||
noDebugKey: '(未設定)',
|
||||
debugKeyDisabled: '未設定偵錯金鑰,外掛偵錯無需認證',
|
||||
failedToGetDebugInfo: '取得偵錯資訊失敗',
|
||||
copiedToClipboard: '已複製到剪貼簿',
|
||||
deleting: '刪除中...',
|
||||
deletePlugin: '刪除外掛',
|
||||
cancel: '取消',
|
||||
@@ -506,6 +523,8 @@ const zhHant = {
|
||||
imageUploadFailed: '圖片上傳失敗',
|
||||
reply: '回覆',
|
||||
replyTo: '回覆給',
|
||||
showMarkdown: '渲染',
|
||||
showRaw: '原文',
|
||||
},
|
||||
},
|
||||
knowledge: {
|
||||
@@ -579,6 +598,8 @@ const zhHant = {
|
||||
retriever: '檢索器',
|
||||
selectRetriever: '選擇一個檢索器...',
|
||||
retrieverConfiguration: '檢索器配置',
|
||||
retrieverInstallInfo: '您可以從',
|
||||
retrieverMarketLink: '此處安裝知識檢索器插件',
|
||||
},
|
||||
register: {
|
||||
title: '初始化 LangBot 👋',
|
||||
|
||||
Reference in New Issue
Block a user