mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 20:14:36 +00:00
Compare commits
95 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59d55b382d | ||
|
|
8c17e55913 | ||
|
|
af509fe61f | ||
|
|
87e2a2099a | ||
|
|
3f22f62332 | ||
|
|
d1ee5f931a | ||
|
|
35506dd2bb | ||
|
|
2f06321ebf | ||
|
|
023281ae56 | ||
|
|
50dff55217 | ||
|
|
3204292360 | ||
|
|
e0d72969e3 | ||
|
|
a65b7ad413 | ||
|
|
45df44e01b | ||
|
|
d8addb105a | ||
|
|
f17ccad665 | ||
|
|
120ceb0b55 | ||
|
|
8a6f80a181 | ||
|
|
b19e468668 | ||
|
|
aeac79e1b3 | ||
|
|
b89a240250 | ||
|
|
13f42857f5 | ||
|
|
61f3f31edc | ||
|
|
3663d9dc10 | ||
|
|
89ec86c530 | ||
|
|
d9ba2a17ff | ||
|
|
c4ea6188f9 | ||
|
|
5d9f6ec763 | ||
|
|
b73847f1a6 | ||
|
|
d6e1e79f07 | ||
|
|
525008b8b2 | ||
|
|
bbf77bac4c | ||
|
|
fc6e414be4 | ||
|
|
e60cb6ad0e | ||
|
|
c90f2d6a12 | ||
|
|
fe8a738cd7 | ||
|
|
604cc53973 | ||
|
|
195b694ecc | ||
|
|
d21f23beee | ||
|
|
558587883b | ||
|
|
2e6a1daf4f | ||
|
|
1fc5e75f93 | ||
|
|
a332206ba3 | ||
|
|
8e620dc635 | ||
|
|
c9a21ebace | ||
|
|
a05cdcac50 | ||
|
|
ecfb2bfb34 | ||
|
|
e17dba0a98 | ||
|
|
6b138943ce | ||
|
|
eb0e6aff68 | ||
|
|
4d0095626a | ||
|
|
aa0a501ade | ||
|
|
68ef7bd2c4 | ||
|
|
61dc5de085 | ||
|
|
63bdd71e22 | ||
|
|
9ea5b50802 | ||
|
|
1cd586634d | ||
|
|
45bedbe70e | ||
|
|
f7f1dde7b5 | ||
|
|
ba06555078 | ||
|
|
840fa39979 | ||
|
|
b295416e6c | ||
|
|
914f77ff37 | ||
|
|
b0b7b914d8 | ||
|
|
12713aad45 | ||
|
|
02e12cc1e4 | ||
|
|
61f08f3218 | ||
|
|
75c2a063cc | ||
|
|
b4773c4e48 | ||
|
|
fb73da8735 | ||
|
|
679e549b1d | ||
|
|
898144e9f4 | ||
|
|
b99c5561fc | ||
|
|
b2f4b91979 | ||
|
|
4528000fc4 | ||
|
|
96e40eaf25 | ||
|
|
197258ae91 | ||
|
|
19f417174c | ||
|
|
9c82eeddeb | ||
|
|
f11e01b549 | ||
|
|
863b26c3fa | ||
|
|
b788858f9e | ||
|
|
de8a7df6c2 | ||
|
|
ba5b481617 | ||
|
|
07ad846e96 | ||
|
|
30945aafdd | ||
|
|
24c15b4479 | ||
|
|
1d4c5bbdf1 | ||
|
|
57fcec011d | ||
|
|
455e3db28d | ||
|
|
8caab43b00 | ||
|
|
7479545339 | ||
|
|
10ee30695a | ||
|
|
a9a262eaae | ||
|
|
a8594b76cd |
60
.github/workflows/lint.yml
vendored
Normal file
60
.github/workflows/lint.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
- dev
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
name: Ruff Lint & Format
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run ruff check
|
||||
run: uv run ruff check src
|
||||
|
||||
- name: Run ruff format
|
||||
run: uv run ruff format src --check
|
||||
|
||||
frontend:
|
||||
name: Frontend Lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '25'
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: web
|
||||
run: pnpm install
|
||||
|
||||
- name: Run lint
|
||||
working-directory: web
|
||||
run: pnpm lint
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -42,7 +42,6 @@ botpy.log*
|
||||
test.py
|
||||
/web_ui
|
||||
.venv/
|
||||
uv.lock
|
||||
/test
|
||||
plugins.bak
|
||||
coverage.xml
|
||||
|
||||
@@ -70,6 +70,7 @@ Plugin Runtime automatically starts each installed plugin and interacts through
|
||||
- type: must be a specific type, such as feat (new feature), fix (bug fix), docs (documentation), style (code style), refactor (refactoring), perf (performance optimization), etc.
|
||||
- scope: the scope of the commit, such as the package name, the file name, the function name, the class name, the module name, etc.
|
||||
- subject: the subject of the commit, such as the description of the commit, the reason for the commit, the impact of the commit, etc.
|
||||
- If you changed the definition of database entities, please update the migration file in `src/langbot/pkg/persistence/migrations/` and update the constants.py file in `src/langbot/pkg/utils/constants.py` with the new migration number.
|
||||
|
||||
## Some Principles
|
||||
|
||||
|
||||
15
README.md
15
README.md
@@ -13,16 +13,18 @@
|
||||
[English](README_EN.md) / 简体中文 / [繁體中文](README_TW.md) / [日本語](README_JP.md) / [Español](README_ES.md) / [Français](README_FR.md) / [한국어](README_KO.md) / [Русский](README_RU.md) / [Tiếng Việt](README_VI.md)
|
||||
|
||||
[](https://discord.gg/wdNEHETs87)
|
||||
[](https://qm.qq.com/q/JLi38whHum)
|
||||
[](https://qm.qq.com/q/DxZZcNxM1W)
|
||||
[](https://deepwiki.com/langbot-app/LangBot)
|
||||
[](https://github.com/langbot-app/LangBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10 ~ 3.13 -blue.svg" alt="python">
|
||||
[](https://gitcode.com/RockChinQ/LangBot)
|
||||
|
||||
<a href="https://langbot.app">项目主页</a> |
|
||||
<a href="https://docs.langbot.app/zh/insight/features.html">规格特性</a> |
|
||||
<a href="https://docs.langbot.app/zh/insight/guide.html">部署文档</a> |
|
||||
<a href="https://docs.langbot.app/zh/plugin/plugin-intro.html">插件介绍</a> |
|
||||
<a href="https://github.com/langbot-app/LangBot/issues/new?assignees=&labels=%E7%8B%AC%E7%AB%8B%E6%8F%92%E4%BB%B6&projects=&template=submit-plugin.yml&title=%5BPlugin%5D%3A+%E8%AF%B7%E6%B1%82%E7%99%BB%E8%AE%B0%E6%96%B0%E6%8F%92%E4%BB%B6">提交插件</a>
|
||||
<a href="https://docs.langbot.app/zh/tags/readme.html">API 集成</a> |
|
||||
<a href="https://space.langbot.app">插件市场</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">路线图</a>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -86,11 +88,12 @@ docker compose up -d
|
||||
<img width="500" src="https://docs.langbot.app/ui/bot-page-zh-rounded.png" />
|
||||
|
||||
|
||||
- 💬 大模型对话、Agent:支持多种大模型,适配群聊和私聊;具有多轮对话、工具调用、多模态、流式输出能力,自带 RAG(知识库)实现,并深度适配 [Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io)等 LLMOps 平台。
|
||||
- 💬 大模型对话、Agent:支持多种大模型,适配群聊和私聊;具有多轮对话、工具调用、多模态、流式输出能力,自带 RAG(知识库)实现,并深度适配 [Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io)、[Langflow](https://langflow.org)等 LLMOps 平台。
|
||||
- 🤖 多平台支持:目前支持 QQ、QQ频道、企业微信、个人微信、飞书、Discord、Telegram、KOOK、Slack、LINE 等平台。
|
||||
- 🛠️ 高稳定性、功能完备:原生支持访问控制、限速、敏感词过滤等机制;配置简单,支持多种部署方式。支持多流水线配置,不同机器人用于不同应用场景。
|
||||
- 🛠️ 高稳定性、功能完备:原生支持访问控制、限速、敏感词过滤等机制;配置简单,支持多种部署方式。
|
||||
- 🧩 插件扩展、活跃社区:高稳定性、高安全性的生产级插件系统,支持事件驱动、组件扩展等插件机制;适配 Anthropic [MCP 协议](https://modelcontextprotocol.io/);目前已有数百个插件。
|
||||
- 😻 Web 管理面板:支持通过浏览器管理 LangBot 实例,不再需要手动编写配置文件。
|
||||
- 😻 Web 管理面板:提供先进的 WebUI 管理面板,用最直观的方式配置、管理、监控机器人。
|
||||
- 📊 生产级特性:支持多流水线配置,不同机器人用于不同应用场景。具有全面的监控和异常处理能力。已被多家企业采用。
|
||||
|
||||
详细规格特性请访问[文档](https://docs.langbot.app/zh/insight/features.html)。
|
||||
|
||||
|
||||
11
README_EN.md
11
README_EN.md
@@ -17,9 +17,11 @@ English / [简体中文](README.md) / [繁體中文](README_TW.md) / [日本語]
|
||||
<img src="https://img.shields.io/badge/python-3.10 ~ 3.13 -blue.svg" alt="python">
|
||||
|
||||
<a href="https://langbot.app">Home</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features.html">Features</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide.html">Deployment</a> |
|
||||
<a href="https://docs.langbot.app/en/plugin/plugin-intro.html">Plugin</a> |
|
||||
<a href="https://github.com/langbot-app/LangBot/issues/new?assignees=&labels=%E7%8B%AC%E7%AB%8B%E6%8F%92%E4%BB%B6&projects=&template=submit-plugin.yml&title=%5BPlugin%5D%3A+%E8%AF%B7%E6%B1%82%E7%99%BB%E8%AE%B0%E6%96%B0%E6%8F%92%E4%BB%B6">Submit Plugin</a>
|
||||
<a href="https://docs.langbot.app/en/tags/readme.html">API Integration</a> |
|
||||
<a href="https://space.langbot.app">Plugin Market</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">Roadmap</a>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -83,11 +85,12 @@ Click the Star and Watch button in the upper right corner of the repository to g
|
||||
<img width="500" src="https://docs.langbot.app/ui/bot-page-en-rounded.png" />
|
||||
|
||||
|
||||
- 💬 Chat with LLM / Agent: Supports multiple LLMs, adapt to group chats and private chats; Supports multi-round conversations, tool calls, multi-modal, and streaming output capabilities. Built-in RAG (knowledge base) implementation, and deeply integrates with [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io) etc. LLMOps platforms.
|
||||
- 💬 Chat with LLM / Agent: Supports multiple LLMs, adapt to group chats and private chats; Supports multi-round conversations, tool calls, multi-modal, and streaming output capabilities. Built-in RAG (knowledge base) implementation, and deeply integrates with [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io), [Langflow](https://langflow.org) etc. LLMOps platforms.
|
||||
- 🤖 Multi-platform Support: Currently supports QQ, QQ Channel, WeCom, personal WeChat, Lark, DingTalk, Discord, Telegram, KOOK, Slack, LINE, etc.
|
||||
- 🛠️ High Stability, Feature-rich: Native access control, rate limiting, sensitive word filtering, etc. mechanisms; Easy to use, supports multiple deployment methods. Supports multiple pipeline configurations, different bots can be used for different scenarios.
|
||||
- 🛠️ High Stability, Feature-rich: Native access control, rate limiting, sensitive word filtering, etc. mechanisms; Easy to use, supports multiple deployment methods.
|
||||
- 🧩 Plugin Extension, Active Community: High stability, high security production-level plugin system; Support event-driven, component extension, etc. plugin mechanisms; Integrate Anthropic [MCP protocol](https://modelcontextprotocol.io/); Currently has hundreds of plugins.
|
||||
- 😻 Web UI: Support management LangBot instance through the browser. No need to manually write configuration files.
|
||||
- 📊 Production-grade Features: Supports multiple pipeline configurations, different bots can be used for different scenarios. Has comprehensive monitoring and exception handling capabilities.
|
||||
|
||||
For more detailed specifications, please refer to the [documentation](https://docs.langbot.app/en/insight/features.html).
|
||||
|
||||
|
||||
11
README_ES.md
11
README_ES.md
@@ -17,9 +17,11 @@
|
||||
<img src="https://img.shields.io/badge/python-3.10 ~ 3.13 -blue.svg" alt="python">
|
||||
|
||||
<a href="https://langbot.app">Inicio</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features.html">Características</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide.html">Despliegue</a> |
|
||||
<a href="https://docs.langbot.app/en/plugin/plugin-intro.html">Plugin</a> |
|
||||
<a href="https://github.com/langbot-app/LangBot/issues/new?assignees=&labels=%E7%8B%AC%E7%AB%8B%E6%8F%92%E4%BB%B6&projects=&template=submit-plugin.yml&title=%5BPlugin%5D%3A+%E8%AF%B7%E6%B1%82%E7%99%BB%E8%AE%B0%E6%96%B0%E6%8F%92%E4%BB%B6">Enviar Plugin</a>
|
||||
<a href="https://docs.langbot.app/en/tags/readme.html">Integración API</a> |
|
||||
<a href="https://space.langbot.app">Mercado de Plugins</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">Hoja de Ruta</a>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -83,11 +85,12 @@ Haga clic en los botones Star y Watch en la esquina superior derecha del reposit
|
||||
<img width="500" src="https://docs.langbot.app/ui/bot-page-en-rounded.png" />
|
||||
|
||||
|
||||
- 💬 Chat con LLM / Agent: Compatible con múltiples LLMs, adaptado para chats grupales y privados; Admite conversaciones de múltiples rondas, llamadas a herramientas, capacidades multimodales y de salida en streaming. Implementación RAG (base de conocimientos) incorporada, e integración profunda con [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io) etc. LLMOps platforms.
|
||||
- 💬 Chat con LLM / Agent: Compatible con múltiples LLMs, adaptado para chats grupales y privados; Admite conversaciones de múltiples rondas, llamadas a herramientas, capacidades multimodales y de salida en streaming. Implementación RAG (base de conocimientos) incorporada, e integración profunda con [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io), [Langflow](https://langflow.org) etc. LLMOps platforms.
|
||||
- 🤖 Soporte Multiplataforma: Actualmente compatible con QQ, QQ Channel, WeCom, WeChat personal, Lark, DingTalk, Discord, Telegram, KOOK, Slack, LINE, etc.
|
||||
- 🛠️ Alta Estabilidad, Rico en Funciones: Control de acceso nativo, limitación de velocidad, filtrado de palabras sensibles, etc.; Fácil de usar, admite múltiples métodos de despliegue. Compatible con múltiples configuraciones de pipeline, diferentes bots para diferentes escenarios.
|
||||
- 🛠️ Alta Estabilidad, Rico en Funciones: Control de acceso nativo, limitación de velocidad, filtrado de palabras sensibles, etc.; Fácil de usar, admite múltiples métodos de despliegue.
|
||||
- 🧩 Extensión de Plugin, Comunidad Activa: Sistema de plugin de alta estabilidad, alta seguridad de nivel de producción; Compatible con mecanismos de plugin impulsados por eventos, extensión de componentes, etc.; Integración del protocolo [MCP](https://modelcontextprotocol.io/) de Anthropic; Actualmente cuenta con cientos de plugins.
|
||||
- 😻 Interfaz Web: Admite la gestión de instancias de LangBot a través del navegador. No es necesario escribir archivos de configuración manualmente.
|
||||
- 📊 Características de Nivel de Producción: Compatible con múltiples configuraciones de pipeline, diferentes bots para diferentes escenarios. Cuenta con capacidades completas de monitoreo y manejo de excepciones.
|
||||
|
||||
Para especificaciones más detalladas, consulte la [documentación](https://docs.langbot.app/en/insight/features.html).
|
||||
|
||||
|
||||
11
README_FR.md
11
README_FR.md
@@ -17,9 +17,11 @@
|
||||
<img src="https://img.shields.io/badge/python-3.10 ~ 3.13 -blue.svg" alt="python">
|
||||
|
||||
<a href="https://langbot.app">Accueil</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features.html">Fonctionnalités</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide.html">Déploiement</a> |
|
||||
<a href="https://docs.langbot.app/en/plugin/plugin-intro.html">Plugin</a> |
|
||||
<a href="https://github.com/langbot-app/LangBot/issues/new?assignees=&labels=%E7%8B%AC%E7%AB%8B%E6%8F%92%E4%BB%B6&projects=&template=submit-plugin.yml&title=%5BPlugin%5D%3A+%E8%AF%B7%E6%B1%82%E7%99%BB%E8%AE%B0%E6%96%B0%E6%8F%92%E4%BB%B6">Soumettre un Plugin</a>
|
||||
<a href="https://docs.langbot.app/en/tags/readme.html">Intégration API</a> |
|
||||
<a href="https://space.langbot.app">Marché des Plugins</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">Feuille de Route</a>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -82,11 +84,12 @@ Cliquez sur les boutons Star et Watch dans le coin supérieur droit du dépôt p
|
||||
<img width="500" src="https://docs.langbot.app/ui/bot-page-en-rounded.png" />
|
||||
|
||||
|
||||
- 💬 Chat avec LLM / Agent : Prend en charge plusieurs LLM, adapté aux chats de groupe et privés ; Prend en charge les conversations multi-tours, les appels d'outils, les capacités multimodales et de sortie en streaming. Implémentation RAG (base de connaissances) intégrée, et intégration profonde avec [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io) etc. LLMOps platforms.
|
||||
- 💬 Chat avec LLM / Agent : Prend en charge plusieurs LLM, adapté aux chats de groupe et privés ; Prend en charge les conversations multi-tours, les appels d'outils, les capacités multimodales et de sortie en streaming. Implémentation RAG (base de connaissances) intégrée, et intégration profonde avec [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io), [Langflow](https://langflow.org) etc. LLMOps platforms.
|
||||
- 🤖 Support Multi-plateforme : Actuellement compatible avec QQ, QQ Channel, WeCom, WeChat personnel, Lark, DingTalk, Discord, Telegram, KOOK, Slack, LINE, etc.
|
||||
- 🛠️ Haute Stabilité, Riche en Fonctionnalités : Contrôle d'accès natif, limitation de débit, filtrage de mots sensibles, etc. ; Facile à utiliser, prend en charge plusieurs méthodes de déploiement. Prend en charge plusieurs configurations de pipeline, différents bots pour différents scénarios.
|
||||
- 🛠️ Haute Stabilité, Riche en Fonctionnalités : Contrôle d'accès natif, limitation de débit, filtrage de mots sensibles, etc. ; Facile à utiliser, prend en charge plusieurs méthodes de déploiement.
|
||||
- 🧩 Extension de Plugin, Communauté Active : Système de plugin de haute stabilité, haute sécurité de niveau production; Prend en charge les mécanismes de plugin pilotés par événements, l'extension de composants, etc. ; Intégration du protocole [MCP](https://modelcontextprotocol.io/) d'Anthropic ; Dispose actuellement de centaines de plugins.
|
||||
- 😻 Interface Web : Prend en charge la gestion des instances LangBot via le navigateur. Pas besoin d'écrire manuellement les fichiers de configuration.
|
||||
- 📊 Fonctionnalités de Niveau Production : Prend en charge plusieurs configurations de pipeline, différents bots pour différents scénarios. Dispose de capacités complètes de surveillance et de gestion des exceptions.
|
||||
|
||||
Pour des spécifications plus détaillées, veuillez consulter la [documentation](https://docs.langbot.app/en/insight/features.html).
|
||||
|
||||
|
||||
13
README_JP.md
13
README_JP.md
@@ -17,9 +17,11 @@
|
||||
<img src="https://img.shields.io/badge/python-3.10 ~ 3.13 -blue.svg" alt="python">
|
||||
|
||||
<a href="https://langbot.app">ホーム</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide.html">デプロイ</a> |
|
||||
<a href="https://docs.langbot.app/en/plugin/plugin-intro.html">プラグイン</a> |
|
||||
<a href="https://github.com/langbot-app/LangBot/issues/new?assignees=&labels=%E7%8B%AC%E7%AB%8B%E6%8F%92%E4%BB%B6&projects=&template=submit-plugin.yml&title=%5BPlugin%5D%3A+%E8%AF%B7%E6%B1%82%E7%99%BB%E8%AE%B0%E6%96%B0%E6%8F%92%E4%BB%B6">プラグインの提出</a>
|
||||
<a href="https://docs.langbot.app/ja/insight/features.html">機能仕様</a> |
|
||||
<a href="https://docs.langbot.app/ja/insight/guide.html">デプロイ</a> |
|
||||
<a href="https://docs.langbot.app/ja/tags/readme.html">API統合</a> |
|
||||
<a href="https://space.langbot.app">プラグインマーケット</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">ロードマップ</a>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -82,11 +84,12 @@ LangBotはBTPanelにリストされています。BTPanelをインストール
|
||||
<img width="500" src="https://docs.langbot.app/ui/bot-page-en-rounded.png" />
|
||||
|
||||
|
||||
- 💬 LLM / エージェントとのチャット: 複数のLLMをサポートし、グループチャットとプライベートチャットに対応。マルチラウンドの会話、ツールの呼び出し、マルチモーダル、ストリーミング出力機能をサポート、RAG(知識ベース)を組み込み、[Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io) などの LLMOps プラットフォームと深く統合。
|
||||
- 💬 LLM / エージェントとのチャット: 複数のLLMをサポートし、グループチャットとプライベートチャットに対応。マルチラウンドの会話、ツールの呼び出し、マルチモーダル、ストリーミング出力機能をサポート、RAG(知識ベース)を組み込み、[Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io)、[Langflow](https://langflow.org)などの LLMOps プラットフォームと深く統合。
|
||||
- 🤖 多プラットフォーム対応: 現在、QQ、QQ チャンネル、WeChat、個人 WeChat、Lark、DingTalk、Discord、Telegram、KOOK、Slack、LINE など、複数のプラットフォームをサポートしています。
|
||||
- 🛠️ 高い安定性、豊富な機能: ネイティブのアクセス制御、レート制限、敏感な単語のフィルタリングなどのメカニズムをサポート。使いやすく、複数のデプロイ方法をサポート。複数のパイプライン設定をサポートし、異なるボットを異なる用途に使用できます。
|
||||
- 🛠️ 高い安定性、豊富な機能: ネイティブのアクセス制御、レート制限、敏感な単語のフィルタリングなどのメカニズムをサポート。使いやすく、複数のデプロイ方法をサポート。
|
||||
- 🧩 プラグイン拡張、活発なコミュニティ: 高い安定性、高いセキュリティの生産レベルのプラグインシステム;イベント駆動、コンポーネント拡張などのプラグインメカニズムをサポート。適配 Anthropic [MCP プロトコル](https://modelcontextprotocol.io/);豊富なエコシステム、現在数百のプラグインが存在。
|
||||
- 😻 Web UI: ブラウザを通じてLangBotインスタンスを管理することをサポート。
|
||||
- 📊 生産レベルの機能: 複数のパイプライン設定をサポートし、異なるボットを異なる用途に使用できます。包括的な監視と例外処理機能を備えています。
|
||||
|
||||
詳細な仕様については、[ドキュメント](https://docs.langbot.app/en/insight/features.html)を参照してください。
|
||||
|
||||
|
||||
11
README_KO.md
11
README_KO.md
@@ -17,9 +17,11 @@
|
||||
<img src="https://img.shields.io/badge/python-3.10 ~ 3.13 -blue.svg" alt="python">
|
||||
|
||||
<a href="https://langbot.app">홈</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features.html">기능 사양</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide.html">배포</a> |
|
||||
<a href="https://docs.langbot.app/en/plugin/plugin-intro.html">플러그인</a> |
|
||||
<a href="https://github.com/langbot-app/LangBot/issues/new?assignees=&labels=%E7%8B%AC%E7%AB%8B%E6%8F%92%E4%BB%B6&projects=&template=submit-plugin.yml&title=%5BPlugin%5D%3A+%E8%AF%B7%E6%B1%82%E7%99%BB%E8%AE%B0%E6%96%B0%E6%8F%92%E4%BB%B6">플러그인 제출</a>
|
||||
<a href="https://docs.langbot.app/en/tags/readme.html">API 통합</a> |
|
||||
<a href="https://space.langbot.app">플러그인 마켓</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">로드맵</a>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -82,11 +84,12 @@ LangBot은 BTPanel에 등록되어 있습니다. BTPanel을 설치한 경우 [
|
||||
<img width="500" src="https://docs.langbot.app/ui/bot-page-en-rounded.png" />
|
||||
|
||||
|
||||
- 💬 LLM / Agent와 채팅: 여러 LLM을 지원하며 그룹 채팅 및 개인 채팅에 적응; 멀티 라운드 대화, 도구 호출, 멀티모달, 스트리밍 출력 기능을 지원합니다. 내장된 RAG(지식 베이스) 구현 및 [Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io) 등의 LLMOps 플랫폼과 깊이 통합됩니다.
|
||||
- 💬 LLM / Agent와 채팅: 여러 LLM을 지원하며 그룹 채팅 및 개인 채팅에 적응; 멀티 라운드 대화, 도구 호출, 멀티모달, 스트리밍 출력 기능을 지원합니다. 내장된 RAG(지식 베이스) 구현 및 [Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io)、[Langflow](https://langflow.org)등의 LLMOps 플랫폼과 깊이 통합됩니다.
|
||||
- 🤖 다중 플랫폼 지원: 현재 QQ, QQ Channel, WeCom, 개인 WeChat, Lark, DingTalk, Discord, Telegram, KOOK, Slack, LINE 등을 지원합니다.
|
||||
- 🛠️ 높은 안정성, 풍부한 기능: 네이티브 액세스 제어, 속도 제한, 민감한 단어 필터링 등의 메커니즘; 사용하기 쉽고 여러 배포 방법을 지원합니다. 여러 파이프라인 구성을 지원하며 다양한 시나리오에 대해 다른 봇을 사용할 수 있습니다.
|
||||
- 🛠️ 높은 안정성, 풍부한 기능: 네이티브 액세스 제어, 속도 제한, 민감한 단어 필터링 등의 메커니즘; 사용하기 쉽고 여러 배포 방법을 지원합니다.
|
||||
- 🧩 플러그인 확장, 활발한 커뮤니티: 고안정성, 고보안 생산 수준의 플러그인 시스템; 이벤트 기반, 컴포넌트 확장 등의 플러그인 메커니즘을 지원; Anthropic [MCP 프로토콜](https://modelcontextprotocol.io/) 통합; 현재 수백 개의 플러그인이 있습니다.
|
||||
- 😻 웹 UI: 브라우저를 통해 LangBot 인스턴스 관리를 지원합니다. 구성 파일을 수동으로 작성할 필요가 없습니다.
|
||||
- 📊 생산 수준의 기능: 여러 파이프라인 구성을 지원하며 다양한 시나리오에 대해 다른 봇을 사용할 수 있습니다. 포괄적인 모니터링 및 예외 처리 기능을 갖추고 있습니다.
|
||||
|
||||
더 자세한 사양은 [문서](https://docs.langbot.app/en/insight/features.html)를 참조하세요.
|
||||
|
||||
|
||||
11
README_RU.md
11
README_RU.md
@@ -17,9 +17,11 @@
|
||||
<img src="https://img.shields.io/badge/python-3.10 ~ 3.13 -blue.svg" alt="python">
|
||||
|
||||
<a href="https://langbot.app">Главная</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features.html">Характеристики</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide.html">Развертывание</a> |
|
||||
<a href="https://docs.langbot.app/en/plugin/plugin-intro.html">Плагин</a> |
|
||||
<a href="https://github.com/langbot-app/LangBot/issues/new?assignees=&labels=%E7%8B%AC%E7%AB%8B%E6%8F%92%E4%BB%B6&projects=&template=submit-plugin.yml&title=%5BPlugin%5D%3A+%E8%AF%B7%E6%B1%82%E7%99%BB%E8%AE%B0%E6%96%B0%E6%8F%92%E4%BB%B6">Отправить плагин</a>
|
||||
<a href="https://docs.langbot.app/en/tags/readme.html">Интеграция API</a> |
|
||||
<a href="https://space.langbot.app">Магазин плагинов</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">Дорожная карта</a>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -82,11 +84,12 @@ LangBot добавлен в BTPanel. Если у вас установлен BTP
|
||||
<img width="500" src="https://docs.langbot.app/ui/bot-page-en-rounded.png" />
|
||||
|
||||
|
||||
- 💬 Чат с LLM / Agent: Поддержка нескольких LLM, адаптация к групповым и личным чатам; Поддержка многораундовых разговоров, вызовов инструментов, мультимодальных возможностей и потоковой передачи. Встроенная реализация RAG (база знаний) и глубокая интеграция с [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io) 등의 LLMOps 플랫포트폼과 깊이 통합됩니다.
|
||||
- 💬 Чат с LLM / Agent: Поддержка нескольких LLM, адаптация к групповым и личным чатам; Поддержка многораундовых разговоров, вызовов инструментов, мультимодальных возможностей и потоковой передачи. Встроенная реализация RAG (база знаний) и глубокая интеграция с [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io), [Langflow](https://langflow.org) и др. LLMOps платформами.
|
||||
- 🤖 Многоплатформенная поддержка: В настоящее время поддерживает QQ, QQ Channel, WeCom, личный WeChat, Lark, DingTalk, Discord, Telegram, KOOK, Slack, LINE и т.д.
|
||||
- 🛠️ Высокая стабильность, богатство функций: Нативный контроль доступа, ограничение скорости, фильтрация чувствительных слов и т.д.; Простота в использовании, поддержка нескольких методов развертывания. Поддержка нескольких конфигураций конвейера, разные боты для разных сценариев.
|
||||
- 🛠️ Высокая стабильность, богатство функций: Нативный контроль доступа, ограничение скорости, фильтрация чувствительных слов и т.д.; Простота в использовании, поддержка нескольких методов развертывания.
|
||||
- 🧩 Расширение плагинов, активное сообщество: Высокая стабильность, высокая безопасность уровня производства; Поддержка механизмов плагинов, управляемых событиями, расширения компонентов и т.д.; Интеграция протокола [MCP](https://modelcontextprotocol.io/) от Anthropic; В настоящее время сотни плагинов.
|
||||
- 😻 Веб-интерфейс: Поддержка управления экземплярами LangBot через браузер. Нет необходимости вручную писать конфигурационные файлы.
|
||||
- 📊 Функции уровня производства: Поддержка нескольких конфигураций конвейера, разные боты для разных сценариев. Имеет комплексные возможности мониторинга и обработки исключений.
|
||||
|
||||
Для более подробных спецификаций обратитесь к [документации](https://docs.langbot.app/en/insight/features.html).
|
||||
|
||||
|
||||
13
README_TW.md
13
README_TW.md
@@ -17,9 +17,11 @@
|
||||
[](https://gitcode.com/RockChinQ/LangBot)
|
||||
|
||||
<a href="https://langbot.app">主頁</a> |
|
||||
<a href="https://docs.langbot.app/zh/insight/features.html">規格特性</a> |
|
||||
<a href="https://docs.langbot.app/zh/insight/guide.html">部署文件</a> |
|
||||
<a href="https://docs.langbot.app/zh/plugin/plugin-intro.html">外掛介紹</a> |
|
||||
<a href="https://github.com/langbot-app/LangBot/issues/new?assignees=&labels=%E7%8B%AC%E7%AB%8B%E6%8F%92%E4%BB%B6&projects=&template=submit-plugin.yml&title=%5BPlugin%5D%3A+%E8%AF%B7%E6%B1%82%E7%99%BB%E8%AE%B0%E6%96%B0%E6%8F%92%E4%BB%B6">提交外掛</a>
|
||||
<a href="https://docs.langbot.app/zh/tags/readme.html">API 整合</a> |
|
||||
<a href="https://space.langbot.app">外掛市場</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">路線圖</a>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -82,11 +84,12 @@ docker compose up -d
|
||||
<img width="500" src="https://docs.langbot.app/ui/bot-page-en-rounded.png" />
|
||||
|
||||
|
||||
- 💬 大模型對話、Agent:支援多種大模型,適配群聊和私聊;具有多輪對話、工具調用、多模態、流式輸出能力,自帶 RAG(知識庫)實現,並深度適配 [Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io) 等 LLMOps 平台。
|
||||
- 💬 大模型對話、Agent:支援多種大模型,適配群聊和私聊;具有多輪對話、工具調用、多模態、流式輸出能力,自帶 RAG(知識庫)實現,並深度適配 [Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io)、[Langflow](https://langflow.org)等 LLMOps 平台。
|
||||
- 🤖 多平台支援:目前支援 QQ、QQ頻道、企業微信、個人微信、飛書、Discord、Telegram、KOOK、Slack、LINE 等平台。
|
||||
- 🛠️ 高穩定性、功能完備:原生支援訪問控制、限速、敏感詞過濾等機制;配置簡單,支援多種部署方式。支援多流水線配置,不同機器人用於不同應用場景。
|
||||
- 🛠️ 高穩定性、功能完備:原生支援訪問控制、限速、敏感詞過濾等機制;配置簡單,支援多種部署方式。
|
||||
- 🧩 外掛擴展、活躍社群:高穩定性、高安全性的生產級外掛系統;支援事件驅動、組件擴展等外掛機制;適配 Anthropic [MCP 協議](https://modelcontextprotocol.io/);目前已有數百個外掛。
|
||||
- 😻 Web 管理面板:支援通過瀏覽器管理 LangBot 實例,不再需要手動編寫配置文件。
|
||||
- 😻 Web 管理面板:提供先進的 WebUI 管理面板,用最直觀的方式配置、管理、監控機器人。
|
||||
- 📊 生產級特性:支援多流水線配置,不同機器人用於不同應用場景。具有全面的監控和異常處理能力。
|
||||
|
||||
詳細規格特性請訪問[文件](https://docs.langbot.app/zh/insight/features.html)。
|
||||
|
||||
|
||||
11
README_VI.md
11
README_VI.md
@@ -17,9 +17,11 @@
|
||||
<img src="https://img.shields.io/badge/python-3.10 ~ 3.13 -blue.svg" alt="python">
|
||||
|
||||
<a href="https://langbot.app">Trang chủ</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features.html">Tính năng</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide.html">Triển khai</a> |
|
||||
<a href="https://docs.langbot.app/en/plugin/plugin-intro.html">Plugin</a> |
|
||||
<a href="https://github.com/langbot-app/LangBot/issues/new?assignees=&labels=%E7%8B%AC%E7%AB%8B%E6%8F%92%E4%BB%B6&projects=&template=submit-plugin.yml&title=%5BPlugin%5D%3A+%E8%AF%B7%E6%B1%82%E7%99%BB%E8%AE%B0%E6%96%B0%E6%8F%92%E4%BB%B6">Gửi Plugin</a>
|
||||
<a href="https://docs.langbot.app/en/tags/readme.html">Tích hợp API</a> |
|
||||
<a href="https://space.langbot.app">Chợ Plugin</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">Lộ trình</a>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -82,11 +84,12 @@ Nhấp vào các nút Star và Watch ở góc trên bên phải của kho lưu t
|
||||
<img width="500" src="https://docs.langbot.app/ui/bot-page-en-rounded.png" />
|
||||
|
||||
|
||||
- 💬 Chat với LLM / Agent: Hỗ trợ nhiều LLM, thích ứng với chat nhóm và chat riêng tư; Hỗ trợ các cuộc trò chuyện nhiều vòng, gọi công cụ, khả năng đa phương thức và đầu ra streaming. Triển khai RAG (cơ sở kiến thức) tích hợp sẵn và tích hợp sâu với [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io) v.v. LLMOps platforms.
|
||||
- 💬 Chat với LLM / Agent: Hỗ trợ nhiều LLM, thích ứng với chat nhóm và chat riêng tư; Hỗ trợ các cuộc trò chuyện nhiều vòng, gọi công cụ, khả năng đa phương thức và đầu ra streaming. Triển khai RAG (cơ sở kiến thức) tích hợp sẵn và tích hợp sâu với [Dify](https://dify.ai), [Coze](https://coze.com), [n8n](https://n8n.io), [Langflow](https://langflow.org) v.v. LLMOps platforms.
|
||||
- 🤖 Hỗ trợ Đa nền tảng: Hiện hỗ trợ QQ, QQ Channel, WeCom, WeChat cá nhân, Lark, DingTalk, Discord, Telegram, KOOK, Slack, LINE, v.v.
|
||||
- 🛠️ Độ ổn định Cao, Tính năng Phong phú: Kiểm soát truy cập gốc, giới hạn tốc độ, lọc từ nhạy cảm, v.v.; Dễ sử dụng, hỗ trợ nhiều phương pháp triển khai. Hỗ trợ nhiều cấu hình pipeline, các bot khác nhau cho các kịch bản khác nhau.
|
||||
- 🛠️ Độ ổn định Cao, Tính năng Phong phú: Kiểm soát truy cập gốc, giới hạn tốc độ, lọc từ nhạy cảm, v.v.; Dễ sử dụng, hỗ trợ nhiều phương pháp triển khai.
|
||||
- 🧩 Mở rộng Plugin, Cộng đồng Hoạt động: Hỗ trợ các cơ chế plugin hướng sự kiện, mở rộng thành phần, v.v.; Tích hợp giao thức [MCP](https://modelcontextprotocol.io/) của Anthropic; Hiện có hàng trăng plugin.
|
||||
- 😻 Giao diện Web: Hỗ trợ quản lý các phiên bản LangBot thông qua trình duyệt. Không cần viết tệp cấu hình thủ công.
|
||||
- 📊 Tính năng Cấp sản xuất: Hỗ trợ nhiều cấu hình pipeline, các bot khác nhau cho các kịch bản khác nhau. Có khả năng giám sát toàn diện và xử lý ngoại lệ.
|
||||
|
||||
Để biết thêm thông số kỹ thuật chi tiết, vui lòng tham khảo [tài liệu](https://docs.langbot.app/en/insight/features.html).
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ services:
|
||||
restart: on-failure
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
command: ["uv", "run", "-m", "langbot_plugin.cli.__init__", "rt"]
|
||||
command: ["uv", "run", "--no-sync", "-m", "langbot_plugin.cli.__init__", "rt"]
|
||||
networks:
|
||||
- langbot_network
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "langbot"
|
||||
version = "4.6.5"
|
||||
description = "Easy-to-use global IM bot platform designed for LLM era"
|
||||
version = "4.8.3"
|
||||
description = "Production-grade platform for building agentic IM bots"
|
||||
readme = "README.md"
|
||||
license-files = ["LICENSE"]
|
||||
requires-python = ">=3.11,<4.0"
|
||||
@@ -17,13 +17,13 @@ dependencies = [
|
||||
"certifi>=2025.4.26",
|
||||
"colorlog~=6.6.0",
|
||||
"cryptography>=44.0.3",
|
||||
"dashscope>=1.23.2",
|
||||
"dashscope>=1.25.10",
|
||||
"dingtalk-stream>=0.24.0",
|
||||
"discord-py>=2.5.2",
|
||||
"pynacl>=1.5.0", # Required for Discord voice support
|
||||
"gewechat-client>=0.1.5",
|
||||
"lark-oapi>=1.4.15",
|
||||
"mcp>=1.8.1",
|
||||
"mcp>=1.25.0",
|
||||
"nakuru-project-idk>=0.0.2.1",
|
||||
"ollama>=0.4.8",
|
||||
"openai>1.0.0",
|
||||
@@ -63,8 +63,8 @@ dependencies = [
|
||||
"langchain-text-splitters>=0.0.1",
|
||||
"chromadb>=0.4.24",
|
||||
"qdrant-client (>=1.15.1,<2.0.0)",
|
||||
"pyseekdb>=0.1.0",
|
||||
"langbot-plugin==0.2.4",
|
||||
"pyseekdb==1.0.0b7",
|
||||
"langbot-plugin==0.2.5",
|
||||
"asyncpg>=0.30.0",
|
||||
"line-bot-sdk>=3.19.0",
|
||||
"tboxsdk>=0.0.10",
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""LangBot - Easy-to-use global IM bot platform designed for LLM era"""
|
||||
"""LangBot - Production-grade platform for building agentic IM bots"""
|
||||
|
||||
__version__ = '4.6.5'
|
||||
__version__ = '4.8.3'
|
||||
|
||||
@@ -347,10 +347,15 @@ class DingTalkClient:
|
||||
raise Exception(f'failed to send proactive massage to group: {traceback.format_exc()}')
|
||||
|
||||
async def create_and_card(
|
||||
self, temp_card_id: str, incoming_message: dingtalk_stream.ChatbotMessage, quote_origin: bool = False
|
||||
self,
|
||||
temp_card_id: str,
|
||||
incoming_message: dingtalk_stream.ChatbotMessage,
|
||||
quote_origin: bool = False,
|
||||
card_auto_layout: bool = False,
|
||||
):
|
||||
content_key = 'content'
|
||||
card_data = {content_key: ''}
|
||||
card_data = {}
|
||||
card_data['config'] = json.dumps({'autoLayout': card_auto_layout})
|
||||
card_data['content'] = ''
|
||||
|
||||
card_instance = dingtalk_stream.AICardReplier(self.client, incoming_message)
|
||||
# print(card_instance)
|
||||
|
||||
@@ -23,12 +23,21 @@ xml_template = """
|
||||
|
||||
|
||||
class OAClient:
|
||||
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str, logger: None, unified_mode: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
token: str,
|
||||
EncodingAESKey: str,
|
||||
AppID: str,
|
||||
Appsecret: str,
|
||||
logger: None,
|
||||
unified_mode: bool = False,
|
||||
api_base_url: str = 'https://api.weixin.qq.com',
|
||||
):
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
self.appid = AppID
|
||||
self.appsecret = Appsecret
|
||||
self.base_url = 'https://api.weixin.qq.com'
|
||||
self.base_url = api_base_url
|
||||
self.access_token = ''
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
@@ -208,12 +217,13 @@ class OAClientForLongerResponse:
|
||||
LoadingMessage: str,
|
||||
logger: None,
|
||||
unified_mode: bool = False,
|
||||
api_base_url: str = 'https://api.weixin.qq.com',
|
||||
):
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
self.appid = AppID
|
||||
self.appsecret = Appsecret
|
||||
self.base_url = 'https://api.weixin.qq.com'
|
||||
self.base_url = api_base_url
|
||||
self.access_token = ''
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
|
||||
@@ -85,7 +85,6 @@ class QQOfficialClient:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
|
||||
body = await req.get_data()
|
||||
|
||||
print(f'[QQ Official] Received request, body length: {len(body)}')
|
||||
@@ -96,7 +95,6 @@ class QQOfficialClient:
|
||||
|
||||
payload = json.loads(body)
|
||||
|
||||
|
||||
if payload.get('op') == 13:
|
||||
validation_data = payload.get('d')
|
||||
if not validation_data:
|
||||
@@ -276,21 +274,21 @@ class QQOfficialClient:
|
||||
seed = bot_secret
|
||||
while len(seed) < target_size:
|
||||
seed *= 2
|
||||
return seed[:target_size].encode("utf-8")
|
||||
return seed[:target_size].encode('utf-8')
|
||||
|
||||
async def verify(self, validation_payload: dict):
|
||||
seed = await self.repeat_seed(self.secret)
|
||||
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
|
||||
|
||||
event_ts = validation_payload.get("event_ts", "")
|
||||
plain_token = validation_payload.get("plain_token", "")
|
||||
event_ts = validation_payload.get('event_ts', '')
|
||||
plain_token = validation_payload.get('plain_token', '')
|
||||
msg = event_ts + plain_token
|
||||
|
||||
# sign
|
||||
signature = private_key.sign(msg.encode()).hex()
|
||||
|
||||
response = {
|
||||
"plain_token": plain_token,
|
||||
"signature": signature,
|
||||
'plain_token': plain_token,
|
||||
'signature': signature,
|
||||
}
|
||||
return response
|
||||
|
||||
@@ -36,7 +36,12 @@ class WecomBotEvent(dict):
|
||||
"""
|
||||
用户名称
|
||||
"""
|
||||
return self.get('username', '') or self.get('from', {}).get('alias', '') or self.get('from', {}).get('name', '') or self.userid
|
||||
return (
|
||||
self.get('username', '')
|
||||
or self.get('from', {}).get('alias', '')
|
||||
or self.get('from', {}).get('name', '')
|
||||
or self.userid
|
||||
)
|
||||
|
||||
@property
|
||||
def chatname(self) -> str:
|
||||
@@ -121,7 +126,7 @@ class WecomBotEvent(dict):
|
||||
消息id
|
||||
"""
|
||||
return self.get('msgid', '')
|
||||
|
||||
|
||||
@property
|
||||
def ai_bot_id(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -22,13 +22,14 @@ class WecomClient:
|
||||
contacts_secret: str,
|
||||
logger: None,
|
||||
unified_mode: bool = False,
|
||||
api_base_url: str = 'https://qyapi.weixin.qq.com/cgi-bin',
|
||||
):
|
||||
self.corpid = corpid
|
||||
self.secret = secret
|
||||
self.access_token_for_contacts = ''
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
|
||||
self.base_url = api_base_url
|
||||
self.access_token = ''
|
||||
self.secret_for_contacts = contacts_secret
|
||||
self.logger = logger
|
||||
@@ -56,7 +57,7 @@ class WecomClient:
|
||||
return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip())
|
||||
|
||||
async def get_access_token(self, secret):
|
||||
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}'
|
||||
url = f'{self.base_url}/gettoken?corpid={self.corpid}&corpsecret={secret}'
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
data = response.json()
|
||||
@@ -196,7 +197,7 @@ class WecomClient:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
|
||||
url = self.base_url + '/message/send?access_token=' + self.access_token
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=None) as client:
|
||||
params = {
|
||||
'touser': user_id,
|
||||
'msgtype': 'text',
|
||||
|
||||
@@ -13,13 +13,22 @@ import aiofiles
|
||||
|
||||
|
||||
class WecomCSClient:
|
||||
def __init__(self, corpid: str, secret: str, token: str, EncodingAESKey: str, logger: None, unified_mode: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
corpid: str,
|
||||
secret: str,
|
||||
token: str,
|
||||
EncodingAESKey: str,
|
||||
logger: None,
|
||||
unified_mode: bool = False,
|
||||
api_base_url: str = 'https://qyapi.weixin.qq.com/cgi-bin',
|
||||
):
|
||||
self.corpid = corpid
|
||||
self.secret = secret
|
||||
self.access_token_for_contacts = ''
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
|
||||
self.base_url = api_base_url
|
||||
self.access_token = ''
|
||||
self.logger = logger
|
||||
self.unified_mode = unified_mode
|
||||
@@ -66,7 +75,7 @@ class WecomCSClient:
|
||||
return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip())
|
||||
|
||||
async def get_access_token(self, secret):
|
||||
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}'
|
||||
url = f'{self.base_url}/gettoken?corpid={self.corpid}&corpsecret={secret}'
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
data = response.json()
|
||||
@@ -172,7 +181,7 @@ class WecomCSClient:
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
|
||||
url = f'https://qyapi.weixin.qq.com/cgi-bin/kf/send_msg?access_token={self.access_token}'
|
||||
url = f'{self.base_url}/kf/send_msg?access_token={self.access_token}'
|
||||
|
||||
payload = {
|
||||
'touser': external_userid,
|
||||
|
||||
325
src/langbot/pkg/api/http/controller/groups/monitoring.py
Normal file
325
src/langbot/pkg/api/http/controller/groups/monitoring.py
Normal file
@@ -0,0 +1,325 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import quart
|
||||
|
||||
from .. import group
|
||||
|
||||
|
||||
def parse_iso_datetime(datetime_str: str | None) -> datetime.datetime | None:
|
||||
"""Parse ISO 8601 datetime string, handling 'Z' suffix for UTC timezone"""
|
||||
if not datetime_str:
|
||||
return None
|
||||
# Replace 'Z' with '+00:00' for Python 3.10 compatibility
|
||||
if datetime_str.endswith('Z'):
|
||||
datetime_str = datetime_str[:-1] + '+00:00'
|
||||
dt = datetime.datetime.fromisoformat(datetime_str)
|
||||
# Convert to UTC and remove timezone info to match database storage (which stores UTC as naive datetime)
|
||||
if dt.tzinfo is not None:
|
||||
# Convert to UTC and remove timezone info
|
||||
dt = dt.astimezone(datetime.timezone.utc).replace(tzinfo=None)
|
||||
return dt
|
||||
|
||||
|
||||
@group.group_class('monitoring', '/api/v1/monitoring')
|
||||
class MonitoringRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/overview', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_overview() -> str:
|
||||
"""Get overview metrics"""
|
||||
# Parse query parameters
|
||||
bot_ids = quart.request.args.getlist('botId')
|
||||
pipeline_ids = quart.request.args.getlist('pipelineId')
|
||||
start_time_str = quart.request.args.get('startTime')
|
||||
end_time_str = quart.request.args.get('endTime')
|
||||
|
||||
# Parse datetime
|
||||
start_time = parse_iso_datetime(start_time_str)
|
||||
end_time = parse_iso_datetime(end_time_str)
|
||||
|
||||
metrics = await self.ap.monitoring_service.get_overview_metrics(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
return self.success(data=metrics)
|
||||
|
||||
@self.route('/messages', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_messages() -> str:
|
||||
"""Get message logs"""
|
||||
# Parse query parameters
|
||||
bot_ids = quart.request.args.getlist('botId')
|
||||
pipeline_ids = quart.request.args.getlist('pipelineId')
|
||||
start_time_str = quart.request.args.get('startTime')
|
||||
end_time_str = quart.request.args.get('endTime')
|
||||
limit = int(quart.request.args.get('limit', 100))
|
||||
offset = int(quart.request.args.get('offset', 0))
|
||||
|
||||
# Parse datetime
|
||||
start_time = parse_iso_datetime(start_time_str)
|
||||
end_time = parse_iso_datetime(end_time_str)
|
||||
|
||||
messages, total = await self.ap.monitoring_service.get_messages(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'messages': messages,
|
||||
'total': total,
|
||||
'limit': limit,
|
||||
'offset': offset,
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/llm-calls', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_llm_calls() -> str:
|
||||
"""Get LLM call records"""
|
||||
# Parse query parameters
|
||||
bot_ids = quart.request.args.getlist('botId')
|
||||
pipeline_ids = quart.request.args.getlist('pipelineId')
|
||||
start_time_str = quart.request.args.get('startTime')
|
||||
end_time_str = quart.request.args.get('endTime')
|
||||
limit = int(quart.request.args.get('limit', 100))
|
||||
offset = int(quart.request.args.get('offset', 0))
|
||||
|
||||
# Parse datetime
|
||||
start_time = parse_iso_datetime(start_time_str)
|
||||
end_time = parse_iso_datetime(end_time_str)
|
||||
|
||||
llm_calls, total = await self.ap.monitoring_service.get_llm_calls(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'llm_calls': llm_calls,
|
||||
'total': total,
|
||||
'limit': limit,
|
||||
'offset': offset,
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/embedding-calls', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_embedding_calls() -> str:
|
||||
"""Get embedding call records"""
|
||||
# Parse query parameters
|
||||
start_time_str = quart.request.args.get('startTime')
|
||||
end_time_str = quart.request.args.get('endTime')
|
||||
knowledge_base_id = quart.request.args.get('knowledgeBaseId')
|
||||
limit = int(quart.request.args.get('limit', 100))
|
||||
offset = int(quart.request.args.get('offset', 0))
|
||||
|
||||
# Parse datetime
|
||||
start_time = parse_iso_datetime(start_time_str)
|
||||
end_time = parse_iso_datetime(end_time_str)
|
||||
|
||||
embedding_calls, total = await self.ap.monitoring_service.get_embedding_calls(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
knowledge_base_id=knowledge_base_id if knowledge_base_id else None,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'embedding_calls': embedding_calls,
|
||||
'total': total,
|
||||
'limit': limit,
|
||||
'offset': offset,
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/sessions', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_sessions() -> str:
|
||||
"""Get session information"""
|
||||
# Parse query parameters
|
||||
bot_ids = quart.request.args.getlist('botId')
|
||||
pipeline_ids = quart.request.args.getlist('pipelineId')
|
||||
start_time_str = quart.request.args.get('startTime')
|
||||
end_time_str = quart.request.args.get('endTime')
|
||||
is_active_str = quart.request.args.get('isActive')
|
||||
limit = int(quart.request.args.get('limit', 100))
|
||||
offset = int(quart.request.args.get('offset', 0))
|
||||
|
||||
# Parse datetime
|
||||
start_time = parse_iso_datetime(start_time_str)
|
||||
end_time = parse_iso_datetime(end_time_str)
|
||||
|
||||
# Parse is_active
|
||||
is_active = None
|
||||
if is_active_str:
|
||||
is_active = is_active_str.lower() == 'true'
|
||||
|
||||
sessions, total = await self.ap.monitoring_service.get_sessions(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
is_active=is_active,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'sessions': sessions,
|
||||
'total': total,
|
||||
'limit': limit,
|
||||
'offset': offset,
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/errors', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_errors() -> str:
|
||||
"""Get error logs"""
|
||||
# Parse query parameters
|
||||
bot_ids = quart.request.args.getlist('botId')
|
||||
pipeline_ids = quart.request.args.getlist('pipelineId')
|
||||
start_time_str = quart.request.args.get('startTime')
|
||||
end_time_str = quart.request.args.get('endTime')
|
||||
limit = int(quart.request.args.get('limit', 100))
|
||||
offset = int(quart.request.args.get('offset', 0))
|
||||
|
||||
# Parse datetime
|
||||
start_time = parse_iso_datetime(start_time_str)
|
||||
end_time = parse_iso_datetime(end_time_str)
|
||||
|
||||
errors, total = await self.ap.monitoring_service.get_errors(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'errors': errors,
|
||||
'total': total,
|
||||
'limit': limit,
|
||||
'offset': offset,
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/data', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_all_data() -> str:
|
||||
"""Get all monitoring data in a single request"""
|
||||
# Parse query parameters
|
||||
bot_ids = quart.request.args.getlist('botId')
|
||||
pipeline_ids = quart.request.args.getlist('pipelineId')
|
||||
start_time_str = quart.request.args.get('startTime')
|
||||
end_time_str = quart.request.args.get('endTime')
|
||||
limit = int(quart.request.args.get('limit', 50))
|
||||
|
||||
# Parse datetime
|
||||
start_time = parse_iso_datetime(start_time_str)
|
||||
end_time = parse_iso_datetime(end_time_str)
|
||||
|
||||
# Get overview metrics
|
||||
overview = await self.ap.monitoring_service.get_overview_metrics(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
# Get messages
|
||||
messages, messages_total = await self.ap.monitoring_service.get_messages(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
# Get LLM calls
|
||||
llm_calls, llm_calls_total = await self.ap.monitoring_service.get_llm_calls(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
# Get sessions
|
||||
sessions, sessions_total = await self.ap.monitoring_service.get_sessions(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
is_active=None,
|
||||
limit=limit,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
# Get errors
|
||||
errors, errors_total = await self.ap.monitoring_service.get_errors(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
# Get embedding calls
|
||||
embedding_calls, embedding_calls_total = await self.ap.monitoring_service.get_embedding_calls(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'overview': overview,
|
||||
'messages': messages,
|
||||
'llmCalls': llm_calls,
|
||||
'embeddingCalls': embedding_calls,
|
||||
'sessions': sessions,
|
||||
'errors': errors,
|
||||
'totalCount': {
|
||||
'messages': messages_total,
|
||||
'llmCalls': llm_calls_total,
|
||||
'embeddingCalls': embedding_calls_total,
|
||||
'sessions': sessions_total,
|
||||
'errors': errors_total,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/sessions/<session_id>/analysis', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_session_analysis(session_id: str) -> str:
|
||||
"""Get detailed analysis for a specific session"""
|
||||
analysis = await self.ap.monitoring_service.get_session_analysis(session_id)
|
||||
|
||||
# Always return success with the analysis data
|
||||
# The frontend will handle the 'found: false' case
|
||||
return self.success(data=analysis)
|
||||
|
||||
@self.route('/messages/<message_id>/details', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_message_details(message_id: str) -> str:
|
||||
"""Get detailed information for a specific message"""
|
||||
details = await self.ap.monitoring_service.get_message_details(message_id)
|
||||
|
||||
if not details.get('found'):
|
||||
return self.error(message=f'Message {message_id} not found', code=404)
|
||||
|
||||
return self.success(data=details)
|
||||
@@ -9,12 +9,15 @@ class LLMModelsRouterGroup(group.RouterGroup):
|
||||
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
provider_uuid = quart.request.args.get('provider_uuid')
|
||||
if provider_uuid:
|
||||
return self.success(
|
||||
data={'models': await self.ap.llm_model_service.get_llm_models_by_provider(provider_uuid)}
|
||||
)
|
||||
return self.success(data={'models': await self.ap.llm_model_service.get_llm_models()})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
|
||||
model_uuid = await self.ap.llm_model_service.create_llm_model(json_data)
|
||||
|
||||
return self.success(data={'uuid': model_uuid})
|
||||
|
||||
@self.route('/<model_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
@@ -52,12 +55,19 @@ class EmbeddingModelsRouterGroup(group.RouterGroup):
|
||||
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
provider_uuid = quart.request.args.get('provider_uuid')
|
||||
if provider_uuid:
|
||||
return self.success(
|
||||
data={
|
||||
'models': await self.ap.embedding_models_service.get_embedding_models_by_provider(
|
||||
provider_uuid
|
||||
)
|
||||
}
|
||||
)
|
||||
return self.success(data={'models': await self.ap.embedding_models_service.get_embedding_models()})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
|
||||
model_uuid = await self.ap.embedding_models_service.create_embedding_model(json_data)
|
||||
|
||||
return self.success(data={'uuid': model_uuid})
|
||||
|
||||
@self.route('/<model_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import quart
|
||||
|
||||
from ... import group
|
||||
|
||||
|
||||
@group.group_class('models/providers', '/api/v1/provider/providers')
|
||||
class ModelProvidersRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
providers = await self.ap.provider_service.get_providers()
|
||||
# Add model counts
|
||||
for provider in providers:
|
||||
counts = await self.ap.provider_service.get_provider_model_counts(provider['uuid'])
|
||||
provider['llm_count'] = counts['llm_count']
|
||||
provider['embedding_count'] = counts['embedding_count']
|
||||
return self.success(data={'providers': providers})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
provider_uuid = await self.ap.provider_service.create_provider(json_data)
|
||||
return self.success(data={'uuid': provider_uuid})
|
||||
|
||||
@self.route(
|
||||
'/<provider_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY
|
||||
)
|
||||
async def _(provider_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
provider = await self.ap.provider_service.get_provider(provider_uuid)
|
||||
if provider is None:
|
||||
return self.http_status(404, -1, 'provider not found')
|
||||
counts = await self.ap.provider_service.get_provider_model_counts(provider_uuid)
|
||||
provider['llm_count'] = counts['llm_count']
|
||||
provider['embedding_count'] = counts['embedding_count']
|
||||
return self.success(data={'provider': provider})
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
await self.ap.provider_service.update_provider(provider_uuid, json_data)
|
||||
return self.success()
|
||||
elif quart.request.method == 'DELETE':
|
||||
try:
|
||||
await self.ap.provider_service.delete_provider(provider_uuid)
|
||||
return self.success()
|
||||
except ValueError as e:
|
||||
return self.http_status(400, -1, str(e))
|
||||
@@ -17,14 +17,13 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
'enable_marketplace', True
|
||||
),
|
||||
'cloud_service_url': (
|
||||
self.ap.instance_config.data.get('plugin', {}).get(
|
||||
'cloud_service_url', 'https://space.langbot.app'
|
||||
)
|
||||
if 'cloud_service_url' in self.ap.instance_config.data.get('plugin', {})
|
||||
else 'https://space.langbot.app'
|
||||
self.ap.instance_config.data.get('space', {}).get('url', 'https://space.langbot.app')
|
||||
),
|
||||
'allow_change_password': self.ap.instance_config.data.get('system', {}).get(
|
||||
'allow_change_password', True
|
||||
'allow_modify_login_info': self.ap.instance_config.data.get('system', {}).get(
|
||||
'allow_modify_login_info', True
|
||||
),
|
||||
'disable_models_service': self.ap.instance_config.data.get('space', {}).get(
|
||||
'disable_models_service', False
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import quart
|
||||
import argon2
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from .. import group
|
||||
from .....entity.errors import account as account_errors
|
||||
|
||||
|
||||
@group.group_class('user', '/api/v1/user')
|
||||
@@ -33,6 +35,8 @@ class UserRouterGroup(group.RouterGroup):
|
||||
token = await self.ap.user_service.authenticate(json_data['user'], json_data['password'])
|
||||
except argon2.exceptions.VerifyMismatchError:
|
||||
return self.fail(1, 'Invalid username or password')
|
||||
except ValueError as e:
|
||||
return self.fail(1, str(e))
|
||||
|
||||
return self.success(data={'token': token})
|
||||
|
||||
@@ -71,11 +75,11 @@ class UserRouterGroup(group.RouterGroup):
|
||||
@self.route('/change-password', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(user_email: str) -> str:
|
||||
# Check if password change is allowed
|
||||
allow_change_password = self.ap.instance_config.data.get('system', {}).get(
|
||||
'allow_change_password', True
|
||||
allow_modify_login_info = self.ap.instance_config.data.get('system', {}).get(
|
||||
'allow_modify_login_info', True
|
||||
)
|
||||
if not allow_change_password:
|
||||
return self.http_status(403, -1, 'Password change is disabled')
|
||||
if not allow_modify_login_info:
|
||||
return self.http_status(403, -1, 'Modifying login info is disabled')
|
||||
|
||||
json_data = await quart.request.json
|
||||
|
||||
@@ -90,3 +94,169 @@ class UserRouterGroup(group.RouterGroup):
|
||||
return self.http_status(400, -1, str(e))
|
||||
|
||||
return self.success(data={'user': user_email})
|
||||
|
||||
# Space OAuth endpoints (redirect flow)
|
||||
|
||||
@self.route('/space/authorize-url', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def _() -> str:
|
||||
"""Get Space OAuth authorization URL for redirect"""
|
||||
redirect_uri = quart.request.args.get('redirect_uri', '')
|
||||
state = quart.request.args.get('state', '')
|
||||
|
||||
if not redirect_uri:
|
||||
return self.fail(1, 'Missing redirect_uri parameter')
|
||||
|
||||
try:
|
||||
authorize_url = self.ap.space_service.get_oauth_authorize_url(redirect_uri, state)
|
||||
return self.success(data={'authorize_url': authorize_url})
|
||||
except Exception as e:
|
||||
return self.fail(1, str(e))
|
||||
|
||||
@self.route('/space/callback', methods=['POST'], auth_type=group.AuthType.NONE)
|
||||
async def _() -> str:
|
||||
"""Handle OAuth callback - exchange code for tokens and authenticate"""
|
||||
json_data = await quart.request.json
|
||||
code = json_data.get('code')
|
||||
|
||||
if not code:
|
||||
return self.fail(1, 'Missing authorization code')
|
||||
|
||||
try:
|
||||
# Exchange code for tokens
|
||||
token_data = await self.ap.space_service.exchange_oauth_code(code)
|
||||
access_token = token_data.get('access_token')
|
||||
refresh_token = token_data.get('refresh_token')
|
||||
expires_in = token_data.get('expires_in', 0)
|
||||
|
||||
if not access_token:
|
||||
return self.fail(1, 'Failed to get access token from Space')
|
||||
|
||||
# Authenticate and create/update local user
|
||||
jwt_token, user_obj = await self.ap.user_service.authenticate_space_user(
|
||||
access_token, refresh_token, expires_in
|
||||
)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'token': jwt_token,
|
||||
'user': user_obj.user,
|
||||
}
|
||||
)
|
||||
except account_errors.AccountEmailMismatchError as e:
|
||||
return self.fail(3, str(e))
|
||||
except ValueError as e:
|
||||
traceback.print_exc()
|
||||
return self.fail(1, str(e))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return self.fail(2, f'OAuth callback failed: {str(e)}')
|
||||
|
||||
@self.route('/info', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(user_email: str) -> str:
|
||||
"""Get current user information including account type"""
|
||||
user_obj = await self.ap.user_service.get_user_by_email(user_email)
|
||||
|
||||
if user_obj is None:
|
||||
return self.http_status(404, -1, 'User not found')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'user': user_obj.user,
|
||||
'account_type': user_obj.account_type,
|
||||
'has_password': bool(user_obj.password and user_obj.password.strip()),
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/space-credits', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(user_email: str) -> str:
|
||||
"""Get Space credits balance for current user"""
|
||||
credits = await self.ap.space_service.get_credits(user_email)
|
||||
return self.success(data={'credits': credits})
|
||||
|
||||
@self.route('/account-info', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def _() -> str:
|
||||
"""Get account info for login page (account type and has_password)"""
|
||||
if not await self.ap.user_service.is_initialized():
|
||||
return self.success(data={'initialized': False})
|
||||
|
||||
user_obj = await self.ap.user_service.get_first_user()
|
||||
if user_obj is None:
|
||||
return self.success(data={'initialized': False})
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'initialized': True,
|
||||
'account_type': user_obj.account_type,
|
||||
'has_password': bool(user_obj.password and user_obj.password.strip()),
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/set-password', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(user_email: str) -> str:
|
||||
"""Set password for Space account (first time) or change password"""
|
||||
json_data = await quart.request.json
|
||||
new_password = json_data.get('new_password')
|
||||
current_password = json_data.get('current_password')
|
||||
|
||||
if not new_password:
|
||||
return self.http_status(400, -1, 'New password is required')
|
||||
|
||||
user_obj = await self.ap.user_service.get_user_by_email(user_email)
|
||||
if user_obj is None:
|
||||
return self.http_status(404, -1, 'User not found')
|
||||
|
||||
try:
|
||||
await self.ap.user_service.set_password(user_email, new_password, current_password)
|
||||
return self.success(data={'user': user_email})
|
||||
except ValueError as e:
|
||||
return self.http_status(400, -1, str(e))
|
||||
except argon2.exceptions.VerifyMismatchError:
|
||||
return self.http_status(400, -1, 'Current password is incorrect')
|
||||
|
||||
@self.route('/bind-space', methods=['POST'], auth_type=group.AuthType.NONE)
|
||||
async def _() -> str:
|
||||
"""Bind Space account to existing local account"""
|
||||
# Check if modifying login info is allowed
|
||||
allow_modify_login_info = self.ap.instance_config.data.get('system', {}).get(
|
||||
'allow_modify_login_info', True
|
||||
)
|
||||
if not allow_modify_login_info:
|
||||
return self.http_status(403, -1, 'Modifying login info is disabled')
|
||||
|
||||
json_data = await quart.request.json
|
||||
code = json_data.get('code')
|
||||
state = json_data.get('state') # JWT token passed as state
|
||||
|
||||
if not code:
|
||||
return self.http_status(400, -1, 'Missing authorization code')
|
||||
|
||||
if not state:
|
||||
return self.http_status(400, -1, 'Missing state parameter')
|
||||
|
||||
# Verify state is a valid JWT token
|
||||
try:
|
||||
user_email = await self.ap.user_service.verify_jwt_token(state)
|
||||
except Exception:
|
||||
return self.http_status(401, -1, 'Invalid or expired state')
|
||||
|
||||
user_obj = await self.ap.user_service.get_user_by_email(user_email)
|
||||
if user_obj is None:
|
||||
return self.http_status(404, -1, 'User not found')
|
||||
|
||||
if user_obj.account_type != 'local':
|
||||
return self.http_status(400, -1, 'Only local accounts can bind to Space')
|
||||
|
||||
try:
|
||||
updated_user = await self.ap.user_service.bind_space_account(user_email, code)
|
||||
jwt_token = await self.ap.user_service.generate_jwt_token(updated_user.user)
|
||||
return self.success(
|
||||
data={
|
||||
'token': jwt_token,
|
||||
'user': updated_user.user,
|
||||
'account_type': updated_user.account_type,
|
||||
}
|
||||
)
|
||||
except ValueError as e:
|
||||
return self.http_status(400, -1, str(e))
|
||||
except Exception as e:
|
||||
return self.http_status(500, -1, f'Failed to bind Space account: {str(e)}')
|
||||
|
||||
@@ -30,7 +30,6 @@ class WebhookRouterGroup(group.RouterGroup):
|
||||
适配器返回的响应
|
||||
"""
|
||||
try:
|
||||
|
||||
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
|
||||
|
||||
if not runtime_bot:
|
||||
@@ -39,11 +38,9 @@ class WebhookRouterGroup(group.RouterGroup):
|
||||
if not runtime_bot.enable:
|
||||
return quart.jsonify({'error': 'Bot is disabled'}), 403
|
||||
|
||||
|
||||
if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'):
|
||||
return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501
|
||||
|
||||
|
||||
response = await runtime_bot.adapter.handle_unified_webhook(
|
||||
bot_uuid=bot_uuid,
|
||||
path=path,
|
||||
|
||||
@@ -59,7 +59,16 @@ class BotService:
|
||||
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
|
||||
|
||||
# Webhook URL for unified webhook adapters (independent of bot running state)
|
||||
if persistence_bot['adapter'] in ['wecom', 'wecombot', 'officialaccount', 'qqofficial', 'slack', 'wecomcs', 'LINE', 'lark']:
|
||||
if persistence_bot['adapter'] in [
|
||||
'wecom',
|
||||
'wecombot',
|
||||
'officialaccount',
|
||||
'qqofficial',
|
||||
'slack',
|
||||
'wecomcs',
|
||||
'LINE',
|
||||
'lark',
|
||||
]:
|
||||
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
||||
webhook_url = f'/bots/{bot_uuid}'
|
||||
adapter_runtime_values['webhook_url'] = webhook_url
|
||||
|
||||
@@ -11,6 +11,18 @@ from ....entity.persistence import pipeline as persistence_pipeline
|
||||
from ....provider.modelmgr import requester as model_requester
|
||||
|
||||
|
||||
def _parse_provider_api_keys(provider_dict: dict) -> dict:
|
||||
"""Parse api_keys if it's a JSON string"""
|
||||
if isinstance(provider_dict.get('api_keys'), str):
|
||||
import json
|
||||
|
||||
try:
|
||||
provider_dict['api_keys'] = json.loads(provider_dict['api_keys'])
|
||||
except Exception:
|
||||
provider_dict['api_keys'] = []
|
||||
return provider_dict
|
||||
|
||||
|
||||
class LLMModelsService:
|
||||
ap: app.Application
|
||||
|
||||
@@ -18,59 +30,131 @@ class LLMModelsService:
|
||||
self.ap = ap
|
||||
|
||||
async def get_llm_models(self, include_secret: bool = True) -> list[dict]:
|
||||
"""Get all LLM models with provider info"""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
|
||||
|
||||
models = result.all()
|
||||
|
||||
masked_columns = []
|
||||
if not include_secret:
|
||||
masked_columns = ['api_keys']
|
||||
# Get all providers for lookup
|
||||
providers_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider)
|
||||
)
|
||||
providers = {p.uuid: p for p in providers_result.all()}
|
||||
|
||||
return [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model, masked_columns)
|
||||
for model in models
|
||||
]
|
||||
models_list = []
|
||||
for model in models:
|
||||
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
|
||||
provider = providers.get(model.provider_uuid)
|
||||
if provider:
|
||||
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
||||
provider_dict = _parse_provider_api_keys(provider_dict)
|
||||
if not include_secret:
|
||||
provider_dict['api_keys'] = ['***'] * len(provider_dict.get('api_keys', []))
|
||||
model_dict['provider'] = provider_dict
|
||||
models_list.append(model_dict)
|
||||
|
||||
async def create_llm_model(self, model_data: dict) -> str:
|
||||
model_data['uuid'] = str(uuid.uuid4())
|
||||
return models_list
|
||||
|
||||
async def get_llm_models_by_provider(self, provider_uuid: str) -> list[dict]:
|
||||
"""Get LLM models by provider UUID"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.LLMModel).where(
|
||||
persistence_model.LLMModel.provider_uuid == provider_uuid
|
||||
)
|
||||
)
|
||||
models = result.all()
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, m) for m in models]
|
||||
|
||||
async def create_llm_model(
|
||||
self, model_data: dict, preserve_uuid: bool = False, auto_set_to_default_pipeline: bool = True
|
||||
) -> str:
|
||||
"""Create a new LLM model"""
|
||||
if not preserve_uuid:
|
||||
model_data['uuid'] = str(uuid.uuid4())
|
||||
|
||||
# Handle provider creation if needed
|
||||
if 'provider' in model_data:
|
||||
provider_data = model_data.pop('provider')
|
||||
if provider_data.get('uuid'):
|
||||
model_data['provider_uuid'] = provider_data['uuid']
|
||||
else:
|
||||
# Create new provider
|
||||
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
||||
requester=provider_data.get('requester', ''),
|
||||
base_url=provider_data.get('base_url', ''),
|
||||
api_keys=provider_data.get('api_keys', []),
|
||||
)
|
||||
model_data['provider_uuid'] = provider_uuid
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_model.LLMModel).values(**model_data))
|
||||
|
||||
llm_model = await self.get_llm_model(model_data['uuid'])
|
||||
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
||||
if runtime_provider is None:
|
||||
raise Exception('provider not found')
|
||||
|
||||
await self.ap.model_mgr.load_llm_model(llm_model)
|
||||
|
||||
# check if default pipeline has no model bound
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
||||
persistence_pipeline.LegacyPipeline.is_default == True
|
||||
)
|
||||
runtime_llm_model = await self.ap.model_mgr.load_llm_model_with_provider(
|
||||
persistence_model.LLMModel(**model_data),
|
||||
runtime_provider,
|
||||
)
|
||||
pipeline = result.first()
|
||||
if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '':
|
||||
pipeline_config = pipeline.config
|
||||
pipeline_config['ai']['local-agent']['model'] = model_data['uuid']
|
||||
pipeline_data = {'config': pipeline_config}
|
||||
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
|
||||
self.ap.model_mgr.llm_models.append(runtime_llm_model)
|
||||
|
||||
if auto_set_to_default_pipeline:
|
||||
# set the default pipeline model to this model
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
||||
persistence_pipeline.LegacyPipeline.is_default == True
|
||||
)
|
||||
)
|
||||
pipeline = result.first()
|
||||
if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '':
|
||||
pipeline_config = pipeline.config
|
||||
pipeline_config['ai']['local-agent']['model'] = model_data['uuid']
|
||||
pipeline_data = {'config': pipeline_config}
|
||||
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
|
||||
|
||||
return model_data['uuid']
|
||||
|
||||
async def get_llm_model(self, model_uuid: str) -> dict | None:
|
||||
"""Get a single LLM model with provider info"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
|
||||
)
|
||||
|
||||
model = result.first()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
|
||||
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
|
||||
|
||||
# Get provider
|
||||
provider_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||
persistence_model.ModelProvider.uuid == model.provider_uuid
|
||||
)
|
||||
)
|
||||
provider = provider_result.first()
|
||||
if provider:
|
||||
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
||||
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
||||
|
||||
return model_dict
|
||||
|
||||
async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
|
||||
"""Update an existing LLM model"""
|
||||
if 'uuid' in model_data:
|
||||
del model_data['uuid']
|
||||
|
||||
# Handle provider update if needed
|
||||
if 'provider' in model_data:
|
||||
provider_data = model_data.pop('provider')
|
||||
if provider_data.get('uuid'):
|
||||
model_data['provider_uuid'] = provider_data['uuid']
|
||||
else:
|
||||
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
||||
requester=provider_data.get('requester', ''),
|
||||
base_url=provider_data.get('base_url', ''),
|
||||
api_keys=provider_data.get('api_keys', []),
|
||||
)
|
||||
model_data['provider_uuid'] = provider_uuid
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.LLMModel)
|
||||
.where(persistence_model.LLMModel.uuid == model_uuid)
|
||||
@@ -79,18 +163,25 @@ class LLMModelsService:
|
||||
|
||||
await self.ap.model_mgr.remove_llm_model(model_uuid)
|
||||
|
||||
llm_model = await self.get_llm_model(model_uuid)
|
||||
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
||||
if runtime_provider is None:
|
||||
raise Exception('provider not found')
|
||||
|
||||
await self.ap.model_mgr.load_llm_model(llm_model)
|
||||
runtime_llm_model = await self.ap.model_mgr.load_llm_model_with_provider(
|
||||
persistence_model.LLMModel(**model_data),
|
||||
runtime_provider,
|
||||
)
|
||||
self.ap.model_mgr.llm_models.append(runtime_llm_model)
|
||||
|
||||
async def delete_llm_model(self, model_uuid: str) -> None:
|
||||
"""Delete an LLM model"""
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
|
||||
)
|
||||
|
||||
await self.ap.model_mgr.remove_llm_model(model_uuid)
|
||||
|
||||
async def test_llm_model(self, model_uuid: str, model_data: dict) -> None:
|
||||
"""Test an LLM model"""
|
||||
runtime_llm_model: model_requester.RuntimeLLMModel | None = None
|
||||
|
||||
if model_uuid != '_':
|
||||
@@ -98,25 +189,18 @@ class LLMModelsService:
|
||||
if model.model_entity.uuid == model_uuid:
|
||||
runtime_llm_model = model
|
||||
break
|
||||
|
||||
if runtime_llm_model is None:
|
||||
raise Exception('model not found')
|
||||
|
||||
else:
|
||||
runtime_llm_model = await self.ap.model_mgr.init_runtime_llm_model(model_data)
|
||||
runtime_llm_model = await self.ap.model_mgr.init_temporary_runtime_llm_model(model_data)
|
||||
|
||||
# Mon Nov 10 2025: Commented for some providers may not support thinking parameter
|
||||
# # 有些模型厂商默认开启了思考功能,测试容易延迟
|
||||
# extra_args = model_data.get('extra_args', {})
|
||||
# if not extra_args or 'thinking' not in extra_args:
|
||||
# extra_args['thinking'] = {'type': 'disabled'}
|
||||
|
||||
await runtime_llm_model.requester.invoke_llm(
|
||||
extra_args = model_data.get('extra_args', {})
|
||||
await runtime_llm_model.provider.invoke_llm(
|
||||
query=None,
|
||||
model=runtime_llm_model,
|
||||
messages=[provider_message.Message(role='user', content='Hello, world! Please just reply a "Hello".')],
|
||||
funcs=[],
|
||||
# extra_args=extra_args,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
|
||||
@@ -127,42 +211,111 @@ class EmbeddingModelsService:
|
||||
self.ap = ap
|
||||
|
||||
async def get_embedding_models(self) -> list[dict]:
|
||||
"""Get all embedding models with provider info"""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
|
||||
|
||||
models = result.all()
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) for model in models]
|
||||
|
||||
async def create_embedding_model(self, model_data: dict) -> str:
|
||||
model_data['uuid'] = str(uuid.uuid4())
|
||||
providers_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider)
|
||||
)
|
||||
providers = {p.uuid: p for p in providers_result.all()}
|
||||
|
||||
models_list = []
|
||||
for model in models:
|
||||
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model)
|
||||
provider = providers.get(model.provider_uuid)
|
||||
if provider:
|
||||
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
||||
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
||||
models_list.append(model_dict)
|
||||
|
||||
return models_list
|
||||
|
||||
async def get_embedding_models_by_provider(self, provider_uuid: str) -> list[dict]:
|
||||
"""Get embedding models by provider UUID"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.EmbeddingModel).where(
|
||||
persistence_model.EmbeddingModel.provider_uuid == provider_uuid
|
||||
)
|
||||
)
|
||||
models = result.all()
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, m) for m in models]
|
||||
|
||||
async def create_embedding_model(self, model_data: dict, preserve_uuid: bool = False) -> str:
|
||||
"""Create a new embedding model"""
|
||||
if not preserve_uuid:
|
||||
model_data['uuid'] = str(uuid.uuid4())
|
||||
|
||||
if 'provider' in model_data:
|
||||
provider_data = model_data.pop('provider')
|
||||
if provider_data.get('uuid'):
|
||||
model_data['provider_uuid'] = provider_data['uuid']
|
||||
else:
|
||||
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
||||
requester=provider_data.get('requester', ''),
|
||||
base_url=provider_data.get('base_url', ''),
|
||||
api_keys=provider_data.get('api_keys', []),
|
||||
)
|
||||
model_data['provider_uuid'] = provider_uuid
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data)
|
||||
)
|
||||
|
||||
embedding_model = await self.get_embedding_model(model_data['uuid'])
|
||||
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
||||
if runtime_provider is None:
|
||||
raise Exception('provider not found')
|
||||
|
||||
await self.ap.model_mgr.load_embedding_model(embedding_model)
|
||||
runtime_embedding_model = await self.ap.model_mgr.load_embedding_model_with_provider(
|
||||
persistence_model.EmbeddingModel(**model_data),
|
||||
runtime_provider,
|
||||
)
|
||||
self.ap.model_mgr.embedding_models.append(runtime_embedding_model)
|
||||
|
||||
return model_data['uuid']
|
||||
|
||||
async def get_embedding_model(self, model_uuid: str) -> dict | None:
|
||||
"""Get a single embedding model with provider info"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.EmbeddingModel).where(
|
||||
persistence_model.EmbeddingModel.uuid == model_uuid
|
||||
)
|
||||
)
|
||||
|
||||
model = result.first()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model)
|
||||
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model)
|
||||
|
||||
provider_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||
persistence_model.ModelProvider.uuid == model.provider_uuid
|
||||
)
|
||||
)
|
||||
provider = provider_result.first()
|
||||
if provider:
|
||||
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
||||
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
||||
|
||||
return model_dict
|
||||
|
||||
async def update_embedding_model(self, model_uuid: str, model_data: dict) -> None:
|
||||
"""Update an existing embedding model"""
|
||||
if 'uuid' in model_data:
|
||||
del model_data['uuid']
|
||||
|
||||
if 'provider' in model_data:
|
||||
provider_data = model_data.pop('provider')
|
||||
if provider_data.get('uuid'):
|
||||
model_data['provider_uuid'] = provider_data['uuid']
|
||||
else:
|
||||
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
||||
requester=provider_data.get('requester', ''),
|
||||
base_url=provider_data.get('base_url', ''),
|
||||
api_keys=provider_data.get('api_keys', []),
|
||||
)
|
||||
model_data['provider_uuid'] = provider_uuid
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.EmbeddingModel)
|
||||
.where(persistence_model.EmbeddingModel.uuid == model_uuid)
|
||||
@@ -171,20 +324,27 @@ class EmbeddingModelsService:
|
||||
|
||||
await self.ap.model_mgr.remove_embedding_model(model_uuid)
|
||||
|
||||
embedding_model = await self.get_embedding_model(model_uuid)
|
||||
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
||||
if runtime_provider is None:
|
||||
raise Exception('provider not found')
|
||||
|
||||
await self.ap.model_mgr.load_embedding_model(embedding_model)
|
||||
runtime_embedding_model = await self.ap.model_mgr.load_embedding_model_with_provider(
|
||||
persistence_model.EmbeddingModel(**model_data),
|
||||
runtime_provider,
|
||||
)
|
||||
self.ap.model_mgr.embedding_models.append(runtime_embedding_model)
|
||||
|
||||
async def delete_embedding_model(self, model_uuid: str) -> None:
|
||||
"""Delete an embedding model"""
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_model.EmbeddingModel).where(
|
||||
persistence_model.EmbeddingModel.uuid == model_uuid
|
||||
)
|
||||
)
|
||||
|
||||
await self.ap.model_mgr.remove_embedding_model(model_uuid)
|
||||
|
||||
async def test_embedding_model(self, model_uuid: str, model_data: dict) -> None:
|
||||
"""Test an embedding model"""
|
||||
runtime_embedding_model: model_requester.RuntimeEmbeddingModel | None = None
|
||||
|
||||
if model_uuid != '_':
|
||||
@@ -192,14 +352,12 @@ class EmbeddingModelsService:
|
||||
if model.model_entity.uuid == model_uuid:
|
||||
runtime_embedding_model = model
|
||||
break
|
||||
|
||||
if runtime_embedding_model is None:
|
||||
raise Exception('model not found')
|
||||
|
||||
else:
|
||||
runtime_embedding_model = await self.ap.model_mgr.init_runtime_embedding_model(model_data)
|
||||
runtime_embedding_model = await self.ap.model_mgr.init_temporary_runtime_embedding_model(model_data)
|
||||
|
||||
await runtime_embedding_model.requester.invoke_embedding(
|
||||
await runtime_embedding_model.provider.invoke_embedding(
|
||||
model=runtime_embedding_model,
|
||||
input_text=['Hello, world!'],
|
||||
extra_args={},
|
||||
|
||||
796
src/langbot/pkg/api/http/service/monitoring.py
Normal file
796
src/langbot/pkg/api/http/service/monitoring.py
Normal file
@@ -0,0 +1,796 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import datetime
|
||||
import sqlalchemy
|
||||
|
||||
from ....core import app
|
||||
from ....entity.persistence import monitoring as persistence_monitoring
|
||||
|
||||
|
||||
class MonitoringService:
|
||||
"""Monitoring service"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
# ========== Recording Methods ==========
|
||||
|
||||
async def record_message(
|
||||
self,
|
||||
bot_id: str,
|
||||
bot_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_name: str,
|
||||
message_content: str,
|
||||
session_id: str,
|
||||
status: str = 'success',
|
||||
level: str = 'info',
|
||||
platform: str | None = None,
|
||||
user_id: str | None = None,
|
||||
runner_name: str | None = None,
|
||||
variables: str | None = None,
|
||||
) -> str:
|
||||
"""Record a message"""
|
||||
message_id = str(uuid.uuid4())
|
||||
message_data = {
|
||||
'id': message_id,
|
||||
'timestamp': datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
'bot_id': bot_id,
|
||||
'bot_name': bot_name,
|
||||
'pipeline_id': pipeline_id,
|
||||
'pipeline_name': pipeline_name,
|
||||
'message_content': message_content,
|
||||
'session_id': session_id,
|
||||
'status': status,
|
||||
'level': level,
|
||||
'platform': platform,
|
||||
'user_id': user_id,
|
||||
'runner_name': runner_name,
|
||||
'variables': variables,
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_monitoring.MonitoringMessage).values(message_data)
|
||||
)
|
||||
|
||||
return message_id
|
||||
|
||||
async def record_llm_call(
|
||||
self,
|
||||
bot_id: str,
|
||||
bot_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_name: str,
|
||||
session_id: str,
|
||||
model_name: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
duration: int,
|
||||
status: str = 'success',
|
||||
cost: float | None = None,
|
||||
error_message: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> str:
|
||||
"""Record an LLM call"""
|
||||
call_id = str(uuid.uuid4())
|
||||
call_data = {
|
||||
'id': call_id,
|
||||
'timestamp': datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
'model_name': model_name,
|
||||
'input_tokens': input_tokens,
|
||||
'output_tokens': output_tokens,
|
||||
'total_tokens': input_tokens + output_tokens,
|
||||
'duration': duration,
|
||||
'cost': cost,
|
||||
'status': status,
|
||||
'bot_id': bot_id,
|
||||
'bot_name': bot_name,
|
||||
'pipeline_id': pipeline_id,
|
||||
'pipeline_name': pipeline_name,
|
||||
'session_id': session_id,
|
||||
'error_message': error_message,
|
||||
'message_id': message_id,
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_monitoring.MonitoringLLMCall).values(call_data)
|
||||
)
|
||||
|
||||
return call_id
|
||||
|
||||
async def record_embedding_call(
|
||||
self,
|
||||
model_name: str,
|
||||
prompt_tokens: int,
|
||||
total_tokens: int,
|
||||
duration: int,
|
||||
input_count: int,
|
||||
status: str = 'success',
|
||||
error_message: str | None = None,
|
||||
knowledge_base_id: str | None = None,
|
||||
query_text: str | None = None,
|
||||
session_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
call_type: str | None = None,
|
||||
) -> str:
|
||||
"""Record an embedding call"""
|
||||
call_id = str(uuid.uuid4())
|
||||
call_data = {
|
||||
'id': call_id,
|
||||
'timestamp': datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
'model_name': model_name,
|
||||
'prompt_tokens': prompt_tokens,
|
||||
'total_tokens': total_tokens,
|
||||
'duration': duration,
|
||||
'input_count': input_count,
|
||||
'status': status,
|
||||
'error_message': error_message,
|
||||
'knowledge_base_id': knowledge_base_id,
|
||||
'query_text': query_text,
|
||||
'session_id': session_id,
|
||||
'message_id': message_id,
|
||||
'call_type': call_type,
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_monitoring.MonitoringEmbeddingCall).values(call_data)
|
||||
)
|
||||
|
||||
return call_id
|
||||
|
||||
async def record_session_start(
|
||||
self,
|
||||
session_id: str,
|
||||
bot_id: str,
|
||||
bot_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_name: str,
|
||||
platform: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""Record a new session"""
|
||||
session_data = {
|
||||
'session_id': session_id,
|
||||
'bot_id': bot_id,
|
||||
'bot_name': bot_name,
|
||||
'pipeline_id': pipeline_id,
|
||||
'pipeline_name': pipeline_name,
|
||||
'message_count': 0,
|
||||
'start_time': datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
'last_activity': datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
'is_active': True,
|
||||
'platform': platform,
|
||||
'user_id': user_id,
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_monitoring.MonitoringSession).values(session_data)
|
||||
)
|
||||
|
||||
async def update_session_activity(
|
||||
self,
|
||||
session_id: str,
|
||||
pipeline_id: str | None = None,
|
||||
pipeline_name: str | None = None,
|
||||
) -> bool:
|
||||
"""Update session last activity time and increment message count.
|
||||
|
||||
Also updates pipeline info if the bot's pipeline has changed.
|
||||
|
||||
Returns:
|
||||
True if session was found and updated, False if session doesn't exist.
|
||||
"""
|
||||
update_values = {
|
||||
'last_activity': datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
'message_count': persistence_monitoring.MonitoringSession.message_count + 1,
|
||||
}
|
||||
|
||||
# Update pipeline info if provided (handles pipeline switch)
|
||||
if pipeline_id is not None:
|
||||
update_values['pipeline_id'] = pipeline_id
|
||||
if pipeline_name is not None:
|
||||
update_values['pipeline_name'] = pipeline_name
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_monitoring.MonitoringSession)
|
||||
.where(persistence_monitoring.MonitoringSession.session_id == session_id)
|
||||
.values(update_values)
|
||||
)
|
||||
# Check if any rows were updated
|
||||
return result.rowcount > 0
|
||||
|
||||
async def record_error(
|
||||
self,
|
||||
bot_id: str,
|
||||
bot_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_name: str,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
session_id: str | None = None,
|
||||
stack_trace: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> str:
|
||||
"""Record an error"""
|
||||
error_id = str(uuid.uuid4())
|
||||
error_data = {
|
||||
'id': error_id,
|
||||
'timestamp': datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
'error_type': error_type,
|
||||
'error_message': error_message,
|
||||
'bot_id': bot_id,
|
||||
'bot_name': bot_name,
|
||||
'pipeline_id': pipeline_id,
|
||||
'pipeline_name': pipeline_name,
|
||||
'session_id': session_id,
|
||||
'stack_trace': stack_trace,
|
||||
'message_id': message_id,
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_monitoring.MonitoringError).values(error_data)
|
||||
)
|
||||
|
||||
return error_id
|
||||
|
||||
async def update_message_status(
|
||||
self,
|
||||
message_id: str,
|
||||
status: str,
|
||||
level: str | None = None,
|
||||
variables: str | None = None,
|
||||
) -> None:
|
||||
"""Update message status and optionally variables"""
|
||||
update_values = {'status': status}
|
||||
if level is not None:
|
||||
update_values['level'] = level
|
||||
if variables is not None:
|
||||
update_values['variables'] = variables
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_monitoring.MonitoringMessage)
|
||||
.where(persistence_monitoring.MonitoringMessage.id == message_id)
|
||||
.values(update_values)
|
||||
)
|
||||
|
||||
# ========== Query Methods ==========
|
||||
|
||||
async def get_overview_metrics(
|
||||
self,
|
||||
bot_ids: list[str] | None = None,
|
||||
pipeline_ids: list[str] | None = None,
|
||||
start_time: datetime.datetime | None = None,
|
||||
end_time: datetime.datetime | None = None,
|
||||
) -> dict:
|
||||
"""Get overview metrics"""
|
||||
# Build base query conditions
|
||||
message_conditions = []
|
||||
llm_conditions = []
|
||||
embedding_conditions = []
|
||||
session_conditions = []
|
||||
|
||||
if bot_ids:
|
||||
message_conditions.append(persistence_monitoring.MonitoringMessage.bot_id.in_(bot_ids))
|
||||
llm_conditions.append(persistence_monitoring.MonitoringLLMCall.bot_id.in_(bot_ids))
|
||||
session_conditions.append(persistence_monitoring.MonitoringSession.bot_id.in_(bot_ids))
|
||||
|
||||
if pipeline_ids:
|
||||
message_conditions.append(persistence_monitoring.MonitoringMessage.pipeline_id.in_(pipeline_ids))
|
||||
llm_conditions.append(persistence_monitoring.MonitoringLLMCall.pipeline_id.in_(pipeline_ids))
|
||||
session_conditions.append(persistence_monitoring.MonitoringSession.pipeline_id.in_(pipeline_ids))
|
||||
|
||||
if start_time:
|
||||
message_conditions.append(persistence_monitoring.MonitoringMessage.timestamp >= start_time)
|
||||
llm_conditions.append(persistence_monitoring.MonitoringLLMCall.timestamp >= start_time)
|
||||
embedding_conditions.append(persistence_monitoring.MonitoringEmbeddingCall.timestamp >= start_time)
|
||||
session_conditions.append(persistence_monitoring.MonitoringSession.start_time >= start_time)
|
||||
|
||||
if end_time:
|
||||
message_conditions.append(persistence_monitoring.MonitoringMessage.timestamp <= end_time)
|
||||
llm_conditions.append(persistence_monitoring.MonitoringLLMCall.timestamp <= end_time)
|
||||
embedding_conditions.append(persistence_monitoring.MonitoringEmbeddingCall.timestamp <= end_time)
|
||||
session_conditions.append(persistence_monitoring.MonitoringSession.start_time <= end_time)
|
||||
|
||||
# Total messages
|
||||
message_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringMessage.id))
|
||||
if message_conditions:
|
||||
message_query = message_query.where(sqlalchemy.and_(*message_conditions))
|
||||
|
||||
total_messages_result = await self.ap.persistence_mgr.execute_async(message_query)
|
||||
total_messages = total_messages_result.scalar() or 0
|
||||
|
||||
# Total LLM calls
|
||||
llm_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringLLMCall.id))
|
||||
if llm_conditions:
|
||||
llm_query = llm_query.where(sqlalchemy.and_(*llm_conditions))
|
||||
|
||||
llm_calls_result = await self.ap.persistence_mgr.execute_async(llm_query)
|
||||
llm_calls = llm_calls_result.scalar() or 0
|
||||
|
||||
# Total Embedding calls
|
||||
embedding_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringEmbeddingCall.id))
|
||||
if embedding_conditions:
|
||||
embedding_query = embedding_query.where(sqlalchemy.and_(*embedding_conditions))
|
||||
|
||||
embedding_calls_result = await self.ap.persistence_mgr.execute_async(embedding_query)
|
||||
embedding_calls = embedding_calls_result.scalar() or 0
|
||||
|
||||
# Total model calls (LLM + Embedding)
|
||||
model_calls = llm_calls + embedding_calls
|
||||
|
||||
# Success rate (based on messages)
|
||||
success_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringMessage.id)).where(
|
||||
persistence_monitoring.MonitoringMessage.status == 'success'
|
||||
)
|
||||
if message_conditions:
|
||||
success_query = success_query.where(sqlalchemy.and_(*message_conditions))
|
||||
|
||||
success_result = await self.ap.persistence_mgr.execute_async(success_query)
|
||||
success_count = success_result.scalar() or 0
|
||||
success_rate = (success_count / total_messages * 100) if total_messages > 0 else 100
|
||||
|
||||
# Active sessions
|
||||
active_session_query = sqlalchemy.select(
|
||||
sqlalchemy.func.count(persistence_monitoring.MonitoringSession.session_id)
|
||||
).where(persistence_monitoring.MonitoringSession.is_active == True)
|
||||
if session_conditions:
|
||||
active_session_query = active_session_query.where(sqlalchemy.and_(*session_conditions))
|
||||
|
||||
active_sessions_result = await self.ap.persistence_mgr.execute_async(active_session_query)
|
||||
active_sessions = active_sessions_result.scalar() or 0
|
||||
|
||||
return {
|
||||
'total_messages': total_messages,
|
||||
'llm_calls': llm_calls,
|
||||
'embedding_calls': embedding_calls,
|
||||
'model_calls': model_calls,
|
||||
'success_rate': round(success_rate, 2),
|
||||
'active_sessions': active_sessions,
|
||||
}
|
||||
|
||||
async def get_messages(
|
||||
self,
|
||||
bot_ids: list[str] | None = None,
|
||||
pipeline_ids: list[str] | None = None,
|
||||
start_time: datetime.datetime | None = None,
|
||||
end_time: datetime.datetime | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get messages with filters"""
|
||||
conditions = []
|
||||
|
||||
if bot_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringMessage.bot_id.in_(bot_ids))
|
||||
if pipeline_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringMessage.pipeline_id.in_(pipeline_ids))
|
||||
if start_time:
|
||||
conditions.append(persistence_monitoring.MonitoringMessage.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append(persistence_monitoring.MonitoringMessage.timestamp <= end_time)
|
||||
|
||||
# Get total count
|
||||
count_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringMessage.id))
|
||||
if conditions:
|
||||
count_query = count_query.where(sqlalchemy.and_(*conditions))
|
||||
|
||||
count_result = await self.ap.persistence_mgr.execute_async(count_query)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
# Get messages
|
||||
query = sqlalchemy.select(persistence_monitoring.MonitoringMessage).order_by(
|
||||
persistence_monitoring.MonitoringMessage.timestamp.desc()
|
||||
)
|
||||
if conditions:
|
||||
query = query.where(sqlalchemy.and_(*conditions))
|
||||
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(query)
|
||||
messages_rows = result.all()
|
||||
|
||||
serialized = []
|
||||
for row in messages_rows:
|
||||
# Extract model instance from Row (SQLAlchemy returns Row objects)
|
||||
msg = row[0] if isinstance(row, tuple) else row
|
||||
serialized_msg = self.ap.persistence_mgr.serialize_model(persistence_monitoring.MonitoringMessage, msg)
|
||||
serialized.append(serialized_msg)
|
||||
|
||||
return (serialized, total)
|
||||
|
||||
async def get_llm_calls(
|
||||
self,
|
||||
bot_ids: list[str] | None = None,
|
||||
pipeline_ids: list[str] | None = None,
|
||||
start_time: datetime.datetime | None = None,
|
||||
end_time: datetime.datetime | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get LLM calls with filters"""
|
||||
conditions = []
|
||||
|
||||
if bot_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringLLMCall.bot_id.in_(bot_ids))
|
||||
if pipeline_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringLLMCall.pipeline_id.in_(pipeline_ids))
|
||||
if start_time:
|
||||
conditions.append(persistence_monitoring.MonitoringLLMCall.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append(persistence_monitoring.MonitoringLLMCall.timestamp <= end_time)
|
||||
|
||||
# Get total count
|
||||
count_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringLLMCall.id))
|
||||
if conditions:
|
||||
count_query = count_query.where(sqlalchemy.and_(*conditions))
|
||||
|
||||
count_result = await self.ap.persistence_mgr.execute_async(count_query)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
# Get LLM calls
|
||||
query = sqlalchemy.select(persistence_monitoring.MonitoringLLMCall).order_by(
|
||||
persistence_monitoring.MonitoringLLMCall.timestamp.desc()
|
||||
)
|
||||
if conditions:
|
||||
query = query.where(sqlalchemy.and_(*conditions))
|
||||
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(query)
|
||||
llm_calls_rows = result.all()
|
||||
|
||||
return (
|
||||
[
|
||||
self.ap.persistence_mgr.serialize_model(
|
||||
persistence_monitoring.MonitoringLLMCall, row[0] if isinstance(row, tuple) else row
|
||||
)
|
||||
for row in llm_calls_rows
|
||||
],
|
||||
total,
|
||||
)
|
||||
|
||||
async def get_embedding_calls(
|
||||
self,
|
||||
start_time: datetime.datetime | None = None,
|
||||
end_time: datetime.datetime | None = None,
|
||||
knowledge_base_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get embedding calls with filters"""
|
||||
conditions = []
|
||||
|
||||
if start_time:
|
||||
conditions.append(persistence_monitoring.MonitoringEmbeddingCall.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append(persistence_monitoring.MonitoringEmbeddingCall.timestamp <= end_time)
|
||||
if knowledge_base_id:
|
||||
conditions.append(persistence_monitoring.MonitoringEmbeddingCall.knowledge_base_id == knowledge_base_id)
|
||||
|
||||
# Get total count
|
||||
count_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringEmbeddingCall.id))
|
||||
if conditions:
|
||||
count_query = count_query.where(sqlalchemy.and_(*conditions))
|
||||
|
||||
count_result = await self.ap.persistence_mgr.execute_async(count_query)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
# Get embedding calls
|
||||
query = sqlalchemy.select(persistence_monitoring.MonitoringEmbeddingCall).order_by(
|
||||
persistence_monitoring.MonitoringEmbeddingCall.timestamp.desc()
|
||||
)
|
||||
if conditions:
|
||||
query = query.where(sqlalchemy.and_(*conditions))
|
||||
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(query)
|
||||
embedding_calls_rows = result.all()
|
||||
|
||||
return (
|
||||
[
|
||||
self.ap.persistence_mgr.serialize_model(
|
||||
persistence_monitoring.MonitoringEmbeddingCall, row[0] if isinstance(row, tuple) else row
|
||||
)
|
||||
for row in embedding_calls_rows
|
||||
],
|
||||
total,
|
||||
)
|
||||
|
||||
async def get_sessions(
|
||||
self,
|
||||
bot_ids: list[str] | None = None,
|
||||
pipeline_ids: list[str] | None = None,
|
||||
start_time: datetime.datetime | None = None,
|
||||
end_time: datetime.datetime | None = None,
|
||||
is_active: bool | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get sessions with filters"""
|
||||
conditions = []
|
||||
|
||||
if bot_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringSession.bot_id.in_(bot_ids))
|
||||
if pipeline_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringSession.pipeline_id.in_(pipeline_ids))
|
||||
if start_time:
|
||||
conditions.append(persistence_monitoring.MonitoringSession.start_time >= start_time)
|
||||
if end_time:
|
||||
conditions.append(persistence_monitoring.MonitoringSession.start_time <= end_time)
|
||||
if is_active is not None:
|
||||
conditions.append(persistence_monitoring.MonitoringSession.is_active == is_active)
|
||||
|
||||
# Get total count
|
||||
count_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringSession.session_id))
|
||||
if conditions:
|
||||
count_query = count_query.where(sqlalchemy.and_(*conditions))
|
||||
|
||||
count_result = await self.ap.persistence_mgr.execute_async(count_query)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
# Get sessions
|
||||
query = sqlalchemy.select(persistence_monitoring.MonitoringSession).order_by(
|
||||
persistence_monitoring.MonitoringSession.last_activity.desc()
|
||||
)
|
||||
if conditions:
|
||||
query = query.where(sqlalchemy.and_(*conditions))
|
||||
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(query)
|
||||
sessions_rows = result.all()
|
||||
|
||||
return (
|
||||
[
|
||||
self.ap.persistence_mgr.serialize_model(
|
||||
persistence_monitoring.MonitoringSession, row[0] if isinstance(row, tuple) else row
|
||||
)
|
||||
for row in sessions_rows
|
||||
],
|
||||
total,
|
||||
)
|
||||
|
||||
async def get_errors(
|
||||
self,
|
||||
bot_ids: list[str] | None = None,
|
||||
pipeline_ids: list[str] | None = None,
|
||||
start_time: datetime.datetime | None = None,
|
||||
end_time: datetime.datetime | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get errors with filters"""
|
||||
conditions = []
|
||||
|
||||
if bot_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringError.bot_id.in_(bot_ids))
|
||||
if pipeline_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringError.pipeline_id.in_(pipeline_ids))
|
||||
if start_time:
|
||||
conditions.append(persistence_monitoring.MonitoringError.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append(persistence_monitoring.MonitoringError.timestamp <= end_time)
|
||||
|
||||
# Get total count
|
||||
count_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringError.id))
|
||||
if conditions:
|
||||
count_query = count_query.where(sqlalchemy.and_(*conditions))
|
||||
|
||||
count_result = await self.ap.persistence_mgr.execute_async(count_query)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
# Get errors
|
||||
query = sqlalchemy.select(persistence_monitoring.MonitoringError).order_by(
|
||||
persistence_monitoring.MonitoringError.timestamp.desc()
|
||||
)
|
||||
if conditions:
|
||||
query = query.where(sqlalchemy.and_(*conditions))
|
||||
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(query)
|
||||
errors_rows = result.all()
|
||||
|
||||
return (
|
||||
[
|
||||
self.ap.persistence_mgr.serialize_model(
|
||||
persistence_monitoring.MonitoringError, row[0] if isinstance(row, tuple) else row
|
||||
)
|
||||
for row in errors_rows
|
||||
],
|
||||
total,
|
||||
)
|
||||
|
||||
async def get_session_analysis(
|
||||
self,
|
||||
session_id: str,
|
||||
) -> dict:
|
||||
"""Get detailed analysis for a specific session"""
|
||||
# Get session info
|
||||
session_query = sqlalchemy.select(persistence_monitoring.MonitoringSession).where(
|
||||
persistence_monitoring.MonitoringSession.session_id == session_id
|
||||
)
|
||||
session_result = await self.ap.persistence_mgr.execute_async(session_query)
|
||||
session_row = session_result.first()
|
||||
|
||||
if not session_row:
|
||||
return {
|
||||
'session_id': session_id,
|
||||
'found': False,
|
||||
}
|
||||
|
||||
session = session_row[0] if isinstance(session_row, tuple) else session_row
|
||||
|
||||
# Get messages for this session
|
||||
messages_query = (
|
||||
sqlalchemy.select(persistence_monitoring.MonitoringMessage)
|
||||
.where(persistence_monitoring.MonitoringMessage.session_id == session_id)
|
||||
.order_by(persistence_monitoring.MonitoringMessage.timestamp.asc())
|
||||
)
|
||||
messages_result = await self.ap.persistence_mgr.execute_async(messages_query)
|
||||
messages_rows = messages_result.all()
|
||||
|
||||
# Count messages by status
|
||||
success_messages = 0
|
||||
error_messages = 0
|
||||
pending_messages = 0
|
||||
for row in messages_rows:
|
||||
msg = row[0] if isinstance(row, tuple) else row
|
||||
if msg.status == 'success':
|
||||
success_messages += 1
|
||||
elif msg.status == 'error':
|
||||
error_messages += 1
|
||||
elif msg.status == 'pending':
|
||||
pending_messages += 1
|
||||
|
||||
# Get LLM calls for this session
|
||||
llm_query = sqlalchemy.select(persistence_monitoring.MonitoringLLMCall).where(
|
||||
persistence_monitoring.MonitoringLLMCall.session_id == session_id
|
||||
)
|
||||
llm_result = await self.ap.persistence_mgr.execute_async(llm_query)
|
||||
llm_rows = llm_result.all()
|
||||
|
||||
# Calculate LLM statistics
|
||||
total_llm_calls = len(llm_rows)
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
total_tokens = 0
|
||||
total_duration = 0
|
||||
success_llm_calls = 0
|
||||
error_llm_calls = 0
|
||||
|
||||
for row in llm_rows:
|
||||
llm_call = row[0] if isinstance(row, tuple) else row
|
||||
total_input_tokens += llm_call.input_tokens
|
||||
total_output_tokens += llm_call.output_tokens
|
||||
total_tokens += llm_call.total_tokens
|
||||
total_duration += llm_call.duration
|
||||
if llm_call.status == 'success':
|
||||
success_llm_calls += 1
|
||||
else:
|
||||
error_llm_calls += 1
|
||||
|
||||
# Get errors for this session
|
||||
error_query = (
|
||||
sqlalchemy.select(persistence_monitoring.MonitoringError)
|
||||
.where(persistence_monitoring.MonitoringError.session_id == session_id)
|
||||
.order_by(persistence_monitoring.MonitoringError.timestamp.desc())
|
||||
)
|
||||
error_result = await self.ap.persistence_mgr.execute_async(error_query)
|
||||
error_rows = error_result.all()
|
||||
|
||||
errors = [
|
||||
self.ap.persistence_mgr.serialize_model(
|
||||
persistence_monitoring.MonitoringError, row[0] if isinstance(row, tuple) else row
|
||||
)
|
||||
for row in error_rows
|
||||
]
|
||||
|
||||
# Calculate session duration
|
||||
if messages_rows:
|
||||
first_msg = messages_rows[0][0] if isinstance(messages_rows[0], tuple) else messages_rows[0]
|
||||
last_msg = messages_rows[-1][0] if isinstance(messages_rows[-1], tuple) else messages_rows[-1]
|
||||
session_duration_seconds = int((last_msg.timestamp - first_msg.timestamp).total_seconds())
|
||||
else:
|
||||
session_duration_seconds = 0
|
||||
|
||||
return {
|
||||
'session_id': session_id,
|
||||
'found': True,
|
||||
'session': self.ap.persistence_mgr.serialize_model(persistence_monitoring.MonitoringSession, session),
|
||||
'message_stats': {
|
||||
'total': len(messages_rows),
|
||||
'success': success_messages,
|
||||
'error': error_messages,
|
||||
'pending': pending_messages,
|
||||
},
|
||||
'llm_stats': {
|
||||
'total_calls': total_llm_calls,
|
||||
'success_calls': success_llm_calls,
|
||||
'error_calls': error_llm_calls,
|
||||
'total_input_tokens': total_input_tokens,
|
||||
'total_output_tokens': total_output_tokens,
|
||||
'total_tokens': total_tokens,
|
||||
'average_duration_ms': int(total_duration / total_llm_calls) if total_llm_calls > 0 else 0,
|
||||
},
|
||||
'errors': errors,
|
||||
'session_duration_seconds': session_duration_seconds,
|
||||
}
|
||||
|
||||
async def get_message_details(
|
||||
self,
|
||||
message_id: str,
|
||||
) -> dict:
|
||||
"""Get detailed information for a specific message including associated LLM calls and errors"""
|
||||
# Get message info
|
||||
message_query = sqlalchemy.select(persistence_monitoring.MonitoringMessage).where(
|
||||
persistence_monitoring.MonitoringMessage.id == message_id
|
||||
)
|
||||
message_result = await self.ap.persistence_mgr.execute_async(message_query)
|
||||
message_row = message_result.first()
|
||||
|
||||
if not message_row:
|
||||
return {
|
||||
'message_id': message_id,
|
||||
'found': False,
|
||||
}
|
||||
|
||||
message = message_row[0] if isinstance(message_row, tuple) else message_row
|
||||
|
||||
# Get LLM calls for this message
|
||||
llm_query = (
|
||||
sqlalchemy.select(persistence_monitoring.MonitoringLLMCall)
|
||||
.where(persistence_monitoring.MonitoringLLMCall.message_id == message_id)
|
||||
.order_by(persistence_monitoring.MonitoringLLMCall.timestamp.asc())
|
||||
)
|
||||
llm_result = await self.ap.persistence_mgr.execute_async(llm_query)
|
||||
llm_rows = llm_result.all()
|
||||
|
||||
llm_calls = [
|
||||
self.ap.persistence_mgr.serialize_model(
|
||||
persistence_monitoring.MonitoringLLMCall, row[0] if isinstance(row, tuple) else row
|
||||
)
|
||||
for row in llm_rows
|
||||
]
|
||||
|
||||
# Calculate LLM statistics
|
||||
total_input_tokens = sum(call.input_tokens for call in llm_rows)
|
||||
total_output_tokens = sum(call.output_tokens for call in llm_rows)
|
||||
total_tokens = sum(call.total_tokens for call in llm_rows)
|
||||
total_duration = sum(call.duration for call in llm_rows)
|
||||
|
||||
# Get errors for this message
|
||||
error_query = (
|
||||
sqlalchemy.select(persistence_monitoring.MonitoringError)
|
||||
.where(persistence_monitoring.MonitoringError.message_id == message_id)
|
||||
.order_by(persistence_monitoring.MonitoringError.timestamp.asc())
|
||||
)
|
||||
error_result = await self.ap.persistence_mgr.execute_async(error_query)
|
||||
error_rows = error_result.all()
|
||||
|
||||
errors = [
|
||||
self.ap.persistence_mgr.serialize_model(
|
||||
persistence_monitoring.MonitoringError, row[0] if isinstance(row, tuple) else row
|
||||
)
|
||||
for row in error_rows
|
||||
]
|
||||
|
||||
return {
|
||||
'message_id': message_id,
|
||||
'found': True,
|
||||
'message': self.ap.persistence_mgr.serialize_model(persistence_monitoring.MonitoringMessage, message),
|
||||
'llm_calls': llm_calls,
|
||||
'llm_stats': {
|
||||
'total_calls': len(llm_rows),
|
||||
'total_input_tokens': total_input_tokens,
|
||||
'total_output_tokens': total_output_tokens,
|
||||
'total_tokens': total_tokens,
|
||||
'total_duration_ms': total_duration,
|
||||
'average_duration_ms': int(total_duration / len(llm_rows)) if len(llm_rows) > 0 else 0,
|
||||
},
|
||||
'errors': errors,
|
||||
}
|
||||
166
src/langbot/pkg/api/http/service/provider.py
Normal file
166
src/langbot/pkg/api/http/service/provider.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ....core import app
|
||||
from ....entity.persistence import model as persistence_model
|
||||
|
||||
|
||||
class ModelProviderService:
|
||||
"""Service for managing model providers"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def get_providers(self) -> list[dict]:
|
||||
"""Get all providers"""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.ModelProvider))
|
||||
providers = result.all()
|
||||
providers_list = []
|
||||
for p in providers:
|
||||
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, p)
|
||||
# Parse api_keys if it's a JSON string
|
||||
if isinstance(provider_dict.get('api_keys'), str):
|
||||
import json
|
||||
|
||||
try:
|
||||
provider_dict['api_keys'] = json.loads(provider_dict['api_keys'])
|
||||
except Exception:
|
||||
provider_dict['api_keys'] = []
|
||||
providers_list.append(provider_dict)
|
||||
return providers_list
|
||||
|
||||
async def get_provider(self, provider_uuid: str) -> dict | None:
|
||||
"""Get a single provider by UUID"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||
persistence_model.ModelProvider.uuid == provider_uuid
|
||||
)
|
||||
)
|
||||
provider = result.first()
|
||||
if provider is None:
|
||||
return None
|
||||
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
||||
# Parse api_keys if it's a JSON string
|
||||
if isinstance(provider_dict.get('api_keys'), str):
|
||||
import json
|
||||
|
||||
try:
|
||||
provider_dict['api_keys'] = json.loads(provider_dict['api_keys'])
|
||||
except Exception:
|
||||
provider_dict['api_keys'] = []
|
||||
return provider_dict
|
||||
|
||||
async def create_provider(self, provider_data: dict) -> str:
|
||||
"""Create a new provider"""
|
||||
provider_data['uuid'] = str(uuid.uuid4())
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
|
||||
)
|
||||
|
||||
# load to runtime
|
||||
runtime_provider = await self.ap.model_mgr.load_provider(provider_data)
|
||||
self.ap.model_mgr.provider_dict[runtime_provider.provider_entity.uuid] = runtime_provider
|
||||
return provider_data['uuid']
|
||||
|
||||
async def update_provider(self, provider_uuid: str, provider_data: dict) -> None:
|
||||
"""Update an existing provider"""
|
||||
if 'uuid' in provider_data:
|
||||
del provider_data['uuid']
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.ModelProvider)
|
||||
.where(persistence_model.ModelProvider.uuid == provider_uuid)
|
||||
.values(**provider_data)
|
||||
)
|
||||
await self.ap.model_mgr.reload_provider(provider_uuid)
|
||||
|
||||
async def delete_provider(self, provider_uuid: str) -> None:
|
||||
"""Delete a provider (only if no models reference it)"""
|
||||
# Check if any models use this provider
|
||||
llm_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.LLMModel).where(
|
||||
persistence_model.LLMModel.provider_uuid == provider_uuid
|
||||
)
|
||||
)
|
||||
if llm_result.first() is not None:
|
||||
raise ValueError('Cannot delete provider: LLM models still reference it')
|
||||
|
||||
embedding_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.EmbeddingModel).where(
|
||||
persistence_model.EmbeddingModel.provider_uuid == provider_uuid
|
||||
)
|
||||
)
|
||||
if embedding_result.first() is not None:
|
||||
raise ValueError('Cannot delete provider: Embedding models still reference it')
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_model.ModelProvider).where(
|
||||
persistence_model.ModelProvider.uuid == provider_uuid
|
||||
)
|
||||
)
|
||||
|
||||
await self.ap.model_mgr.remove_provider(provider_uuid)
|
||||
|
||||
async def get_provider_model_counts(self, provider_uuid: str) -> dict:
|
||||
"""Get count of models using this provider"""
|
||||
llm_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(sqlalchemy.func.count())
|
||||
.select_from(persistence_model.LLMModel)
|
||||
.where(persistence_model.LLMModel.provider_uuid == provider_uuid)
|
||||
)
|
||||
llm_count = llm_result.scalar() or 0
|
||||
|
||||
embedding_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(sqlalchemy.func.count())
|
||||
.select_from(persistence_model.EmbeddingModel)
|
||||
.where(persistence_model.EmbeddingModel.provider_uuid == provider_uuid)
|
||||
)
|
||||
embedding_count = embedding_result.scalar() or 0
|
||||
|
||||
return {'llm_count': llm_count, 'embedding_count': embedding_count}
|
||||
|
||||
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
|
||||
"""Find existing provider or create new one"""
|
||||
# Try to find existing provider with same config
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||
persistence_model.ModelProvider.requester == requester,
|
||||
persistence_model.ModelProvider.base_url == base_url,
|
||||
)
|
||||
)
|
||||
for provider in result.all():
|
||||
if sorted(provider.api_keys or []) == sorted(api_keys or []):
|
||||
return provider.uuid
|
||||
|
||||
# Create new provider
|
||||
provider_name = requester
|
||||
if base_url:
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(base_url)
|
||||
provider_name = parsed.netloc or requester
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await self.create_provider(
|
||||
{
|
||||
'name': provider_name,
|
||||
'requester': requester,
|
||||
'base_url': base_url,
|
||||
'api_keys': api_keys or [],
|
||||
}
|
||||
)
|
||||
|
||||
async def update_space_model_provider_api_keys(self, api_key: str) -> None:
|
||||
"""Update Space model provider API keys"""
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.ModelProvider)
|
||||
.where(persistence_model.ModelProvider.uuid == '00000000-0000-0000-0000-000000000000')
|
||||
.values(api_keys=[api_key])
|
||||
)
|
||||
await self.ap.model_mgr.reload_provider('00000000-0000-0000-0000-000000000000')
|
||||
189
src/langbot/pkg/api/http/service/space.py
Normal file
189
src/langbot/pkg/api/http/service/space.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import aiohttp
|
||||
import typing
|
||||
import datetime
|
||||
import time
|
||||
import sqlalchemy
|
||||
|
||||
from ....core import app
|
||||
from ....entity.persistence import user
|
||||
from ....entity.dto.space_model import SpaceModel
|
||||
|
||||
|
||||
class SpaceService:
|
||||
"""Service for interacting with LangBot Space API"""
|
||||
|
||||
ap: app.Application
|
||||
_credits_cache: typing.Dict[str, typing.Tuple[int, float]] # {user_email: (credits, timestamp)}
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
self._credits_cache = {}
|
||||
|
||||
def _get_space_config(self) -> typing.Dict[str, str]:
|
||||
"""Get Space configuration from config file"""
|
||||
space_config = self.ap.instance_config.data.get('space', {})
|
||||
return {
|
||||
'url': space_config.get('url', 'https://space.langbot.app'),
|
||||
'oauth_authorize_url': space_config.get('oauth_authorize_url', 'https://space.langbot.app/auth/authorize'),
|
||||
}
|
||||
|
||||
async def _get_user_by_email(self, user_email: str) -> user.User | None:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(user.User).where(user.User.user == user_email)
|
||||
)
|
||||
result_list = result.all()
|
||||
return result_list[0] if result_list else None
|
||||
|
||||
async def _ensure_valid_token(self, user_email: str) -> str | None:
|
||||
"""Ensure access token is valid, refresh if expired. Returns valid access_token or None."""
|
||||
user_obj = await self._get_user_by_email(user_email)
|
||||
if not user_obj or user_obj.account_type != 'space':
|
||||
return None
|
||||
|
||||
if not user_obj.space_access_token:
|
||||
return None
|
||||
|
||||
# Check if token is expired (with 60s buffer)
|
||||
if user_obj.space_access_token_expires_at:
|
||||
if datetime.datetime.now() >= user_obj.space_access_token_expires_at - datetime.timedelta(seconds=60):
|
||||
# Token expired, try to refresh
|
||||
if user_obj.space_refresh_token:
|
||||
try:
|
||||
new_token = await self._refresh_and_save_token(user_obj)
|
||||
return new_token
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
return user_obj.space_access_token
|
||||
|
||||
async def _refresh_and_save_token(self, user_obj: user.User) -> str:
|
||||
"""Refresh token and save to database"""
|
||||
token_data = await self.refresh_token(user_obj.space_refresh_token)
|
||||
access_token = token_data.get('access_token')
|
||||
expires_in = token_data.get('expires_in', 0)
|
||||
|
||||
if not access_token:
|
||||
raise ValueError('Failed to refresh token')
|
||||
|
||||
expires_at = datetime.datetime.now() + datetime.timedelta(seconds=expires_in) if expires_in > 0 else None
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(user.User)
|
||||
.where(user.User.user == user_obj.user)
|
||||
.values(
|
||||
space_access_token=access_token,
|
||||
space_access_token_expires_at=expires_at,
|
||||
)
|
||||
)
|
||||
|
||||
return access_token
|
||||
|
||||
# === Raw API calls (no token validation) ===
|
||||
|
||||
def get_oauth_authorize_url(self, redirect_uri: str, state: str = '') -> str:
|
||||
"""Get the Space OAuth authorization URL for redirect"""
|
||||
space_config = self._get_space_config()
|
||||
authorize_url = space_config['oauth_authorize_url']
|
||||
params = f'redirect_uri={redirect_uri}'
|
||||
if state:
|
||||
params += f'&state={state}'
|
||||
return f'{authorize_url}?{params}'
|
||||
|
||||
async def exchange_oauth_code(self, code: str) -> typing.Dict:
|
||||
"""Exchange OAuth authorization code for tokens"""
|
||||
from langbot.pkg.utils import constants
|
||||
|
||||
space_config = self._get_space_config()
|
||||
space_url = space_config['url']
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f'{space_url}/api/v1/accounts/oauth/token',
|
||||
json={'code': code, 'instance_id': constants.instance_id},
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to exchange OAuth code: {await response.text()}')
|
||||
data = await response.json()
|
||||
if data.get('code') != 0:
|
||||
raise ValueError(f'Failed to exchange OAuth code: {data.get("msg")}')
|
||||
return data.get('data', {})
|
||||
|
||||
async def refresh_token(self, refresh_token: str) -> typing.Dict:
|
||||
"""Refresh Space access token"""
|
||||
space_config = self._get_space_config()
|
||||
space_url = space_config['url']
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f'{space_url}/api/v1/accounts/token/refresh', json={'refresh_token': refresh_token}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to refresh token: {await response.text()}')
|
||||
data = await response.json()
|
||||
if data.get('code') != 0:
|
||||
raise ValueError(f'Failed to refresh token: {data.get("msg")}')
|
||||
return data.get('data', {})
|
||||
|
||||
async def get_user_info_raw(self, access_token: str) -> typing.Dict:
|
||||
"""Get user info from Space using access token (no validation)"""
|
||||
space_config = self._get_space_config()
|
||||
space_url = space_config['url']
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f'{space_url}/api/v1/accounts/me', headers={'Authorization': f'Bearer {access_token}'}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to get user info: {await response.text()}')
|
||||
data = await response.json()
|
||||
if data.get('code') != 0:
|
||||
raise ValueError(f'Failed to get user info: {data.get("msg")}')
|
||||
return data.get('data', {})
|
||||
|
||||
# === API calls with token validation ===
|
||||
|
||||
async def get_user_info(self, user_email: str) -> typing.Dict | None:
|
||||
"""Get user info from Space (with token validation)"""
|
||||
access_token = await self._ensure_valid_token(user_email)
|
||||
if not access_token:
|
||||
return None
|
||||
return await self.get_user_info_raw(access_token)
|
||||
|
||||
async def get_credits(self, user_email: str, force_refresh: bool = False) -> int | None:
|
||||
"""Get Space credits for user with caching (60s TTL)"""
|
||||
cache_ttl = 60
|
||||
|
||||
if not force_refresh and user_email in self._credits_cache:
|
||||
credits, ts = self._credits_cache[user_email]
|
||||
if time.time() - ts < cache_ttl:
|
||||
return credits
|
||||
|
||||
try:
|
||||
info = await self.get_user_info(user_email)
|
||||
if info is None:
|
||||
return None
|
||||
credits = info.get('credits')
|
||||
if credits is not None:
|
||||
self._credits_cache[user_email] = (credits, time.time())
|
||||
return credits
|
||||
except Exception:
|
||||
return self._credits_cache.get(user_email, (None, 0))[0]
|
||||
|
||||
async def get_models(self) -> typing.List[SpaceModel]:
|
||||
"""Get models from Space"""
|
||||
|
||||
space_config = self._get_space_config()
|
||||
space_url = space_config['url']
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f'{space_url}/api/v1/models') as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to get models: {await response.text()}')
|
||||
data = await response.json()
|
||||
if data.get('code') != 0:
|
||||
raise ValueError(f'Failed to get models: {data.get("msg")}')
|
||||
models_data = data.get('data', {}).get('models', [])
|
||||
return [SpaceModel.model_validate(model_dict) for model_dict in models_data]
|
||||
@@ -4,17 +4,22 @@ import sqlalchemy
|
||||
import argon2
|
||||
import jwt
|
||||
import datetime
|
||||
import typing
|
||||
import asyncio
|
||||
|
||||
from ....core import app
|
||||
from ....entity.persistence import user
|
||||
from ....utils import constants
|
||||
from ....entity.errors import account as account_errors
|
||||
|
||||
|
||||
class UserService:
|
||||
ap: app.Application
|
||||
_create_user_lock: asyncio.Lock
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
self._create_user_lock = asyncio.Lock()
|
||||
|
||||
async def is_initialized(self) -> bool:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(user.User).limit(1))
|
||||
@@ -28,7 +33,7 @@ class UserService:
|
||||
hashed_password = ph.hash(password)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(user.User).values(user=user_email, password=hashed_password)
|
||||
sqlalchemy.insert(user.User).values(user=user_email, password=hashed_password, account_type='local')
|
||||
)
|
||||
|
||||
async def get_user_by_email(self, user_email: str) -> user.User | None:
|
||||
@@ -39,6 +44,15 @@ class UserService:
|
||||
result_list = result.all()
|
||||
return result_list[0] if result_list is not None and len(result_list) > 0 else None
|
||||
|
||||
async def get_user_by_space_account_uuid(self, space_account_uuid: str) -> user.User | None:
|
||||
"""Get user by Space account UUID"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(user.User).where(user.User.space_account_uuid == space_account_uuid)
|
||||
)
|
||||
|
||||
result_list = result.all()
|
||||
return result_list[0] if result_list is not None and len(result_list) > 0 else None
|
||||
|
||||
async def authenticate(self, user_email: str, password: str) -> str | None:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(user.User).where(user.User.user == user_email)
|
||||
@@ -51,6 +65,10 @@ class UserService:
|
||||
|
||||
user_obj = result_list[0]
|
||||
|
||||
# Check if this is a Space account
|
||||
if user_obj.account_type == 'space':
|
||||
raise ValueError('请使用 Space 账户登录')
|
||||
|
||||
ph = argon2.PasswordHasher()
|
||||
|
||||
ph.verify(user_obj.password, password)
|
||||
@@ -90,6 +108,10 @@ class UserService:
|
||||
if user_obj is None:
|
||||
raise ValueError('User not found')
|
||||
|
||||
# Space accounts cannot change password locally
|
||||
if user_obj.account_type == 'space':
|
||||
raise ValueError('Space account cannot change password locally')
|
||||
|
||||
ph.verify(user_obj.password, current_password)
|
||||
|
||||
hashed_password = ph.hash(new_password)
|
||||
@@ -97,3 +119,183 @@ class UserService:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(user.User).where(user.User.user == user_email).values(password=hashed_password)
|
||||
)
|
||||
|
||||
# Space user management
|
||||
|
||||
async def create_or_update_space_user(
|
||||
self,
|
||||
space_account_uuid: str,
|
||||
email: str,
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
api_key: str,
|
||||
expires_in: int = 0,
|
||||
) -> user.User:
|
||||
"""Create or update a Space user account (only if system not initialized or user exists)"""
|
||||
expires_at = datetime.datetime.now() + datetime.timedelta(seconds=expires_in) if expires_in > 0 else None
|
||||
|
||||
async with self._create_user_lock:
|
||||
# Check if user with this Space UUID already exists
|
||||
existing_user = await self.get_user_by_space_account_uuid(space_account_uuid)
|
||||
|
||||
if existing_user:
|
||||
# Update existing user's tokens
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(user.User)
|
||||
.where(user.User.space_account_uuid == space_account_uuid)
|
||||
.values(
|
||||
space_access_token=access_token,
|
||||
space_refresh_token=refresh_token,
|
||||
space_api_key=api_key,
|
||||
space_access_token_expires_at=expires_at,
|
||||
)
|
||||
)
|
||||
await self.ap.provider_service.update_space_model_provider_api_keys(api_key)
|
||||
return await self.get_user_by_space_account_uuid(space_account_uuid)
|
||||
|
||||
# Check if user with same email exists
|
||||
existing_email_user = await self.get_user_by_email(email)
|
||||
if existing_email_user:
|
||||
# Update existing user to link with Space account
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(user.User)
|
||||
.where(user.User.user == email)
|
||||
.values(
|
||||
account_type='space',
|
||||
space_account_uuid=space_account_uuid,
|
||||
space_access_token=access_token,
|
||||
space_refresh_token=refresh_token,
|
||||
space_api_key=api_key,
|
||||
space_access_token_expires_at=expires_at,
|
||||
)
|
||||
)
|
||||
await self.ap.provider_service.update_space_model_provider_api_keys(api_key)
|
||||
return await self.get_user_by_email(email)
|
||||
|
||||
# Check if system is already initialized
|
||||
is_initialized = await self.is_initialized()
|
||||
if is_initialized:
|
||||
raise account_errors.AccountEmailMismatchError()
|
||||
|
||||
# Create new Space user (first time initialization)
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(user.User).values(
|
||||
user=email,
|
||||
password='', # Space users don't have local password
|
||||
account_type='space',
|
||||
space_account_uuid=space_account_uuid,
|
||||
space_access_token=access_token,
|
||||
space_refresh_token=refresh_token,
|
||||
space_api_key=api_key,
|
||||
space_access_token_expires_at=expires_at,
|
||||
)
|
||||
)
|
||||
await self.ap.provider_service.update_space_model_provider_api_keys(api_key)
|
||||
|
||||
return await self.get_user_by_space_account_uuid(space_account_uuid)
|
||||
|
||||
async def authenticate_space_user(
|
||||
self, access_token: str, refresh_token: str, expires_in: int = 0
|
||||
) -> typing.Tuple[str, user.User]:
|
||||
"""Authenticate with Space and return JWT token"""
|
||||
# Get user info from Space using raw API (token just obtained, no need to validate)
|
||||
user_info = await self.ap.space_service.get_user_info_raw(access_token)
|
||||
|
||||
account = user_info.get('account', {})
|
||||
api_key = user_info.get('api_key', '')
|
||||
|
||||
space_account_uuid = account.get('uuid')
|
||||
email = account.get('email')
|
||||
|
||||
if not space_account_uuid or not email:
|
||||
raise ValueError('Invalid Space user info')
|
||||
|
||||
# Create or update Space user in local database
|
||||
user_obj = await self.create_or_update_space_user(
|
||||
space_account_uuid=space_account_uuid,
|
||||
email=email,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
api_key=api_key,
|
||||
expires_in=expires_in,
|
||||
)
|
||||
|
||||
# Generate JWT token
|
||||
jwt_token = await self.generate_jwt_token(email)
|
||||
|
||||
return jwt_token, user_obj
|
||||
|
||||
async def get_first_user(self) -> user.User | None:
|
||||
"""Get the first user (for single-user mode)"""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(user.User).limit(1))
|
||||
result_list = result.all()
|
||||
return result_list[0] if result_list else None
|
||||
|
||||
async def set_password(self, user_email: str, new_password: str, current_password: str | None = None) -> None:
|
||||
"""Set or change password for a user"""
|
||||
ph = argon2.PasswordHasher()
|
||||
user_obj = await self.get_user_by_email(user_email)
|
||||
|
||||
if user_obj is None:
|
||||
raise ValueError('User not found')
|
||||
|
||||
# If user already has a password, verify current password
|
||||
has_password = bool(user_obj.password and user_obj.password.strip())
|
||||
if has_password:
|
||||
if not current_password:
|
||||
raise ValueError('Current password is required')
|
||||
ph.verify(user_obj.password, current_password)
|
||||
|
||||
hashed_password = ph.hash(new_password)
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(user.User).where(user.User.user == user_email).values(password=hashed_password)
|
||||
)
|
||||
|
||||
async def bind_space_account(self, user_email: str, code: str) -> user.User:
|
||||
"""Bind Space account to existing local account"""
|
||||
# Exchange code for tokens
|
||||
token_data = await self.ap.space_service.exchange_oauth_code(code)
|
||||
access_token = token_data.get('access_token')
|
||||
refresh_token = token_data.get('refresh_token')
|
||||
expires_in = token_data.get('expires_in', 0)
|
||||
|
||||
if not access_token:
|
||||
raise ValueError('Failed to get access token from Space')
|
||||
|
||||
expires_at = datetime.datetime.now() + datetime.timedelta(seconds=expires_in) if expires_in > 0 else None
|
||||
|
||||
# Get Space user info (token just obtained, use raw API)
|
||||
user_info = await self.ap.space_service.get_user_info_raw(access_token)
|
||||
account = user_info.get('account', {})
|
||||
api_key = user_info.get('api_key', '')
|
||||
|
||||
space_account_uuid = account.get('uuid')
|
||||
space_email = account.get('email')
|
||||
|
||||
if not space_account_uuid or not space_email:
|
||||
raise ValueError('Invalid Space user info')
|
||||
|
||||
# Check if this Space account is already bound to another user
|
||||
existing_space_user = await self.get_user_by_space_account_uuid(space_account_uuid)
|
||||
if existing_space_user and existing_space_user.user != user_email:
|
||||
raise ValueError('This Space account is already bound to another user')
|
||||
|
||||
# Update local account to Space account
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(user.User)
|
||||
.where(user.User.user == user_email)
|
||||
.values(
|
||||
user=space_email, # Update email to Space email
|
||||
account_type='space',
|
||||
space_account_uuid=space_account_uuid,
|
||||
space_access_token=access_token,
|
||||
space_refresh_token=refresh_token,
|
||||
space_api_key=api_key,
|
||||
space_access_token_expires_at=expires_at,
|
||||
)
|
||||
)
|
||||
|
||||
# Update Space model provider API keys
|
||||
await self.ap.provider_service.update_space_model_provider_api_keys(api_key)
|
||||
|
||||
return await self.get_user_by_email(space_email)
|
||||
|
||||
@@ -19,7 +19,9 @@ from ..utils import version as version_mgr, proxy as proxy_mgr
|
||||
from ..persistence import mgr as persistencemgr
|
||||
from ..api.http.controller import main as http_controller
|
||||
from ..api.http.service import user as user_service
|
||||
from ..api.http.service import space as space_service
|
||||
from ..api.http.service import model as model_service
|
||||
from ..api.http.service import provider as provider_service
|
||||
from ..api.http.service import pipeline as pipeline_service
|
||||
from ..api.http.service import bot as bot_service
|
||||
from ..api.http.service import knowledge as knowledge_service
|
||||
@@ -27,6 +29,7 @@ from ..api.http.service import mcp as mcp_service
|
||||
from ..api.http.service import apikey as apikey_service
|
||||
from ..api.http.service import webhook as webhook_service
|
||||
from ..api.http.service import external_kb as external_kb_service
|
||||
from ..api.http.service import monitoring as monitoring_service
|
||||
from ..discover import engine as discover_engine
|
||||
from ..storage import mgr as storagemgr
|
||||
from ..utils import logcache
|
||||
@@ -34,6 +37,7 @@ from . import taskmgr
|
||||
from . import entities as core_entities
|
||||
from ..rag.knowledge import kbmgr as rag_mgr
|
||||
from ..vector import mgr as vectordb_mgr
|
||||
from ..telemetry import telemetry as telemetry_module
|
||||
|
||||
|
||||
class Application:
|
||||
@@ -75,6 +79,8 @@ class Application:
|
||||
|
||||
instance_config: config_mgr.ConfigManager = None
|
||||
|
||||
instance_id: config_mgr.ConfigManager = None # used to identify the instance
|
||||
|
||||
# ======= Metadata config manager =======
|
||||
|
||||
sensitive_meta: config_mgr.ConfigManager = None
|
||||
@@ -114,10 +120,14 @@ class Application:
|
||||
|
||||
user_service: user_service.UserService = None
|
||||
|
||||
space_service: space_service.SpaceService = None
|
||||
|
||||
llm_model_service: model_service.LLMModelsService = None
|
||||
|
||||
embedding_models_service: model_service.EmbeddingModelsService = None
|
||||
|
||||
provider_service: provider_service.ModelProviderService = None
|
||||
|
||||
pipeline_service: pipeline_service.PipelineService = None
|
||||
|
||||
bot_service: bot_service.BotService = None
|
||||
@@ -132,6 +142,10 @@ class Application:
|
||||
|
||||
webhook_service: webhook_service.WebhookService = None
|
||||
|
||||
telemetry: telemetry_module.TelemetryManager = None
|
||||
|
||||
monitoring_service: monitoring_service.MonitoringService = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class('dingtalk_card_auto_layout', 41)
|
||||
class DingTalkCardAutoLayoutMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return True
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
self.ap.platform_cfg.data['platform-adapters']['app']['dingtalk']['card_auto_layout'] = False
|
||||
await self.ap.platform_cfg.dump_config()
|
||||
@@ -16,7 +16,9 @@ from ...platform.webhook_pusher import WebhookPusher
|
||||
from ...persistence import mgr as persistencemgr
|
||||
from ...api.http.controller import main as http_controller
|
||||
from ...api.http.service import user as user_service
|
||||
from ...api.http.service import space as space_service
|
||||
from ...api.http.service import model as model_service
|
||||
from ...api.http.service import provider as provider_service
|
||||
from ...api.http.service import pipeline as pipeline_service
|
||||
from ...api.http.service import bot as bot_service
|
||||
from ...api.http.service import knowledge as knowledge_service
|
||||
@@ -24,11 +26,13 @@ from ...api.http.service import mcp as mcp_service
|
||||
from ...api.http.service import apikey as apikey_service
|
||||
from ...api.http.service import webhook as webhook_service
|
||||
from ...api.http.service import external_kb as external_kb_service
|
||||
from ...api.http.service import monitoring as monitoring_service
|
||||
from ...discover import engine as discover_engine
|
||||
from ...storage import mgr as storagemgr
|
||||
from ...utils import logcache
|
||||
from ...vector import mgr as vectordb_mgr
|
||||
from .. import taskmgr
|
||||
from ...telemetry import telemetry as telemetry_module
|
||||
|
||||
|
||||
@stage.stage_class('BuildAppStage')
|
||||
@@ -43,6 +47,42 @@ class BuildAppStage(stage.BootingStage):
|
||||
discover.discover_blueprint('templates/components.yaml')
|
||||
ap.discover = discover
|
||||
|
||||
user_service_inst = user_service.UserService(ap)
|
||||
ap.user_service = user_service_inst
|
||||
|
||||
space_service_inst = space_service.SpaceService(ap)
|
||||
ap.space_service = space_service_inst
|
||||
|
||||
llm_model_service_inst = model_service.LLMModelsService(ap)
|
||||
ap.llm_model_service = llm_model_service_inst
|
||||
|
||||
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
|
||||
ap.embedding_models_service = embedding_models_service_inst
|
||||
|
||||
provider_service_inst = provider_service.ModelProviderService(ap)
|
||||
ap.provider_service = provider_service_inst
|
||||
|
||||
pipeline_service_inst = pipeline_service.PipelineService(ap)
|
||||
ap.pipeline_service = pipeline_service_inst
|
||||
|
||||
bot_service_inst = bot_service.BotService(ap)
|
||||
ap.bot_service = bot_service_inst
|
||||
|
||||
knowledge_service_inst = knowledge_service.KnowledgeService(ap)
|
||||
ap.knowledge_service = knowledge_service_inst
|
||||
|
||||
external_kb_service_inst = external_kb_service.ExternalKBService(ap)
|
||||
ap.external_kb_service = external_kb_service_inst
|
||||
|
||||
mcp_service_inst = mcp_service.MCPService(ap)
|
||||
ap.mcp_service = mcp_service_inst
|
||||
|
||||
apikey_service_inst = apikey_service.ApiKeyService(ap)
|
||||
ap.apikey_service = apikey_service_inst
|
||||
|
||||
webhook_service_inst = webhook_service.WebhookService(ap)
|
||||
ap.webhook_service = webhook_service_inst
|
||||
|
||||
proxy_mgr = proxy.ProxyManager(ap)
|
||||
await proxy_mgr.initialize()
|
||||
ap.proxy_mgr = proxy_mgr
|
||||
@@ -64,13 +104,18 @@ class BuildAppStage(stage.BootingStage):
|
||||
ap.persistence_mgr = persistence_mgr_inst
|
||||
await persistence_mgr_inst.initialize()
|
||||
|
||||
# Telemetry manager: attach to app so other components can call via self.ap.telemetry
|
||||
telemetry_inst = telemetry_module.TelemetryManager(ap)
|
||||
await telemetry_inst.initialize()
|
||||
ap.telemetry = telemetry_inst
|
||||
|
||||
cmd_mgr_inst = cmdmgr.CommandManager(ap)
|
||||
await cmd_mgr_inst.initialize()
|
||||
ap.cmd_mgr = cmd_mgr_inst
|
||||
|
||||
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
|
||||
await llm_model_mgr_inst.initialize()
|
||||
ap.model_mgr = llm_model_mgr_inst
|
||||
await llm_model_mgr_inst.initialize()
|
||||
|
||||
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
|
||||
await llm_session_mgr_inst.initialize()
|
||||
@@ -105,35 +150,8 @@ class BuildAppStage(stage.BootingStage):
|
||||
await http_ctrl.initialize()
|
||||
ap.http_ctrl = http_ctrl
|
||||
|
||||
user_service_inst = user_service.UserService(ap)
|
||||
ap.user_service = user_service_inst
|
||||
|
||||
llm_model_service_inst = model_service.LLMModelsService(ap)
|
||||
ap.llm_model_service = llm_model_service_inst
|
||||
|
||||
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
|
||||
ap.embedding_models_service = embedding_models_service_inst
|
||||
|
||||
pipeline_service_inst = pipeline_service.PipelineService(ap)
|
||||
ap.pipeline_service = pipeline_service_inst
|
||||
|
||||
bot_service_inst = bot_service.BotService(ap)
|
||||
ap.bot_service = bot_service_inst
|
||||
|
||||
knowledge_service_inst = knowledge_service.KnowledgeService(ap)
|
||||
ap.knowledge_service = knowledge_service_inst
|
||||
|
||||
external_kb_service_inst = external_kb_service.ExternalKBService(ap)
|
||||
ap.external_kb_service = external_kb_service_inst
|
||||
|
||||
mcp_service_inst = mcp_service.MCPService(ap)
|
||||
ap.mcp_service = mcp_service_inst
|
||||
|
||||
apikey_service_inst = apikey_service.ApiKeyService(ap)
|
||||
ap.apikey_service = apikey_service_inst
|
||||
|
||||
webhook_service_inst = webhook_service.WebhookService(ap)
|
||||
ap.webhook_service = webhook_service_inst
|
||||
monitoring_service_inst = monitoring_service.MonitoringService(ap)
|
||||
ap.monitoring_service = monitoring_service_inst
|
||||
|
||||
async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None:
|
||||
await asyncio.sleep(3)
|
||||
|
||||
@@ -2,8 +2,11 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
from langbot.pkg.utils import constants
|
||||
import yaml
|
||||
import importlib.resources as resources
|
||||
import uuid
|
||||
import time
|
||||
|
||||
from .. import stage, app
|
||||
from ..bootutils import config
|
||||
@@ -142,6 +145,22 @@ class LoadConfigStage(stage.BootingStage):
|
||||
|
||||
await ap.instance_config.dump_config()
|
||||
|
||||
# load or generate instance id
|
||||
ap.instance_id = await config.load_json_config(
|
||||
'data/labels/instance_id.json',
|
||||
template_data={
|
||||
'instance_id': f'instance_{str(uuid.uuid4())}',
|
||||
'instance_create_ts': int(time.time()),
|
||||
},
|
||||
completion=False,
|
||||
)
|
||||
|
||||
constants.instance_id = ap.instance_id.data['instance_id']
|
||||
|
||||
print(f'LangBot instance id: {constants.instance_id}')
|
||||
|
||||
await ap.instance_id.dump_config()
|
||||
|
||||
ap.sensitive_meta = await config.load_json_config(
|
||||
'data/metadata/sensitive-words.json',
|
||||
'metadata/sensitive-words.json',
|
||||
|
||||
0
src/langbot/pkg/entity/dto/__init__.py
Normal file
0
src/langbot/pkg/entity/dto/__init__.py
Normal file
49
src/langbot/pkg/entity/dto/space_model.py
Normal file
49
src/langbot/pkg/entity/dto/space_model.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# [
|
||||
# {
|
||||
# "uuid": "7652ebdb-54dc-412c-a830-e9268ac88471",
|
||||
# "model_id": "claude-opus-4-5-20251101",
|
||||
# "display_name": {
|
||||
# "en_US": "claude-opus-4-5-20251101",
|
||||
# "zh_Hans": "claude-opus-4-5-20251101"
|
||||
# },
|
||||
# "description": {},
|
||||
# "provider": "anthropic",
|
||||
# "category": "chat",
|
||||
# "icon_url": "Claude.Color",
|
||||
# "tags": {},
|
||||
# "is_featured": true,
|
||||
# "featured_order": 999,
|
||||
# "model_ratio": 2.5,
|
||||
# "completion_ratio": 5,
|
||||
# "quota_type": 0,
|
||||
# "model_price": 0,
|
||||
# "input_credits": 500,
|
||||
# "output_credits": 2500,
|
||||
# "vendor_id": 1,
|
||||
# "vendor_name": "Anthropic",
|
||||
# "vendor_icon": "Claude.Color",
|
||||
# "supported_endpoints": [
|
||||
# "anthropic",
|
||||
# "openai"
|
||||
# ],
|
||||
# "status": "active",
|
||||
# "metadata": null,
|
||||
# "created_at": "2025-12-30T22:23:38.337207+08:00",
|
||||
# "updated_at": "2025-12-30T22:23:38.337207+08:00"
|
||||
# }
|
||||
# ]
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class SpaceModel(pydantic.BaseModel):
|
||||
uuid: str
|
||||
model_id: str
|
||||
provider: str
|
||||
category: str # chat / embedding
|
||||
llm_abilities: list[str] | None = None
|
||||
is_featured: bool = False
|
||||
featured_order: int = 0
|
||||
status: str
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
6
src/langbot/pkg/entity/errors/account.py
Normal file
6
src/langbot/pkg/entity/errors/account.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class AccountEmailMismatchError(Exception):
|
||||
def __str__(self):
|
||||
return 'Account email mismatch'
|
||||
@@ -7,3 +7,11 @@ class RequesterNotFoundError(Exception):
|
||||
|
||||
def __str__(self):
|
||||
return f'Requester {self.requester_name} not found'
|
||||
|
||||
|
||||
class ProviderNotFoundError(Exception):
|
||||
def __init__(self, provider_name: str):
|
||||
self.provider_name = provider_name
|
||||
|
||||
def __str__(self):
|
||||
return f'Provider {self.provider_name} not found'
|
||||
|
||||
@@ -9,7 +9,7 @@ class MCPServer(Base):
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
enable = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
|
||||
mode = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) # stdio, sse
|
||||
mode = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) # stdio, sse, http
|
||||
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
|
||||
@@ -3,6 +3,25 @@ import sqlalchemy
|
||||
from .base import Base
|
||||
|
||||
|
||||
class ModelProvider(Base):
|
||||
"""Model provider"""
|
||||
|
||||
__tablename__ = 'model_providers'
|
||||
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
base_url = sqlalchemy.Column(sqlalchemy.String(512), nullable=False)
|
||||
api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[])
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
server_default=sqlalchemy.func.now(),
|
||||
onupdate=sqlalchemy.func.now(),
|
||||
)
|
||||
|
||||
|
||||
class LLMModel(Base):
|
||||
"""LLM model"""
|
||||
|
||||
@@ -10,12 +29,10 @@ class LLMModel(Base):
|
||||
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
requester_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
|
||||
provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
abilities = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[])
|
||||
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
prefered_ranking = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
@@ -26,17 +43,15 @@ class LLMModel(Base):
|
||||
|
||||
|
||||
class EmbeddingModel(Base):
|
||||
"""Embedding 模型"""
|
||||
"""Embedding model"""
|
||||
|
||||
__tablename__ = 'embedding_models'
|
||||
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
requester_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
|
||||
provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
prefered_ranking = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
|
||||
105
src/langbot/pkg/entity/persistence/monitoring.py
Normal file
105
src/langbot/pkg/entity/persistence/monitoring.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import sqlalchemy
|
||||
|
||||
from .base import Base
|
||||
|
||||
|
||||
class MonitoringMessage(Base):
|
||||
"""Monitoring message records"""
|
||||
|
||||
__tablename__ = 'monitoring_messages'
|
||||
|
||||
id = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
|
||||
timestamp = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, index=True)
|
||||
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
bot_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
pipeline_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
pipeline_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
message_content = sqlalchemy.Column(sqlalchemy.Text, nullable=False)
|
||||
session_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
status = sqlalchemy.Column(sqlalchemy.String(50), nullable=False) # success, error, pending
|
||||
level = sqlalchemy.Column(sqlalchemy.String(50), nullable=False) # info, warning, error, debug
|
||||
platform = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
user_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
runner_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) # Runner name for this query
|
||||
variables = sqlalchemy.Column(sqlalchemy.Text, nullable=True) # Query variables as JSON string
|
||||
|
||||
|
||||
class MonitoringLLMCall(Base):
|
||||
"""LLM call records"""
|
||||
|
||||
__tablename__ = 'monitoring_llm_calls'
|
||||
|
||||
id = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
|
||||
timestamp = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, index=True)
|
||||
model_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
input_tokens = sqlalchemy.Column(sqlalchemy.Integer, nullable=False)
|
||||
output_tokens = sqlalchemy.Column(sqlalchemy.Integer, nullable=False)
|
||||
total_tokens = sqlalchemy.Column(sqlalchemy.Integer, nullable=False)
|
||||
duration = sqlalchemy.Column(sqlalchemy.Integer, nullable=False) # milliseconds
|
||||
cost = sqlalchemy.Column(sqlalchemy.Float, nullable=True)
|
||||
status = sqlalchemy.Column(sqlalchemy.String(50), nullable=False) # success, error
|
||||
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
bot_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
pipeline_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
pipeline_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
session_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
error_message = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
message_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) # Associated message ID
|
||||
|
||||
|
||||
class MonitoringSession(Base):
|
||||
"""Session tracking records"""
|
||||
|
||||
__tablename__ = 'monitoring_sessions'
|
||||
|
||||
session_id = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
|
||||
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
bot_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
pipeline_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
pipeline_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
message_count = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||
start_time = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, index=True)
|
||||
last_activity = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, index=True)
|
||||
is_active = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True, index=True)
|
||||
platform = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
user_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
|
||||
|
||||
class MonitoringError(Base):
|
||||
"""Error log records"""
|
||||
|
||||
__tablename__ = 'monitoring_errors'
|
||||
|
||||
id = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
|
||||
timestamp = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, index=True)
|
||||
error_type = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
error_message = sqlalchemy.Column(sqlalchemy.Text, nullable=False)
|
||||
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
bot_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
pipeline_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
pipeline_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
session_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
stack_trace = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
message_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) # Associated message ID
|
||||
|
||||
|
||||
class MonitoringEmbeddingCall(Base):
|
||||
"""Embedding call records"""
|
||||
|
||||
__tablename__ = 'monitoring_embedding_calls'
|
||||
|
||||
id = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
|
||||
timestamp = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, index=True)
|
||||
model_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
prompt_tokens = sqlalchemy.Column(sqlalchemy.Integer, nullable=False)
|
||||
total_tokens = sqlalchemy.Column(sqlalchemy.Integer, nullable=False)
|
||||
duration = sqlalchemy.Column(sqlalchemy.Integer, nullable=False) # milliseconds
|
||||
input_count = sqlalchemy.Column(sqlalchemy.Integer, nullable=False) # Number of input texts
|
||||
status = sqlalchemy.Column(sqlalchemy.String(50), nullable=False) # success, error
|
||||
error_message = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
# Optional context fields
|
||||
knowledge_base_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
query_text = sqlalchemy.Column(sqlalchemy.Text, nullable=True) # For retrieval calls
|
||||
session_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
message_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
call_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=True) # embedding, retrieve
|
||||
@@ -11,6 +11,7 @@ class LegacyPipeline(Base):
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
emoji = sqlalchemy.Column(sqlalchemy.String(10), nullable=True, default='⚙️')
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
|
||||
@@ -7,6 +7,7 @@ class KnowledgeBase(Base):
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String, index=True)
|
||||
description = sqlalchemy.Column(sqlalchemy.Text)
|
||||
emoji = sqlalchemy.Column(sqlalchemy.String(10), nullable=True, default='📚')
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now())
|
||||
embedding_model_uuid = sqlalchemy.Column(sqlalchemy.String, default='')
|
||||
@@ -35,6 +36,7 @@ class ExternalKnowledgeBase(Base):
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String, index=True)
|
||||
description = sqlalchemy.Column(sqlalchemy.Text)
|
||||
emoji = sqlalchemy.Column(sqlalchemy.String(10), nullable=True, default='🔗')
|
||||
plugin_author = sqlalchemy.Column(sqlalchemy.String, nullable=False)
|
||||
plugin_name = sqlalchemy.Column(sqlalchemy.String, nullable=False)
|
||||
retriever_name = sqlalchemy.Column(sqlalchemy.String, nullable=False)
|
||||
|
||||
@@ -9,6 +9,17 @@ class User(Base):
|
||||
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
|
||||
user = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
password = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
|
||||
# Account type: 'local' (default) or 'space'
|
||||
account_type = sqlalchemy.Column(sqlalchemy.String(32), nullable=False, server_default='local')
|
||||
|
||||
# Space account fields (nullable, only used when account_type='space')
|
||||
space_account_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
space_access_token = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
space_refresh_token = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
space_access_token_expires_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
space_api_key = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
|
||||
@@ -9,7 +9,7 @@ import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
|
||||
import sqlalchemy
|
||||
|
||||
from . import database, migration
|
||||
from ..entity.persistence import base, pipeline, metadata
|
||||
from ..entity.persistence import base, pipeline, metadata, model as persistence_model
|
||||
from ..entity import persistence
|
||||
from ..core import app
|
||||
from ..utils import constants, importutil
|
||||
@@ -79,6 +79,7 @@ class PersistenceManager:
|
||||
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
|
||||
|
||||
await self.write_default_pipeline()
|
||||
await self.write_space_model_providers()
|
||||
|
||||
async def create_tables(self):
|
||||
# create tables
|
||||
@@ -123,7 +124,42 @@ class PersistenceManager:
|
||||
|
||||
await self.execute_async(sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data))
|
||||
|
||||
# =================================
|
||||
async def write_space_model_providers(self):
|
||||
space_models_gateway_api_url = self.ap.instance_config.data.get('space', {}).get(
|
||||
'models_gateway_api_url', 'https://api.langbot.cloud/v1'
|
||||
)
|
||||
|
||||
# write space model providers
|
||||
result = await self.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||
persistence_model.ModelProvider.requester == 'space-chat-completions'
|
||||
)
|
||||
)
|
||||
exists_space_chat_completions_model_provider = result.first()
|
||||
|
||||
# api keys will be set/updated when the oauth callback
|
||||
if exists_space_chat_completions_model_provider is None:
|
||||
self.ap.logger.info('Creating space model providers...')
|
||||
space_chat_completions_model_provider = {
|
||||
'uuid': '00000000-0000-0000-0000-000000000000',
|
||||
'name': 'LangBot Models',
|
||||
'requester': 'space-chat-completions',
|
||||
'base_url': space_models_gateway_api_url,
|
||||
'api_keys': [],
|
||||
}
|
||||
|
||||
await self.execute_async(
|
||||
sqlalchemy.insert(persistence_model.ModelProvider).values(space_chat_completions_model_provider)
|
||||
)
|
||||
else:
|
||||
if exists_space_chat_completions_model_provider.base_url != space_models_gateway_api_url:
|
||||
await self.execute_async(
|
||||
sqlalchemy.update(persistence_model.ModelProvider)
|
||||
.where(persistence_model.ModelProvider.uuid == exists_space_chat_completions_model_provider.uuid)
|
||||
.values({'base_url': space_models_gateway_api_url})
|
||||
)
|
||||
|
||||
# =================================
|
||||
|
||||
async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult:
|
||||
async with self.get_db_engine().connect() as conn:
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(14)
|
||||
class DBMigrateSpaceAccountSupport(migration.DBMigration):
|
||||
"""Add Space account support fields to users table"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
# Get all column names from the users table
|
||||
columns = []
|
||||
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("SELECT column_name FROM information_schema.columns WHERE table_name = 'users';")
|
||||
)
|
||||
all_result = result.fetchall()
|
||||
columns = [row[0] for row in all_result]
|
||||
else:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text('PRAGMA table_info(users);'))
|
||||
all_result = result.fetchall()
|
||||
columns = [row[1] for row in all_result]
|
||||
|
||||
# Add account_type column
|
||||
if 'account_type' not in columns:
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("ALTER TABLE users ADD COLUMN account_type VARCHAR(32) DEFAULT 'local' NOT NULL")
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("ALTER TABLE users ADD COLUMN account_type VARCHAR(32) DEFAULT 'local' NOT NULL")
|
||||
)
|
||||
|
||||
# Add space_account_uuid column
|
||||
if 'space_account_uuid' not in columns:
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE users ADD COLUMN space_account_uuid VARCHAR(255)')
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE users ADD COLUMN space_account_uuid VARCHAR(255)')
|
||||
)
|
||||
|
||||
# Add space_access_token column
|
||||
if 'space_access_token' not in columns:
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE users ADD COLUMN space_access_token TEXT')
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE users ADD COLUMN space_access_token TEXT')
|
||||
)
|
||||
|
||||
# Add space_refresh_token column
|
||||
if 'space_refresh_token' not in columns:
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE users ADD COLUMN space_refresh_token TEXT')
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE users ADD COLUMN space_refresh_token TEXT')
|
||||
)
|
||||
|
||||
# Add space_access_token_expires_at column
|
||||
if 'space_access_token_expires_at' not in columns:
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE users ADD COLUMN space_access_token_expires_at TIMESTAMP')
|
||||
)
|
||||
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE users ADD COLUMN space_access_token_expires_at DATETIME')
|
||||
)
|
||||
|
||||
# Add space_api_key column
|
||||
if 'space_api_key' not in columns:
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE users ADD COLUMN space_api_key VARCHAR(255)')
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE users ADD COLUMN space_api_key VARCHAR(255)')
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
pass
|
||||
@@ -0,0 +1,15 @@
|
||||
from .. import migration
|
||||
|
||||
|
||||
# this is a deprecated migration
|
||||
@migration.migration_class(15)
|
||||
class DBMigrateModelSourceTracking(migration.DBMigration):
|
||||
"""Add source tracking fields to models tables for Space integration"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
pass
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
pass
|
||||
@@ -0,0 +1,305 @@
|
||||
import uuid as uuid_lib
|
||||
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(16)
|
||||
class DBMigrateModelProviderRefactor(migration.DBMigration):
|
||||
"""Refactor model structure: create providers from existing models and update references"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
# Step 1: Create model_providers table if not exists
|
||||
await self._create_providers_table()
|
||||
|
||||
# Step 2: Migrate existing models to use providers
|
||||
await self._migrate_llm_models()
|
||||
await self._migrate_embedding_models()
|
||||
|
||||
# Step 3: Remove deprecated columns
|
||||
await self._cleanup_columns()
|
||||
|
||||
async def _create_providers_table(self):
|
||||
"""Create model_providers table"""
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("""
|
||||
CREATE TABLE IF NOT EXISTS model_providers (
|
||||
uuid VARCHAR(255) PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
requester VARCHAR(255) NOT NULL,
|
||||
base_url VARCHAR(512) NOT NULL,
|
||||
api_keys JSONB NOT NULL DEFAULT '[]',
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("""
|
||||
CREATE TABLE IF NOT EXISTS model_providers (
|
||||
uuid VARCHAR(255) PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
requester VARCHAR(255) NOT NULL,
|
||||
base_url VARCHAR(512) NOT NULL,
|
||||
api_keys JSON NOT NULL DEFAULT '[]',
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
async def _migrate_llm_models(self):
|
||||
"""Migrate LLM models to use providers"""
|
||||
llm_columns = await self._get_columns('llm_models')
|
||||
|
||||
# Add provider_uuid column if not exists
|
||||
if 'provider_uuid' not in llm_columns:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE llm_models ADD COLUMN provider_uuid VARCHAR(255)')
|
||||
)
|
||||
|
||||
# Add prefered_ranking column if not exists
|
||||
if 'prefered_ranking' not in llm_columns:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE llm_models ADD COLUMN prefered_ranking INTEGER NOT NULL DEFAULT 0')
|
||||
)
|
||||
|
||||
# Only migrate if old columns exist
|
||||
if 'requester' not in llm_columns:
|
||||
return
|
||||
|
||||
# Get all LLM models with old structure
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, name, requester, requester_config, api_keys FROM llm_models')
|
||||
)
|
||||
models = result.fetchall()
|
||||
|
||||
# Create providers and update models
|
||||
provider_cache = {} # (requester, base_url, api_keys_str) -> provider_uuid
|
||||
|
||||
for model in models:
|
||||
model_uuid, model_name, requester, requester_config, api_keys = model
|
||||
|
||||
# Extract base_url from requester_config
|
||||
base_url = ''
|
||||
if requester_config:
|
||||
if isinstance(requester_config, str):
|
||||
import json
|
||||
|
||||
requester_config = json.loads(requester_config)
|
||||
base_url = requester_config.get('base_url', '') or requester_config.get('base-url', '')
|
||||
|
||||
# Parse api_keys if it's a string
|
||||
if isinstance(api_keys, str):
|
||||
import json
|
||||
|
||||
try:
|
||||
api_keys = json.loads(api_keys)
|
||||
except Exception:
|
||||
api_keys = []
|
||||
if not api_keys:
|
||||
api_keys = []
|
||||
|
||||
# Create cache key
|
||||
api_keys_str = str(sorted(api_keys)) if api_keys else '[]'
|
||||
cache_key = (requester, base_url, api_keys_str)
|
||||
|
||||
if cache_key in provider_cache:
|
||||
provider_uuid = provider_cache[cache_key]
|
||||
else:
|
||||
# Create new provider
|
||||
provider_uuid = str(uuid_lib.uuid4())
|
||||
provider_name = f'{requester}'
|
||||
if base_url:
|
||||
# Extract domain for name
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(base_url)
|
||||
provider_name = parsed.netloc or requester
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import json
|
||||
|
||||
api_keys_json = json.dumps(api_keys) if api_keys else '[]'
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("""
|
||||
INSERT INTO model_providers (uuid, name, requester, base_url, api_keys)
|
||||
VALUES (:uuid, :name, :requester, :base_url, :api_keys)
|
||||
"""),
|
||||
{
|
||||
'uuid': provider_uuid,
|
||||
'name': provider_name,
|
||||
'requester': requester,
|
||||
'base_url': base_url,
|
||||
'api_keys': api_keys_json,
|
||||
},
|
||||
)
|
||||
provider_cache[cache_key] = provider_uuid
|
||||
|
||||
# Update model with provider_uuid
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('UPDATE llm_models SET provider_uuid = :provider_uuid WHERE uuid = :uuid'),
|
||||
{'provider_uuid': provider_uuid, 'uuid': model_uuid},
|
||||
)
|
||||
|
||||
async def _migrate_embedding_models(self):
|
||||
"""Migrate embedding models to use providers"""
|
||||
embedding_columns = await self._get_columns('embedding_models')
|
||||
|
||||
# Add provider_uuid column if not exists
|
||||
if 'provider_uuid' not in embedding_columns:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE embedding_models ADD COLUMN provider_uuid VARCHAR(255)')
|
||||
)
|
||||
|
||||
# Add prefered_ranking column if not exists
|
||||
if 'prefered_ranking' not in embedding_columns:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE embedding_models ADD COLUMN prefered_ranking INTEGER NOT NULL DEFAULT 0')
|
||||
)
|
||||
|
||||
# Only migrate if old columns exist
|
||||
if 'requester' not in embedding_columns:
|
||||
return
|
||||
|
||||
# Get all embedding models with old structure
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, name, requester, requester_config, api_keys FROM embedding_models')
|
||||
)
|
||||
models = result.fetchall()
|
||||
|
||||
# Get existing providers
|
||||
provider_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, requester, base_url, api_keys FROM model_providers')
|
||||
)
|
||||
existing_providers = provider_result.fetchall()
|
||||
|
||||
provider_cache = {}
|
||||
for p in existing_providers:
|
||||
p_uuid, p_requester, p_base_url, p_api_keys = p
|
||||
api_keys_str = str(sorted(p_api_keys)) if p_api_keys else '[]'
|
||||
provider_cache[(p_requester, p_base_url, api_keys_str)] = p_uuid
|
||||
|
||||
for model in models:
|
||||
model_uuid, model_name, requester, requester_config, api_keys = model
|
||||
|
||||
base_url = ''
|
||||
if requester_config:
|
||||
if isinstance(requester_config, str):
|
||||
import json
|
||||
|
||||
requester_config = json.loads(requester_config)
|
||||
base_url = requester_config.get('base_url', '') or requester_config.get('base-url', '')
|
||||
|
||||
# Parse api_keys if it's a string
|
||||
if isinstance(api_keys, str):
|
||||
import json
|
||||
|
||||
try:
|
||||
api_keys = json.loads(api_keys)
|
||||
except Exception:
|
||||
api_keys = []
|
||||
if not api_keys:
|
||||
api_keys = []
|
||||
|
||||
api_keys_str = str(sorted(api_keys)) if api_keys else '[]'
|
||||
cache_key = (requester, base_url, api_keys_str)
|
||||
|
||||
if cache_key in provider_cache:
|
||||
provider_uuid = provider_cache[cache_key]
|
||||
else:
|
||||
provider_uuid = str(uuid_lib.uuid4())
|
||||
provider_name = f'{requester}'
|
||||
if base_url:
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(base_url)
|
||||
provider_name = parsed.netloc or requester
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import json
|
||||
|
||||
api_keys_json = json.dumps(api_keys) if api_keys else '[]'
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("""
|
||||
INSERT INTO model_providers (uuid, name, requester, base_url, api_keys)
|
||||
VALUES (:uuid, :name, :requester, :base_url, :api_keys)
|
||||
"""),
|
||||
{
|
||||
'uuid': provider_uuid,
|
||||
'name': provider_name,
|
||||
'requester': requester,
|
||||
'base_url': base_url,
|
||||
'api_keys': api_keys_json,
|
||||
},
|
||||
)
|
||||
provider_cache[cache_key] = provider_uuid
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('UPDATE embedding_models SET provider_uuid = :provider_uuid WHERE uuid = :uuid'),
|
||||
{'provider_uuid': provider_uuid, 'uuid': model_uuid},
|
||||
)
|
||||
|
||||
async def _cleanup_columns(self):
|
||||
"""Remove deprecated columns from model tables"""
|
||||
|
||||
llm_columns = await self._get_columns('llm_models')
|
||||
deprecated_llm_cols = ['requester', 'requester_config', 'api_keys', 'description', 'source', 'space_model_id']
|
||||
for col in deprecated_llm_cols:
|
||||
if col in llm_columns:
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(f'ALTER TABLE llm_models DROP COLUMN IF EXISTS {col}')
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(f'ALTER TABLE llm_models DROP COLUMN {col}')
|
||||
)
|
||||
|
||||
embedding_columns = await self._get_columns('embedding_models')
|
||||
deprecated_embedding_cols = [
|
||||
'requester',
|
||||
'requester_config',
|
||||
'api_keys',
|
||||
'description',
|
||||
'source',
|
||||
'space_model_id',
|
||||
]
|
||||
for col in deprecated_embedding_cols:
|
||||
if col in embedding_columns:
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(f'ALTER TABLE embedding_models DROP COLUMN IF EXISTS {col}')
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(f'ALTER TABLE embedding_models DROP COLUMN {col}')
|
||||
)
|
||||
|
||||
async def _get_columns(self, table_name: str) -> list:
|
||||
"""Get column names for a table"""
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table_name}';"
|
||||
)
|
||||
)
|
||||
all_result = result.fetchall()
|
||||
return [row[0] for row in all_result]
|
||||
else:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name});'))
|
||||
all_result = result.fetchall()
|
||||
return [row[1] for row in all_result]
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
pass
|
||||
@@ -0,0 +1,25 @@
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(17)
|
||||
class MoveCloudServiceUrl(migration.DBMigration):
|
||||
"""迁移云服务 URL 配置"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""升级"""
|
||||
if 'space' not in self.ap.instance_config.data:
|
||||
self.ap.instance_config.data['space'] = {
|
||||
'url': 'https://space.langbot.app',
|
||||
'models_gateway_api_url': 'https://api.langbot.cloud/v1',
|
||||
'oauth_authorize_url': 'https://space.langbot.app/auth/authorize',
|
||||
'disable_models_service': False,
|
||||
}
|
||||
|
||||
if 'plugin' in self.ap.instance_config.data:
|
||||
self.ap.instance_config.data['plugin'].pop('cloud_service_url', None)
|
||||
|
||||
await self.ap.instance_config.dump_config()
|
||||
|
||||
async def downgrade(self):
|
||||
"""降级"""
|
||||
pass
|
||||
@@ -0,0 +1,58 @@
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(18)
|
||||
class DBMigrateAddEmojiSupport(migration.DBMigration):
|
||||
"""Add emoji field to knowledge_bases, external_knowledge_bases and legacy_pipelines tables"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
# Add emoji field to knowledge_bases
|
||||
await self._add_emoji_to_table('knowledge_bases', '📚')
|
||||
|
||||
# Add emoji field to external_knowledge_bases
|
||||
await self._add_emoji_to_table('external_knowledge_bases', '🔗')
|
||||
|
||||
# Add emoji field to legacy_pipelines
|
||||
await self._add_emoji_to_table('legacy_pipelines', '⚙️')
|
||||
|
||||
async def _add_emoji_to_table(self, table_name: str, default_emoji: str):
|
||||
"""Add emoji column to specified table if it doesn't exist"""
|
||||
# Get all column names from the table
|
||||
columns = []
|
||||
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table_name}';"
|
||||
)
|
||||
)
|
||||
all_result = result.fetchall()
|
||||
columns = [row[0] for row in all_result]
|
||||
else:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name});'))
|
||||
all_result = result.fetchall()
|
||||
columns = [row[1] for row in all_result]
|
||||
|
||||
# Check and add emoji column
|
||||
if 'emoji' not in columns:
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(f"ALTER TABLE {table_name} ADD COLUMN emoji VARCHAR(10) DEFAULT '{default_emoji}'")
|
||||
)
|
||||
else:
|
||||
# SQLite doesn't support DEFAULT with emoji directly in ALTER TABLE
|
||||
# Add column without default first
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(f'ALTER TABLE {table_name} ADD COLUMN emoji VARCHAR(10)')
|
||||
)
|
||||
|
||||
# Set default emoji value for existing records
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(f"UPDATE {table_name} SET emoji = '{default_emoji}' WHERE emoji IS NULL")
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
pass
|
||||
270
src/langbot/pkg/pipeline/monitoring_helper.py
Normal file
270
src/langbot/pkg/pipeline/monitoring_helper.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Monitoring helper for recording events during pipeline execution.
|
||||
This module provides convenient methods to record monitoring data
|
||||
without cluttering the main pipeline code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
import typing
|
||||
import time
|
||||
import json
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..core import app
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
class MonitoringHelper:
|
||||
"""Helper class for monitoring operations"""
|
||||
|
||||
@staticmethod
|
||||
async def record_query_start(
|
||||
ap: app.Application,
|
||||
query: pipeline_query.Query,
|
||||
bot_id: str,
|
||||
bot_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_name: str,
|
||||
runner_name: str | None = None,
|
||||
) -> str:
|
||||
"""Record the start of query processing, returns message_id"""
|
||||
try:
|
||||
# Check if session exists, if not, record session start
|
||||
session_id = f'{query.launcher_type}_{query.launcher_id}'
|
||||
|
||||
# Try to record message
|
||||
# Use JSON serialization to preserve message chain structure (including image URLs, etc.)
|
||||
if hasattr(query, 'message_chain') and hasattr(query.message_chain, 'model_dump'):
|
||||
message_content = json.dumps(query.message_chain.model_dump(), ensure_ascii=False)
|
||||
else:
|
||||
message_content = str(query)
|
||||
|
||||
# Variables will be updated in record_query_success after preproc stage sets them
|
||||
# Here we just record None, the full variables will be set when query completes
|
||||
|
||||
message_id = await ap.monitoring_service.record_message(
|
||||
bot_id=bot_id,
|
||||
bot_name=bot_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_name=pipeline_name,
|
||||
message_content=message_content,
|
||||
session_id=session_id,
|
||||
status='pending',
|
||||
level='info',
|
||||
platform=query.launcher_type.value
|
||||
if hasattr(query.launcher_type, 'value')
|
||||
else str(query.launcher_type),
|
||||
user_id=query.sender_id,
|
||||
runner_name=runner_name,
|
||||
variables=None, # Will be updated in record_query_success
|
||||
)
|
||||
|
||||
# Update session activity or create new session if it doesn't exist
|
||||
# Always pass pipeline info to handle pipeline switches
|
||||
session_updated = await ap.monitoring_service.update_session_activity(
|
||||
session_id,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_name=pipeline_name,
|
||||
)
|
||||
if not session_updated:
|
||||
# Session doesn't exist, create it
|
||||
await ap.monitoring_service.record_session_start(
|
||||
session_id=session_id,
|
||||
bot_id=bot_id,
|
||||
bot_name=bot_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_name=pipeline_name,
|
||||
platform=query.launcher_type.value
|
||||
if hasattr(query.launcher_type, 'value')
|
||||
else str(query.launcher_type),
|
||||
user_id=query.sender_id,
|
||||
)
|
||||
|
||||
return message_id
|
||||
except Exception as e:
|
||||
ap.logger.error(f'Failed to record query start: {e}')
|
||||
return ''
|
||||
|
||||
@staticmethod
|
||||
async def record_query_success(
|
||||
ap: app.Application,
|
||||
message_id: str,
|
||||
query: pipeline_query.Query | None = None,
|
||||
):
|
||||
"""Record successful query processing by updating message status and variables"""
|
||||
try:
|
||||
if message_id:
|
||||
# Serialize query.variables (filtering out internal variables)
|
||||
query_variables_str = None
|
||||
if query and hasattr(query, 'variables') and query.variables:
|
||||
filtered_vars = {k: v for k, v in query.variables.items() if not k.startswith('_')}
|
||||
if filtered_vars:
|
||||
try:
|
||||
query_variables_str = json.dumps(filtered_vars, ensure_ascii=False, default=str)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await ap.monitoring_service.update_message_status(
|
||||
message_id=message_id,
|
||||
status='success',
|
||||
variables=query_variables_str,
|
||||
)
|
||||
except Exception as e:
|
||||
ap.logger.error(f'Failed to record query success: {e}')
|
||||
|
||||
@staticmethod
|
||||
async def record_query_error(
|
||||
ap: app.Application,
|
||||
query: pipeline_query.Query,
|
||||
bot_id: str,
|
||||
bot_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_name: str,
|
||||
error: Exception,
|
||||
runner_name: str | None = None,
|
||||
) -> str:
|
||||
"""Record query processing error, returns message_id"""
|
||||
try:
|
||||
session_id = f'{query.launcher_type}_{query.launcher_id}'
|
||||
|
||||
# Record error message
|
||||
message_id = await ap.monitoring_service.record_message(
|
||||
bot_id=bot_id,
|
||||
bot_name=bot_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_name=pipeline_name,
|
||||
message_content=f'Error: {str(error)}',
|
||||
session_id=session_id,
|
||||
status='error',
|
||||
level='error',
|
||||
platform=query.launcher_type.value
|
||||
if hasattr(query.launcher_type, 'value')
|
||||
else str(query.launcher_type),
|
||||
user_id=query.sender_id,
|
||||
runner_name=runner_name,
|
||||
)
|
||||
|
||||
# Record error log
|
||||
await ap.monitoring_service.record_error(
|
||||
bot_id=bot_id,
|
||||
bot_name=bot_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_name=pipeline_name,
|
||||
error_type=type(error).__name__,
|
||||
error_message=str(error),
|
||||
session_id=session_id,
|
||||
stack_trace=traceback.format_exc(),
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
return message_id
|
||||
except Exception as e:
|
||||
ap.logger.error(f'Failed to record query error: {e}')
|
||||
return ''
|
||||
|
||||
@staticmethod
|
||||
async def record_llm_call(
|
||||
ap: app.Application,
|
||||
query: pipeline_query.Query,
|
||||
bot_id: str,
|
||||
bot_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_name: str,
|
||||
model_name: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
duration_ms: int,
|
||||
status: str = 'success',
|
||||
cost: float | None = None,
|
||||
error_message: str | None = None,
|
||||
message_id: str | None = None,
|
||||
):
|
||||
"""Record LLM call"""
|
||||
try:
|
||||
session_id = f'{query.launcher_type}_{query.launcher_id}'
|
||||
|
||||
await ap.monitoring_service.record_llm_call(
|
||||
bot_id=bot_id,
|
||||
bot_name=bot_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_name=pipeline_name,
|
||||
session_id=session_id,
|
||||
model_name=model_name,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
duration=duration_ms,
|
||||
status=status,
|
||||
cost=cost,
|
||||
error_message=error_message,
|
||||
message_id=message_id,
|
||||
)
|
||||
except Exception as e:
|
||||
ap.logger.error(f'Failed to record LLM call: {e}')
|
||||
|
||||
|
||||
class LLMCallMonitor:
|
||||
"""Context manager for monitoring LLM calls"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap: app.Application,
|
||||
query: pipeline_query.Query,
|
||||
bot_id: str,
|
||||
bot_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_name: str,
|
||||
model_name: str,
|
||||
):
|
||||
self.ap = ap
|
||||
self.query = query
|
||||
self.bot_id = bot_id
|
||||
self.bot_name = bot_name
|
||||
self.pipeline_id = pipeline_id
|
||||
self.pipeline_name = pipeline_name
|
||||
self.model_name = model_name
|
||||
self.start_time = None
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
|
||||
async def __aenter__(self):
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
duration_ms = int((time.time() - self.start_time) * 1000)
|
||||
|
||||
if exc_type is not None:
|
||||
# Error occurred
|
||||
await MonitoringHelper.record_llm_call(
|
||||
ap=self.ap,
|
||||
query=self.query,
|
||||
bot_id=self.bot_id,
|
||||
bot_name=self.bot_name,
|
||||
pipeline_id=self.pipeline_id,
|
||||
pipeline_name=self.pipeline_name,
|
||||
model_name=self.model_name,
|
||||
input_tokens=self.input_tokens,
|
||||
output_tokens=self.output_tokens,
|
||||
duration_ms=duration_ms,
|
||||
status='error',
|
||||
error_message=str(exc_val) if exc_val else None,
|
||||
)
|
||||
else:
|
||||
# Success
|
||||
await MonitoringHelper.record_llm_call(
|
||||
ap=self.ap,
|
||||
query=self.query,
|
||||
bot_id=self.bot_id,
|
||||
bot_name=self.bot_name,
|
||||
pipeline_id=self.pipeline_id,
|
||||
pipeline_name=self.pipeline_name,
|
||||
model_name=self.model_name,
|
||||
input_tokens=self.input_tokens,
|
||||
output_tokens=self.output_tokens,
|
||||
duration_ms=duration_ms,
|
||||
status='success',
|
||||
)
|
||||
|
||||
return False # Don't suppress exceptions
|
||||
@@ -115,6 +115,25 @@ class RuntimePipeline:
|
||||
# Store bound plugins and MCP servers in query for filtering
|
||||
query.variables['_pipeline_bound_plugins'] = self.bound_plugins
|
||||
query.variables['_pipeline_bound_mcp_servers'] = self.bound_mcp_servers
|
||||
|
||||
# Record query start for monitoring
|
||||
try:
|
||||
# Get bot name from bot_uuid
|
||||
bot_name = 'WebChat'
|
||||
if query.bot_uuid:
|
||||
try:
|
||||
bot = await self.ap.bot_service.get_bot(query.bot_uuid, include_secret=False)
|
||||
if bot:
|
||||
bot_name = bot.get('name', 'Unknown')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Store for later use in process_query
|
||||
query.variables['_monitoring_bot_name'] = bot_name
|
||||
query.variables['_monitoring_pipeline_name'] = self.pipeline_entity.name
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to prepare monitoring data: {e}')
|
||||
|
||||
await self.process_query(query)
|
||||
|
||||
async def _check_output(self, query: pipeline_query.Query, result: pipeline_entities.StageProcessResult):
|
||||
@@ -131,7 +150,7 @@ class RuntimePipeline:
|
||||
query.message_event, platform_events.GroupMessage
|
||||
):
|
||||
result.user_notice.insert(0, platform_message.At(target=query.message_event.sender.id))
|
||||
if await query.adapter.is_stream_output_supported():
|
||||
if await query.adapter.is_stream_output_supported() and query.resp_messages:
|
||||
await query.adapter.reply_message_chunk(
|
||||
message_source=query.message_event,
|
||||
bot_message=query.resp_messages[-1],
|
||||
@@ -151,6 +170,37 @@ class RuntimePipeline:
|
||||
self.ap.logger.info(result.console_notice)
|
||||
if result.error_notice:
|
||||
self.ap.logger.error(result.error_notice)
|
||||
# Mark query as having error
|
||||
query.variables['_monitoring_has_error'] = True
|
||||
# Record error to monitoring system
|
||||
try:
|
||||
bot_name = query.variables.get('_monitoring_bot_name', 'Unknown')
|
||||
pipeline_name = query.variables.get('_monitoring_pipeline_name', 'Unknown')
|
||||
message_id = query.variables.get('_monitoring_message_id', '')
|
||||
session_id = f'{query.launcher_type}_{query.launcher_id}'
|
||||
|
||||
# Update message status to error
|
||||
if message_id:
|
||||
await self.ap.monitoring_service.update_message_status(
|
||||
message_id=message_id,
|
||||
status='error',
|
||||
level='error',
|
||||
)
|
||||
|
||||
# Record error log
|
||||
await self.ap.monitoring_service.record_error(
|
||||
bot_id=query.bot_uuid or 'unknown',
|
||||
bot_name=bot_name,
|
||||
pipeline_id=self.pipeline_entity.uuid,
|
||||
pipeline_name=pipeline_name,
|
||||
error_type='PipelineError',
|
||||
error_message=result.error_notice,
|
||||
session_id=session_id,
|
||||
stack_trace=result.debug_notice if result.debug_notice else None,
|
||||
message_id=message_id,
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to record error to monitoring: {e}')
|
||||
|
||||
async def _execute_from_stage(
|
||||
self,
|
||||
@@ -221,6 +271,34 @@ class RuntimePipeline:
|
||||
|
||||
async def process_query(self, query: pipeline_query.Query):
|
||||
"""处理请求"""
|
||||
# Get monitoring metadata
|
||||
bot_name = query.variables.get('_monitoring_bot_name', 'Unknown')
|
||||
pipeline_name = query.variables.get('_monitoring_pipeline_name', 'Unknown')
|
||||
|
||||
# Get runner name from pipeline config
|
||||
runner_name = None
|
||||
if query.pipeline_config and 'ai' in query.pipeline_config and 'runner' in query.pipeline_config['ai']:
|
||||
runner_name = query.pipeline_config['ai']['runner'].get('runner')
|
||||
|
||||
# Record query start and store message_id
|
||||
message_id = ''
|
||||
try:
|
||||
from . import monitoring_helper
|
||||
|
||||
message_id = await monitoring_helper.MonitoringHelper.record_query_start(
|
||||
ap=self.ap,
|
||||
query=query,
|
||||
bot_id=query.bot_uuid or 'unknown',
|
||||
bot_name=bot_name,
|
||||
pipeline_id=self.pipeline_entity.uuid,
|
||||
pipeline_name=pipeline_name,
|
||||
runner_name=runner_name,
|
||||
)
|
||||
# Store message_id in query variables for LLM call monitoring
|
||||
query.variables['_monitoring_message_id'] = message_id
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to record query start: {e}')
|
||||
|
||||
try:
|
||||
# Get bound plugins for this pipeline
|
||||
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
|
||||
@@ -249,10 +327,40 @@ class RuntimePipeline:
|
||||
self.ap.logger.debug(f'Processing query {query.query_id}')
|
||||
|
||||
await self._execute_from_stage(0, query)
|
||||
|
||||
# Record query success only if no error occurred during processing
|
||||
if not query.variables.get('_monitoring_has_error', False):
|
||||
try:
|
||||
await monitoring_helper.MonitoringHelper.record_query_success(
|
||||
ap=self.ap,
|
||||
message_id=message_id,
|
||||
query=query,
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to record query success: {e}')
|
||||
|
||||
except Exception as e:
|
||||
inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
|
||||
self.ap.logger.error(f'Error processing query {query.query_id} stage={inst_name} : {e}')
|
||||
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
|
||||
|
||||
# Record query error
|
||||
try:
|
||||
from . import monitoring_helper
|
||||
|
||||
await monitoring_helper.MonitoringHelper.record_query_error(
|
||||
ap=self.ap,
|
||||
query=query,
|
||||
bot_id=query.bot_uuid or 'unknown',
|
||||
bot_name=bot_name,
|
||||
pipeline_id=self.pipeline_entity.uuid,
|
||||
pipeline_name=pipeline_name,
|
||||
error=e,
|
||||
runner_name=runner_name,
|
||||
)
|
||||
except Exception as me:
|
||||
self.ap.logger.error(f'Failed to record query error: {me}')
|
||||
|
||||
finally:
|
||||
self.ap.logger.debug(f'Query {query.query_id} processed')
|
||||
del self.ap.query_pool.cached_queries[query.query_id]
|
||||
|
||||
@@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
import uuid
|
||||
import typing
|
||||
import traceback
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
from .. import handler
|
||||
@@ -10,7 +12,7 @@ from ... import entities
|
||||
from ....provider import runner as runner_module
|
||||
|
||||
import langbot_plugin.api.entities.events as events
|
||||
from ....utils import importutil
|
||||
from ....utils import importutil, constants
|
||||
from ....provider import runners
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
@@ -84,6 +86,9 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'Request Runner not found: {query.pipeline_config["ai"]["runner"]["runner"]}')
|
||||
# Mark start time for telemetry
|
||||
start_ts = time.time()
|
||||
|
||||
if is_stream:
|
||||
resp_message_id = uuid.uuid4()
|
||||
chunk_count = 0 # Track streaming chunks to reduce excessive logging
|
||||
@@ -140,7 +145,8 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
query.session.using_conversation.messages.extend(query.resp_messages)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Conversation({query.query_id}) Request Failed: {type(e).__name__} {str(e)}')
|
||||
error_info = f'{traceback.format_exc()}'
|
||||
self.ap.logger.error(f'Conversation({query.query_id}) Request Failed: {error_info}')
|
||||
traceback.print_exc()
|
||||
|
||||
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
|
||||
@@ -153,5 +159,47 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
debug_notice=traceback.format_exc(),
|
||||
)
|
||||
finally:
|
||||
# TODO statistics
|
||||
pass
|
||||
# Telemetry reporting: collect minimal per-query execution info and send asynchronously
|
||||
try:
|
||||
end_ts = time.time()
|
||||
duration_ms = None
|
||||
if 'start_ts' in locals():
|
||||
duration_ms = int((end_ts - start_ts) * 1000)
|
||||
|
||||
adapter_name = query.adapter.__class__.__name__ if hasattr(query, 'adapter') else None
|
||||
runner_name = (
|
||||
query.pipeline_config.get('ai', {}).get('runner', {}).get('runner')
|
||||
if query.pipeline_config
|
||||
else None
|
||||
)
|
||||
|
||||
# Model name if using localagent
|
||||
model_name = None
|
||||
try:
|
||||
if runner_name == 'local-agent' and getattr(query, 'use_llm_model_uuid', None):
|
||||
m = await self.ap.model_mgr.get_model_by_uuid(query.use_llm_model_uuid)
|
||||
if m and getattr(m, 'model_entity', None):
|
||||
model_name = getattr(m.model_entity, 'name', None)
|
||||
except Exception:
|
||||
model_name = None
|
||||
|
||||
pipeline_plugins = query.variables.get('_pipeline_bound_plugins', None)
|
||||
|
||||
payload = {
|
||||
'query_id': query.query_id,
|
||||
'adapter': adapter_name,
|
||||
'runner': runner_name,
|
||||
'duration_ms': duration_ms,
|
||||
'model_name': model_name,
|
||||
'version': constants.semantic_version,
|
||||
'instance_id': constants.instance_id,
|
||||
'pipeline_plugins': pipeline_plugins,
|
||||
'error': locals().get('error_info', None),
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
# Send telemetry asynchronously and do not block pipeline via app's telemetry manager
|
||||
await self.ap.telemetry.start_send_task(payload)
|
||||
except Exception as ex:
|
||||
# Ensure telemetry issues do not affect normal flow
|
||||
self.ap.logger.warning(f'Failed to send telemetry: {ex}')
|
||||
|
||||
@@ -75,10 +75,17 @@ class RuntimeBot:
|
||||
|
||||
# Only add to query pool if no webhook requested to skip pipeline
|
||||
if not skip_pipeline:
|
||||
launcher_id = event.sender.id
|
||||
|
||||
if hasattr(adapter, 'get_launcher_id'):
|
||||
custom_launcher_id = adapter.get_launcher_id(event)
|
||||
if custom_launcher_id:
|
||||
launcher_id = custom_launcher_id
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
@@ -86,7 +93,7 @@ class RuntimeBot:
|
||||
pipeline_uuid=self.bot_entity.use_pipeline_uuid,
|
||||
)
|
||||
else:
|
||||
await self.logger.info(f'Pipeline skipped for person message due to webhook response')
|
||||
await self.logger.info('Pipeline skipped for person message due to webhook response')
|
||||
|
||||
async def on_group_message(
|
||||
event: platform_events.GroupMessage,
|
||||
@@ -111,10 +118,17 @@ class RuntimeBot:
|
||||
|
||||
# Only add to query pool if no webhook requested to skip pipeline
|
||||
if not skip_pipeline:
|
||||
launcher_id = event.group.id
|
||||
|
||||
if hasattr(adapter, 'get_launcher_id'):
|
||||
custom_launcher_id = adapter.get_launcher_id(event)
|
||||
if custom_launcher_id:
|
||||
launcher_id = custom_launcher_id
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=provider_session.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
@@ -122,7 +136,7 @@ class RuntimeBot:
|
||||
pipeline_uuid=self.bot_entity.use_pipeline_uuid,
|
||||
)
|
||||
else:
|
||||
await self.logger.info(f'Pipeline skipped for group message due to webhook response')
|
||||
await self.logger.info('Pipeline skipped for group message due to webhook response')
|
||||
|
||||
self.adapter.register_listener(platform_events.FriendMessage, on_friend_message)
|
||||
self.adapter.register_listener(platform_events.GroupMessage, on_group_message)
|
||||
|
||||
@@ -231,7 +231,10 @@ class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
card_template_id = self.config['card_template_id']
|
||||
incoming_message = event.source_platform_object.incoming_message
|
||||
# message_id = incoming_message.message_id
|
||||
card_instance, card_instance_id = await self.bot.create_and_card(card_template_id, incoming_message)
|
||||
card_auto_layout = self.config.get('card_ auto_layout', False)
|
||||
card_instance, card_instance_id = await self.bot.create_and_card(
|
||||
card_template_id, incoming_message, card_auto_layout=card_auto_layout
|
||||
)
|
||||
self.card_instance_id_dict[message_id] = (card_instance, card_instance_id)
|
||||
return True
|
||||
|
||||
|
||||
@@ -56,6 +56,13 @@ spec:
|
||||
type: boolean
|
||||
required: true
|
||||
default: false
|
||||
- name: card_auto_layout
|
||||
label:
|
||||
en_US: Card Auto Layout
|
||||
zh_Hans: 卡片宽屏自动布局
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
- name: card_template_id
|
||||
label:
|
||||
en_US: card template id
|
||||
|
||||
@@ -244,7 +244,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
|
||||
lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time))
|
||||
|
||||
|
||||
if message.message_type == 'text':
|
||||
element_list = []
|
||||
|
||||
@@ -310,7 +309,11 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
]
|
||||
elif message.message_type == 'audio':
|
||||
message_content['content'] = [
|
||||
{'tag': 'audio', 'file_key': message_content['file_key'], "duration": message_content.get('duration',0)}
|
||||
{
|
||||
'tag': 'audio',
|
||||
'file_key': message_content['file_key'],
|
||||
'duration': message_content.get('duration', 0),
|
||||
}
|
||||
]
|
||||
|
||||
for ele in message_content['content']:
|
||||
@@ -367,12 +370,9 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
audio_bytes = response.file.read()
|
||||
audio_base64 = base64.b64encode(audio_bytes).decode()
|
||||
|
||||
|
||||
# Get content type from response headers
|
||||
content_type = response.raw.headers.get('content-type', 'audio/mpeg')
|
||||
|
||||
|
||||
|
||||
mime_main = content_type.split(';')[0].strip()
|
||||
ext = mimetypes.guess_extension(mime_main) or '.bin'
|
||||
temp_dir = tempfile.gettempdir()
|
||||
@@ -418,7 +418,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
file_bytes = response.file.read()
|
||||
file_base64 = base64.b64encode(file_bytes).decode()
|
||||
|
||||
|
||||
file_format = response.raw.headers['content-type']
|
||||
|
||||
file_size = len(file_bytes)
|
||||
@@ -453,7 +452,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
return platform_message.MessageChain(lb_msg_list)
|
||||
|
||||
|
||||
|
||||
@@ -76,6 +76,7 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
|
||||
AppID=config['AppID'],
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
api_base_url=config.get('api_base_url', 'https://api.weixin.qq.com'),
|
||||
)
|
||||
elif config['Mode'] == 'passive':
|
||||
bot = OAClientForLongerResponse(
|
||||
@@ -86,6 +87,7 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
|
||||
LoadingMessage=config.get('LoadingMessage', ''),
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
api_base_url=config.get('api_base_url', 'https://api.weixin.qq.com'),
|
||||
)
|
||||
else:
|
||||
raise KeyError('请设置微信公众号通信模式')
|
||||
|
||||
@@ -53,6 +53,16 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: "AI正在思考中,请发送任意内容获取回复。"
|
||||
- name: api_base_url
|
||||
label:
|
||||
en_US: API Base URL
|
||||
zh_Hans: API 基础 URL
|
||||
description:
|
||||
en_US: API Base URL, used for accessing the Official Account API. If you are deploying in an internal network environment and accessing the Official Account API through a reverse proxy, please fill in this item according to the documentation.
|
||||
zh_Hans: 可选,若您部署在内网环境并通过反向代理访问微信公众号 API,可根据文档修改此项
|
||||
type: string
|
||||
required: false
|
||||
default: "https://api.weixin.qq.com"
|
||||
execution:
|
||||
python:
|
||||
path: ./officialaccount.py
|
||||
|
||||
@@ -85,6 +85,26 @@ class TelegramMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
||||
)
|
||||
)
|
||||
|
||||
if message.voice:
|
||||
if message.caption:
|
||||
message_components.extend(parse_message_text(message.caption))
|
||||
|
||||
file = await message.voice.get_file()
|
||||
|
||||
file_bytes = None
|
||||
file_format = message.voice.mime_type or 'audio/ogg'
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(file.file_path) as response:
|
||||
file_bytes = await response.read()
|
||||
|
||||
message_components.append(
|
||||
platform_message.Voice(
|
||||
base64=f'data:{file_format};base64,{base64.b64encode(file_bytes).decode("utf-8")}',
|
||||
length=message.voice.duration,
|
||||
)
|
||||
)
|
||||
|
||||
return platform_message.MessageChain(message_components)
|
||||
|
||||
|
||||
@@ -159,7 +179,9 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
|
||||
application = ApplicationBuilder().token(config['token']).build()
|
||||
bot = application.bot
|
||||
application.add_handler(MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback))
|
||||
application.add_handler(
|
||||
MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO | filters.VOICE, telegram_callback)
|
||||
)
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
@@ -197,6 +219,10 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
}
|
||||
if self.config['markdown_card'] is True:
|
||||
args['parse_mode'] = 'MarkdownV2'
|
||||
|
||||
if message_source.source_platform_object.message.message_thread_id:
|
||||
args['message_thread_id'] = message_source.source_platform_object.message.message_thread_id
|
||||
|
||||
if quote_origin:
|
||||
args['reply_to_message_id'] = message_source.source_platform_object.message.id
|
||||
|
||||
@@ -216,8 +242,6 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
components = await TelegramMessageConverter.yiri2target(message, self.bot)
|
||||
args = {}
|
||||
message_id = message_source.source_platform_object.message.id
|
||||
if quote_origin:
|
||||
args['reply_to_message_id'] = message_source.source_platform_object.message.id
|
||||
|
||||
component = components[0]
|
||||
if message_id not in self.msg_stream_id: # 当消息回复第一次时,发送新消息
|
||||
@@ -233,6 +257,12 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
'chat_id': message_source.source_platform_object.effective_chat.id,
|
||||
'text': content,
|
||||
}
|
||||
if message_source.source_platform_object.message.message_thread_id:
|
||||
args['message_thread_id'] = message_source.source_platform_object.message.message_thread_id
|
||||
|
||||
if quote_origin:
|
||||
args['reply_to_message_id'] = message_source.source_platform_object.message.id
|
||||
|
||||
if self.config['markdown_card'] is True:
|
||||
args['parse_mode'] = 'MarkdownV2'
|
||||
|
||||
@@ -260,6 +290,24 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
# self.seq = 1 # 消息回复结束之后重置seq
|
||||
self.msg_stream_id.pop(message_id) # 消息回复结束之后删除流式消息id
|
||||
|
||||
def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None:
|
||||
if not isinstance(event.source_platform_object, Update):
|
||||
return None
|
||||
|
||||
message = event.source_platform_object.message
|
||||
if not message:
|
||||
return None
|
||||
|
||||
# specifically handle telegram forum topic and private thread(not supported by official client yet but supported by bot api)
|
||||
if message.message_thread_id:
|
||||
# check if it is a group
|
||||
if isinstance(event, platform_events.GroupMessage):
|
||||
return f'{event.group.id}#{message.message_thread_id}'
|
||||
elif isinstance(event, platform_events.FriendMessage):
|
||||
return f'{event.sender.id}#{message.message_thread_id}'
|
||||
|
||||
return None
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
is_stream = False
|
||||
if self.config.get('enable-stream-reply', None):
|
||||
|
||||
@@ -15,6 +15,58 @@ import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
|
||||
|
||||
def split_string_by_bytes(text, limit=2048, encoding='utf-8'):
|
||||
"""
|
||||
Splits a string into a list of strings, where each part is at most 'limit' bytes.
|
||||
|
||||
Args:
|
||||
text (str): The original string to split.
|
||||
limit (int): The maximum byte size for each split part.
|
||||
encoding (str): The encoding to use (default is 'utf-8').
|
||||
|
||||
Returns:
|
||||
list: A list of split strings.
|
||||
"""
|
||||
# 1. Encode the entire string into bytes
|
||||
bytes_data = text.encode(encoding)
|
||||
total_len = len(bytes_data)
|
||||
|
||||
parts = []
|
||||
start = 0
|
||||
|
||||
while start < total_len:
|
||||
# 2. Determine the end index for the current chunk
|
||||
# It shouldn't exceed the total length
|
||||
end = min(start + limit, total_len)
|
||||
|
||||
# 3. Slice the byte array
|
||||
chunk = bytes_data[start:end]
|
||||
|
||||
# 4. Attempt to decode the chunk
|
||||
# Use errors='ignore' to drop any partial bytes at the end of the chunk
|
||||
# (e.g., if a 3-byte character was cut after the 2nd byte)
|
||||
part_str = chunk.decode(encoding, errors='ignore')
|
||||
|
||||
# 5. Calculate the actual byte length of the successfully decoded string
|
||||
# This tells us exactly where the valid character boundary ended
|
||||
part_bytes = part_str.encode(encoding)
|
||||
part_len = len(part_bytes)
|
||||
|
||||
# Safety check: Prevent infinite loop if limit is too small (e.g., limit=1 for a Chinese char)
|
||||
if part_len == 0 and end < total_len:
|
||||
# Force advance by 1 byte to consume the un-decodable byte or raise error
|
||||
# Here we just treat it as a part to avoid stuck loops, though it might be invalid
|
||||
start += 1
|
||||
continue
|
||||
|
||||
parts.append(part_str)
|
||||
|
||||
# 6. Move the start pointer by the actual length consumed
|
||||
start += part_len
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
class WecomMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain: platform_message.MessageChain, bot: WecomClient):
|
||||
@@ -22,11 +74,15 @@ class WecomMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
|
||||
for msg in message_chain:
|
||||
if type(msg) is platform_message.Plain:
|
||||
content_list.append(
|
||||
{
|
||||
'type': 'text',
|
||||
'content': msg.text,
|
||||
}
|
||||
chunks = split_string_by_bytes(msg.text)
|
||||
content_list.extend(
|
||||
[
|
||||
{
|
||||
'type': 'text',
|
||||
'content': chunk,
|
||||
}
|
||||
for chunk in chunks
|
||||
]
|
||||
)
|
||||
elif type(msg) is platform_message.Image:
|
||||
content_list.append(
|
||||
@@ -170,6 +226,7 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
contacts_secret=config['contacts_secret'],
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
api_base_url=config.get('api_base_url', 'https://qyapi.weixin.qq.com/cgi-bin'),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
|
||||
@@ -46,6 +46,16 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: api_base_url
|
||||
label:
|
||||
en_US: API Base URL
|
||||
zh_Hans: API 基础 URL
|
||||
description:
|
||||
en_US: API Base URL, used for accessing the WeCom API. If you are deploying in an internal network environment and accessing the WeCom Customer Service API through a reverse proxy, please fill in this item according to the documentation.
|
||||
zh_Hans: 可选,若您部署在内网环境并通过反向代理访问企业微信 API,可根据文档填写此项
|
||||
type: string
|
||||
required: false
|
||||
default: "https://qyapi.weixin.qq.com/cgi-bin"
|
||||
execution:
|
||||
python:
|
||||
path: ./wecom.py
|
||||
|
||||
@@ -141,6 +141,7 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
EncodingAESKey=config['EncodingAESKey'],
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
api_base_url=config.get('api_base_url', 'https://qyapi.weixin.qq.com/cgi-bin'),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
|
||||
@@ -39,6 +39,16 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: api_base_url
|
||||
label:
|
||||
en_US: API Base URL
|
||||
zh_Hans: API 基础 URL
|
||||
description:
|
||||
en_US: API Base URL, used for accessing the WeCom API. If you are deploying in an internal network environment and accessing the WeCom Customer Service API through a reverse proxy, please fill in this item according to the documentation.
|
||||
zh_Hans: 可选,若您部署在内网环境并通过反向代理访问企业微信 API,可根据文档修改此项
|
||||
type: string
|
||||
required: false
|
||||
default: "https://qyapi.weixin.qq.com/cgi-bin"
|
||||
execution:
|
||||
python:
|
||||
path: ./wecomcs.py
|
||||
|
||||
@@ -56,7 +56,7 @@ class WebhookPusher:
|
||||
# Check if any webhook responded with skip_pipeline=true
|
||||
for result in results:
|
||||
if isinstance(result, dict) and result.get('skip_pipeline') is True:
|
||||
self.logger.info(f'Webhook responded with skip_pipeline=true, skipping pipeline for person message')
|
||||
self.logger.info('Webhook responded with skip_pipeline=true, skipping pipeline for person message')
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -103,7 +103,7 @@ class WebhookPusher:
|
||||
# Check if any webhook responded with skip_pipeline=true
|
||||
for result in results:
|
||||
if isinstance(result, dict) and result.get('skip_pipeline') is True:
|
||||
self.logger.info(f'Webhook responded with skip_pipeline=true, skipping pipeline for group message')
|
||||
self.logger.info('Webhook responded with skip_pipeline=true, skipping pipeline for group message')
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -324,7 +324,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
messages_obj = [provider_message.Message.model_validate(message) for message in messages]
|
||||
funcs_obj = [resource_tool.LLMTool.model_validate(func) for func in funcs]
|
||||
|
||||
result = await llm_model.requester.invoke_llm(
|
||||
result = await llm_model.provider.invoke_llm(
|
||||
query=None,
|
||||
model=llm_model,
|
||||
messages=messages_obj,
|
||||
|
||||
@@ -9,22 +9,24 @@ from ...discover import engine
|
||||
from . import token
|
||||
from ...entity.persistence import model as persistence_model
|
||||
from ...entity.errors import provider as provider_errors
|
||||
|
||||
FETCH_MODEL_LIST_URL = 'https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list'
|
||||
from async_lru import alru_cache
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""模型管理器"""
|
||||
"""Model manager"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
provider_dict: dict[str, requester.RuntimeProvider]
|
||||
"""运行时模型提供商字典, uuid -> RuntimeProvider"""
|
||||
|
||||
llm_models: list[requester.RuntimeLLMModel]
|
||||
|
||||
embedding_models: list[requester.RuntimeEmbeddingModel]
|
||||
|
||||
requester_components: list[engine.Component]
|
||||
|
||||
requester_dict: dict[str, type[requester.ProviderAPIRequester]] # cache
|
||||
requester_dict: dict[str, type[requester.ProviderAPIRequester]]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
@@ -36,7 +38,6 @@ class ModelManager:
|
||||
async def initialize(self):
|
||||
self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester')
|
||||
|
||||
# forge requester class dict
|
||||
requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {}
|
||||
for component in self.requester_components:
|
||||
requester_dict[component.metadata.name] = component.get_python_component_class()
|
||||
@@ -45,139 +46,343 @@ class ModelManager:
|
||||
|
||||
await self.load_models_from_db()
|
||||
|
||||
# Check if space models service is disabled
|
||||
space_config = self.ap.instance_config.data.get('space', {})
|
||||
if space_config.get('disable_models_service', False):
|
||||
self.ap.logger.info('LangBot Space Models service is disabled, skipping sync.')
|
||||
return
|
||||
|
||||
try:
|
||||
await self.sync_new_models_from_space()
|
||||
except Exception as e:
|
||||
self.ap.logger.warning('Failed to sync new models from LangBot Space, model list may not be updated.')
|
||||
self.ap.logger.warning(f' - Error: {e}')
|
||||
|
||||
async def load_models_from_db(self):
|
||||
"""从数据库加载模型"""
|
||||
"""Load models from database"""
|
||||
self.ap.logger.info('Loading models from db...')
|
||||
|
||||
self.llm_models = []
|
||||
self.embedding_models = []
|
||||
|
||||
# llm models
|
||||
# Load all providers first
|
||||
self.provider_dict = {}
|
||||
providers_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider)
|
||||
)
|
||||
for provider in providers_result.all():
|
||||
try:
|
||||
runtime_provider = await self.load_provider(provider)
|
||||
self.provider_dict[provider.uuid] = runtime_provider
|
||||
except provider_errors.RequesterNotFoundError as e:
|
||||
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping provider {provider.uuid}')
|
||||
continue
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to load provider {provider.uuid}: {e}\n{traceback.format_exc()}')
|
||||
|
||||
# Load LLM models
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
|
||||
llm_models = result.all()
|
||||
for llm_model in llm_models:
|
||||
try:
|
||||
await self.load_llm_model(llm_model)
|
||||
except provider_errors.RequesterNotFoundError as e:
|
||||
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping llm model {llm_model.uuid}')
|
||||
provider = self.provider_dict.get(llm_model.provider_uuid)
|
||||
if provider is None:
|
||||
self.ap.logger.warning(f'Provider {llm_model.provider_uuid} not found for model {llm_model.uuid}')
|
||||
continue
|
||||
runtime_llm_model = await self.load_llm_model_with_provider(llm_model, provider)
|
||||
self.llm_models.append(runtime_llm_model)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}')
|
||||
|
||||
# embedding models
|
||||
# Load embedding models
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
|
||||
embedding_models = result.all()
|
||||
for embedding_model in embedding_models:
|
||||
try:
|
||||
await self.load_embedding_model(embedding_model)
|
||||
except provider_errors.RequesterNotFoundError as e:
|
||||
self.ap.logger.warning(
|
||||
f'Requester {e.requester_name} not found, skipping embedding model {embedding_model.uuid}'
|
||||
)
|
||||
provider = self.provider_dict.get(embedding_model.provider_uuid)
|
||||
if provider is None:
|
||||
self.ap.logger.warning(
|
||||
f'Provider {embedding_model.provider_uuid} not found for model {embedding_model.uuid}'
|
||||
)
|
||||
continue
|
||||
runtime_embedding_model = await self.load_embedding_model_with_provider(embedding_model, provider)
|
||||
self.embedding_models.append(runtime_embedding_model)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}')
|
||||
|
||||
async def init_runtime_llm_model(
|
||||
async def sync_new_models_from_space(self):
|
||||
"""Sync models from Space"""
|
||||
space_model_provider = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||
persistence_model.ModelProvider.requester == 'space-chat-completions'
|
||||
)
|
||||
)
|
||||
result = space_model_provider.first()
|
||||
if result is None:
|
||||
raise provider_errors.ProviderNotFoundError('LangBot Models')
|
||||
|
||||
space_model_provider = result
|
||||
|
||||
# get the latest models from space
|
||||
space_models = await self.ap.space_service.get_models()
|
||||
|
||||
exists_llm_models_uuids = [m['uuid'] for m in await self.ap.llm_model_service.get_llm_models()]
|
||||
exists_embedding_models_uuids = [
|
||||
m['uuid'] for m in await self.ap.embedding_models_service.get_embedding_models()
|
||||
]
|
||||
|
||||
for space_model in space_models:
|
||||
if space_model.category == 'chat':
|
||||
uuid = space_model.uuid
|
||||
|
||||
if uuid in exists_llm_models_uuids:
|
||||
continue
|
||||
|
||||
# model will be automatically loaded
|
||||
await self.ap.llm_model_service.create_llm_model(
|
||||
{
|
||||
'uuid': space_model.uuid,
|
||||
'name': space_model.model_id,
|
||||
'provider_uuid': space_model_provider.uuid,
|
||||
'abilities': space_model.llm_abilities or [],
|
||||
'extra_args': {},
|
||||
'prefered_ranking': space_model.featured_order,
|
||||
},
|
||||
preserve_uuid=True,
|
||||
auto_set_to_default_pipeline=False,
|
||||
)
|
||||
|
||||
elif space_model.category == 'embedding':
|
||||
uuid = space_model.uuid
|
||||
|
||||
if uuid in exists_embedding_models_uuids:
|
||||
continue
|
||||
|
||||
# model will be automatically loaded
|
||||
await self.ap.embedding_models_service.create_embedding_model(
|
||||
{
|
||||
'uuid': space_model.uuid,
|
||||
'name': space_model.model_id,
|
||||
'provider_uuid': space_model_provider.uuid,
|
||||
'extra_args': {},
|
||||
'prefered_ranking': space_model.featured_order,
|
||||
},
|
||||
preserve_uuid=True,
|
||||
)
|
||||
|
||||
async def init_temporary_runtime_llm_model(
|
||||
self,
|
||||
model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict,
|
||||
):
|
||||
"""初始化运行时 LLM 模型"""
|
||||
if isinstance(model_info, sqlalchemy.Row):
|
||||
model_info = persistence_model.LLMModel(**model_info._mapping)
|
||||
elif isinstance(model_info, dict):
|
||||
model_info = persistence_model.LLMModel(**model_info)
|
||||
model_info: dict,
|
||||
) -> requester.RuntimeLLMModel:
|
||||
"""Initialize runtime LLM model from dict (for testing)"""
|
||||
provider_info = model_info.get('provider', {})
|
||||
|
||||
if model_info.requester not in self.requester_dict:
|
||||
raise provider_errors.RequesterNotFoundError(model_info.requester)
|
||||
|
||||
requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config)
|
||||
|
||||
await requester_inst.initialize()
|
||||
runtime_provider = await self.load_provider(provider_info)
|
||||
|
||||
runtime_llm_model = requester.RuntimeLLMModel(
|
||||
model_entity=model_info,
|
||||
token_mgr=token.TokenManager(
|
||||
name=model_info.uuid,
|
||||
tokens=model_info.api_keys,
|
||||
model_entity=persistence_model.LLMModel(
|
||||
uuid=model_info.get('uuid', ''),
|
||||
name=model_info.get('name', ''),
|
||||
provider_uuid='',
|
||||
abilities=model_info.get('abilities', []),
|
||||
extra_args=model_info.get('extra_args', {}),
|
||||
),
|
||||
requester=requester_inst,
|
||||
provider=runtime_provider,
|
||||
)
|
||||
|
||||
return runtime_llm_model
|
||||
|
||||
async def init_runtime_embedding_model(
|
||||
async def init_temporary_runtime_embedding_model(
|
||||
self,
|
||||
model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict,
|
||||
):
|
||||
"""初始化运行时 Embedding 模型"""
|
||||
if isinstance(model_info, sqlalchemy.Row):
|
||||
model_info = persistence_model.EmbeddingModel(**model_info._mapping)
|
||||
elif isinstance(model_info, dict):
|
||||
model_info = persistence_model.EmbeddingModel(**model_info)
|
||||
|
||||
if model_info.requester not in self.requester_dict:
|
||||
raise provider_errors.RequesterNotFoundError(model_info.requester)
|
||||
|
||||
requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config)
|
||||
|
||||
await requester_inst.initialize()
|
||||
model_info: dict,
|
||||
) -> requester.RuntimeEmbeddingModel:
|
||||
"""Initialize runtime embedding model from dict (for testing)"""
|
||||
provider_info = model_info.get('provider', {})
|
||||
runtime_provider = await self.load_provider(provider_info)
|
||||
|
||||
runtime_embedding_model = requester.RuntimeEmbeddingModel(
|
||||
model_entity=model_info,
|
||||
token_mgr=token.TokenManager(
|
||||
name=model_info.uuid,
|
||||
tokens=model_info.api_keys,
|
||||
model_entity=persistence_model.EmbeddingModel(
|
||||
uuid=model_info.get('uuid', ''),
|
||||
name=model_info.get('name', ''),
|
||||
provider_uuid='',
|
||||
extra_args=model_info.get('extra_args', {}),
|
||||
),
|
||||
requester=requester_inst,
|
||||
provider=runtime_provider,
|
||||
)
|
||||
|
||||
return runtime_embedding_model
|
||||
|
||||
async def load_llm_model(
|
||||
self,
|
||||
model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict,
|
||||
):
|
||||
"""加载 LLM 模型"""
|
||||
runtime_llm_model = await self.init_runtime_llm_model(model_info)
|
||||
self.llm_models.append(runtime_llm_model)
|
||||
async def load_provider(
|
||||
self, provider_info: persistence_model.ModelProvider | sqlalchemy.Row | dict
|
||||
) -> requester.RuntimeProvider:
|
||||
"""Load provider from dict"""
|
||||
if isinstance(provider_info, sqlalchemy.Row):
|
||||
provider_entity = persistence_model.ModelProvider(**provider_info._mapping)
|
||||
elif isinstance(provider_info, dict):
|
||||
provider_entity = persistence_model.ModelProvider(**provider_info)
|
||||
else:
|
||||
provider_entity = provider_info
|
||||
|
||||
async def load_embedding_model(
|
||||
self,
|
||||
model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict,
|
||||
):
|
||||
"""加载 Embedding 模型"""
|
||||
runtime_embedding_model = await self.init_runtime_embedding_model(model_info)
|
||||
self.embedding_models.append(runtime_embedding_model)
|
||||
if provider_entity.requester not in self.requester_dict:
|
||||
raise provider_errors.RequesterNotFoundError(provider_entity.requester)
|
||||
|
||||
requester_inst = self.requester_dict[provider_entity.requester](
|
||||
ap=self.ap, config={'base_url': provider_entity.base_url}
|
||||
)
|
||||
await requester_inst.initialize()
|
||||
|
||||
token_mgr = token.TokenManager(name=provider_entity.uuid, tokens=provider_entity.api_keys or [])
|
||||
|
||||
provider = requester.RuntimeProvider(
|
||||
provider_entity=provider_entity,
|
||||
token_mgr=token_mgr,
|
||||
requester=requester_inst,
|
||||
)
|
||||
return provider
|
||||
|
||||
async def remove_provider(self, provider_uuid: str):
|
||||
"""Remove provider
|
||||
|
||||
This method will not consider the models using this provider,
|
||||
because the models should be removed by the caller.
|
||||
"""
|
||||
del self.provider_dict[provider_uuid]
|
||||
|
||||
async def reload_provider(self, provider_uuid: str):
|
||||
"""Reload provider"""
|
||||
provider_entity = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||
persistence_model.ModelProvider.uuid == provider_uuid
|
||||
)
|
||||
)
|
||||
provider_entity = provider_entity.first()
|
||||
if provider_entity is None:
|
||||
raise provider_errors.ProviderNotFoundError(provider_uuid)
|
||||
|
||||
new_runtime_provider = await self.load_provider(provider_entity)
|
||||
|
||||
# update refs in runtime models
|
||||
for model in self.llm_models:
|
||||
if model.provider.provider_entity.uuid == provider_uuid:
|
||||
model.provider = new_runtime_provider
|
||||
for model in self.embedding_models:
|
||||
if model.provider.provider_entity.uuid == provider_uuid:
|
||||
model.provider = new_runtime_provider
|
||||
|
||||
# update ref in provider dict
|
||||
self.provider_dict[provider_uuid] = new_runtime_provider
|
||||
|
||||
async def load_llm_model_with_provider(
|
||||
self,
|
||||
model_info: persistence_model.LLMModel | sqlalchemy.Row,
|
||||
provider: requester.RuntimeProvider,
|
||||
) -> requester.RuntimeLLMModel:
|
||||
"""Load LLM model with provider info"""
|
||||
if isinstance(model_info, sqlalchemy.Row):
|
||||
model_info = persistence_model.LLMModel(**model_info._mapping)
|
||||
|
||||
runtime_llm_model = requester.RuntimeLLMModel(
|
||||
model_entity=model_info,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
return runtime_llm_model
|
||||
|
||||
async def load_embedding_model_with_provider(
|
||||
self,
|
||||
model_info: persistence_model.EmbeddingModel | sqlalchemy.Row,
|
||||
provider: requester.RuntimeProvider,
|
||||
) -> requester.RuntimeEmbeddingModel:
|
||||
"""Load embedding model with provider info"""
|
||||
if isinstance(model_info, sqlalchemy.Row):
|
||||
model_info = persistence_model.EmbeddingModel(**model_info._mapping)
|
||||
|
||||
runtime_embedding_model = requester.RuntimeEmbeddingModel(
|
||||
model_entity=model_info,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
return runtime_embedding_model
|
||||
|
||||
async def load_llm_model(self, model_info: dict):
|
||||
"""Load LLM model from dict (with provider info)"""
|
||||
provider_info = model_info.get('provider', {})
|
||||
if not provider_info:
|
||||
raise ValueError('Provider info is required')
|
||||
|
||||
model_entity = persistence_model.LLMModel(
|
||||
uuid=model_info.get('uuid', ''),
|
||||
name=model_info.get('name', ''),
|
||||
provider_uuid=model_info.get('provider_uuid', ''),
|
||||
abilities=model_info.get('abilities', []),
|
||||
extra_args=model_info.get('extra_args', {}),
|
||||
)
|
||||
|
||||
provider_entity = persistence_model.ModelProvider(
|
||||
uuid=provider_info.get('uuid', ''),
|
||||
name=provider_info.get('name', ''),
|
||||
requester=provider_info.get('requester', ''),
|
||||
base_url=provider_info.get('base_url', ''),
|
||||
api_keys=provider_info.get('api_keys', []),
|
||||
)
|
||||
|
||||
await self.load_llm_model_with_provider(model_entity, provider_entity)
|
||||
|
||||
async def load_embedding_model(self, model_info: dict):
|
||||
"""Load embedding model from dict (with provider info)"""
|
||||
provider_info = model_info.get('provider', {})
|
||||
if not provider_info:
|
||||
raise ValueError('Provider info is required')
|
||||
|
||||
model_entity = persistence_model.EmbeddingModel(
|
||||
uuid=model_info.get('uuid', ''),
|
||||
name=model_info.get('name', ''),
|
||||
provider_uuid=model_info.get('provider_uuid', ''),
|
||||
extra_args=model_info.get('extra_args', {}),
|
||||
)
|
||||
|
||||
provider_entity = persistence_model.ModelProvider(
|
||||
uuid=provider_info.get('uuid', ''),
|
||||
name=provider_info.get('name', ''),
|
||||
requester=provider_info.get('requester', ''),
|
||||
base_url=provider_info.get('base_url', ''),
|
||||
api_keys=provider_info.get('api_keys', []),
|
||||
)
|
||||
|
||||
await self.load_embedding_model_with_provider(model_entity, provider_entity)
|
||||
|
||||
@alru_cache(ttl=60 * 5)
|
||||
async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel:
|
||||
"""通过uuid获取 LLM 模型"""
|
||||
"""Get LLM model by uuid"""
|
||||
for model in self.llm_models:
|
||||
if model.model_entity.uuid == uuid:
|
||||
return model
|
||||
raise ValueError(f'LLM model {uuid} not found')
|
||||
|
||||
@alru_cache(ttl=60 * 5)
|
||||
async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel:
|
||||
"""通过uuid获取 Embedding 模型"""
|
||||
"""Get embedding model by uuid"""
|
||||
for model in self.embedding_models:
|
||||
if model.model_entity.uuid == uuid:
|
||||
return model
|
||||
raise ValueError(f'Embedding model {uuid} not found')
|
||||
|
||||
async def remove_llm_model(self, model_uuid: str):
|
||||
"""移除 LLM 模型"""
|
||||
"""Remove LLM model"""
|
||||
for model in self.llm_models:
|
||||
if model.model_entity.uuid == model_uuid:
|
||||
self.llm_models.remove(model)
|
||||
return
|
||||
|
||||
async def remove_embedding_model(self, model_uuid: str):
|
||||
"""移除 Embedding 模型"""
|
||||
"""Remove embedding model"""
|
||||
for model in self.embedding_models:
|
||||
if model.model_entity.uuid == model_uuid:
|
||||
self.embedding_models.remove(model)
|
||||
return
|
||||
|
||||
def get_available_requesters_info(self, model_type: str) -> list[dict]:
|
||||
"""获取所有可用的请求器"""
|
||||
"""Get all available requesters"""
|
||||
if model_type != '':
|
||||
return [
|
||||
component.to_plain_dict()
|
||||
@@ -188,14 +393,14 @@ class ModelManager:
|
||||
return [component.to_plain_dict() for component in self.requester_components]
|
||||
|
||||
def get_available_requester_info_by_name(self, name: str) -> dict | None:
|
||||
"""通过名称获取请求器信息"""
|
||||
"""Get requester info by name"""
|
||||
for component in self.requester_components:
|
||||
if component.metadata.name == name:
|
||||
return component.to_plain_dict()
|
||||
return None
|
||||
|
||||
def get_available_requester_manifest_by_name(self, name: str) -> engine.Component | None:
|
||||
"""通过名称获取请求器清单"""
|
||||
"""Get requester manifest by name"""
|
||||
for component in self.requester_components:
|
||||
if component.metadata.name == name:
|
||||
return component
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
import time
|
||||
|
||||
from ...core import app
|
||||
from ...entity.persistence import model as persistence_model
|
||||
@@ -11,11 +12,11 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
class RuntimeLLMModel:
|
||||
"""运行时模型"""
|
||||
class RuntimeProvider:
|
||||
"""运行时模型提供商"""
|
||||
|
||||
model_entity: persistence_model.LLMModel
|
||||
"""模型数据"""
|
||||
provider_entity: persistence_model.ModelProvider
|
||||
"""提供商数据"""
|
||||
|
||||
token_mgr: token.TokenManager
|
||||
"""api key管理器"""
|
||||
@@ -25,14 +26,245 @@ class RuntimeLLMModel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_entity: persistence_model.LLMModel,
|
||||
provider_entity: persistence_model.ModelProvider,
|
||||
token_mgr: token.TokenManager,
|
||||
requester: ProviderAPIRequester,
|
||||
):
|
||||
self.model_entity = model_entity
|
||||
self.provider_entity = provider_entity
|
||||
self.token_mgr = token_mgr
|
||||
self.requester = requester
|
||||
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
model: RuntimeLLMModel,
|
||||
messages: typing.List[provider_message.Message],
|
||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message:
|
||||
"""Bridge method for invoking LLM with monitoring"""
|
||||
# Start timing for monitoring
|
||||
start_time = time.time()
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
status = 'success'
|
||||
error_message = None
|
||||
|
||||
try:
|
||||
# Call the underlying requester
|
||||
result = await self.requester.invoke_llm(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=messages,
|
||||
funcs=funcs,
|
||||
extra_args=extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
|
||||
# Try to extract token usage if the requester returns it
|
||||
# For requesters that return tuple (message, usage_info)
|
||||
if isinstance(result, tuple):
|
||||
msg, usage_info = result
|
||||
if usage_info:
|
||||
input_tokens = usage_info.get('input_tokens', 0)
|
||||
output_tokens = usage_info.get('output_tokens', 0)
|
||||
return msg
|
||||
else:
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
status = 'error'
|
||||
error_message = str(e)
|
||||
raise
|
||||
finally:
|
||||
# Record LLM call monitoring data (only if query is provided)
|
||||
if query is not None:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Import monitoring helper
|
||||
try:
|
||||
from ...pipeline import monitoring_helper
|
||||
|
||||
# Get monitoring metadata from query variables
|
||||
if query.variables:
|
||||
bot_name = query.variables.get('_monitoring_bot_name', 'Unknown')
|
||||
pipeline_name = query.variables.get('_monitoring_pipeline_name', 'Unknown')
|
||||
message_id = query.variables.get('_monitoring_message_id')
|
||||
else:
|
||||
bot_name = 'Unknown'
|
||||
pipeline_name = 'Unknown'
|
||||
message_id = None
|
||||
|
||||
await monitoring_helper.MonitoringHelper.record_llm_call(
|
||||
ap=self.requester.ap,
|
||||
query=query,
|
||||
bot_id=query.bot_uuid or 'unknown',
|
||||
bot_name=bot_name,
|
||||
pipeline_id=query.pipeline_uuid or 'unknown',
|
||||
pipeline_name=pipeline_name,
|
||||
model_name=model.model_entity.name,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
duration_ms=duration_ms,
|
||||
status=status,
|
||||
error_message=error_message,
|
||||
message_id=message_id,
|
||||
)
|
||||
except Exception as monitor_err:
|
||||
self.requester.ap.logger.error(f'[Monitoring] Failed to record LLM call: {monitor_err}')
|
||||
|
||||
async def invoke_llm_stream(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
model: RuntimeLLMModel,
|
||||
messages: typing.List[provider_message.Message],
|
||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.MessageChunk:
|
||||
"""Bridge method for invoking LLM stream with monitoring"""
|
||||
# Start timing for monitoring
|
||||
start_time = time.time()
|
||||
status = 'success'
|
||||
error_message = None
|
||||
# Note: Stream doesn't easily provide token counts, set to 0
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
try:
|
||||
# Stream the response
|
||||
async for chunk in self.requester.invoke_llm_stream(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=messages,
|
||||
funcs=funcs,
|
||||
extra_args=extra_args,
|
||||
remove_think=remove_think,
|
||||
):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
status = 'error'
|
||||
error_message = str(e)
|
||||
raise
|
||||
finally:
|
||||
# Record LLM call monitoring data (only if query is provided)
|
||||
if query is not None:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Import monitoring helper
|
||||
try:
|
||||
from ...pipeline import monitoring_helper
|
||||
|
||||
# Get monitoring metadata from query variables
|
||||
if query.variables:
|
||||
bot_name = query.variables.get('_monitoring_bot_name', 'Unknown')
|
||||
pipeline_name = query.variables.get('_monitoring_pipeline_name', 'Unknown')
|
||||
message_id = query.variables.get('_monitoring_message_id')
|
||||
else:
|
||||
bot_name = 'Unknown'
|
||||
pipeline_name = 'Unknown'
|
||||
message_id = None
|
||||
|
||||
await monitoring_helper.MonitoringHelper.record_llm_call(
|
||||
ap=self.requester.ap,
|
||||
query=query,
|
||||
bot_id=query.bot_uuid or 'unknown',
|
||||
bot_name=bot_name,
|
||||
pipeline_id=query.pipeline_uuid or 'unknown',
|
||||
pipeline_name=pipeline_name,
|
||||
model_name=model.model_entity.name,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
duration_ms=duration_ms,
|
||||
status=status,
|
||||
error_message=error_message,
|
||||
message_id=message_id,
|
||||
)
|
||||
except Exception as monitor_err:
|
||||
self.requester.ap.logger.error(f'[Monitoring] Failed to record LLM stream call: {monitor_err}')
|
||||
|
||||
async def invoke_embedding(
|
||||
self,
|
||||
model: RuntimeEmbeddingModel,
|
||||
input_text: typing.List[str],
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
knowledge_base_id: str | None = None,
|
||||
query_text: str | None = None,
|
||||
session_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
call_type: str | None = None,
|
||||
) -> typing.List[typing.List[float]]:
|
||||
"""Bridge method for invoking embedding with monitoring"""
|
||||
# Start timing for monitoring
|
||||
start_time = time.time()
|
||||
prompt_tokens = 0
|
||||
total_tokens = 0
|
||||
status = 'success'
|
||||
error_message = None
|
||||
|
||||
try:
|
||||
# Call the underlying requester
|
||||
result = await self.requester.invoke_embedding(
|
||||
model=model,
|
||||
input_text=input_text,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
# Handle both old format (list only) and new format (tuple with usage)
|
||||
if isinstance(result, tuple):
|
||||
embeddings, usage_info = result
|
||||
if usage_info:
|
||||
prompt_tokens = usage_info.get('prompt_tokens', 0)
|
||||
total_tokens = usage_info.get('total_tokens', 0)
|
||||
return embeddings
|
||||
else:
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
status = 'error'
|
||||
error_message = str(e)
|
||||
raise
|
||||
finally:
|
||||
# Record embedding call monitoring data
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
try:
|
||||
await self.requester.ap.monitoring_service.record_embedding_call(
|
||||
model_name=model.model_entity.name,
|
||||
prompt_tokens=prompt_tokens,
|
||||
total_tokens=total_tokens,
|
||||
duration=duration_ms,
|
||||
input_count=len(input_text),
|
||||
status=status,
|
||||
error_message=error_message,
|
||||
knowledge_base_id=knowledge_base_id,
|
||||
query_text=query_text,
|
||||
session_id=session_id,
|
||||
message_id=message_id,
|
||||
call_type=call_type,
|
||||
)
|
||||
except Exception as monitor_err:
|
||||
self.requester.ap.logger.error(f'[Monitoring] Failed to record embedding call: {monitor_err}')
|
||||
|
||||
|
||||
class RuntimeLLMModel:
|
||||
"""运行时模型"""
|
||||
|
||||
model_entity: persistence_model.LLMModel
|
||||
"""模型数据"""
|
||||
|
||||
provider: RuntimeProvider
|
||||
"""提供商实例"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_entity: persistence_model.LLMModel,
|
||||
provider: RuntimeProvider,
|
||||
):
|
||||
self.model_entity = model_entity
|
||||
self.provider = provider
|
||||
|
||||
|
||||
class RuntimeEmbeddingModel:
|
||||
"""运行时 Embedding 模型"""
|
||||
@@ -40,21 +272,16 @@ class RuntimeEmbeddingModel:
|
||||
model_entity: persistence_model.EmbeddingModel
|
||||
"""模型数据"""
|
||||
|
||||
token_mgr: token.TokenManager
|
||||
"""api key管理器"""
|
||||
|
||||
requester: ProviderAPIRequester
|
||||
"""请求器实例"""
|
||||
provider: RuntimeProvider
|
||||
"""提供商实例"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_entity: persistence_model.EmbeddingModel,
|
||||
token_mgr: token.TokenManager,
|
||||
requester: ProviderAPIRequester,
|
||||
provider: RuntimeProvider,
|
||||
):
|
||||
self.model_entity = model_entity
|
||||
self.token_mgr = token_mgr
|
||||
self.requester = requester
|
||||
self.provider = provider
|
||||
|
||||
|
||||
class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
@@ -128,7 +355,7 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
model: RuntimeEmbeddingModel,
|
||||
input_text: typing.List[str],
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> typing.List[typing.List[float]]:
|
||||
) -> typing.Union[typing.List[typing.List[float]], tuple[typing.List[typing.List[float]], dict]]:
|
||||
"""调用 Embedding API
|
||||
|
||||
Args:
|
||||
@@ -138,5 +365,6 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
|
||||
Returns:
|
||||
typing.List[typing.List[float]]: 返回的 embedding 向量
|
||||
或者 tuple[typing.List[typing.List[float]], dict]: 返回 (embedding 向量, usage_info)
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -56,7 +56,7 @@ class AnthropicMessages(requester.ProviderAPIRequester):
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message:
|
||||
self.client.api_key = model.token_mgr.get_token()
|
||||
self.client.api_key = model.provider.token_mgr.get_token()
|
||||
|
||||
args = extra_args.copy()
|
||||
args['model'] = model.model_entity.name
|
||||
@@ -190,7 +190,7 @@ class AnthropicMessages(requester.ProviderAPIRequester):
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message:
|
||||
self.client.api_key = model.token_mgr.get_token()
|
||||
self.client.api_key = model.provider.token_mgr.get_token()
|
||||
|
||||
args = extra_args.copy()
|
||||
args['model'] = model.model_entity.name
|
||||
|
||||
@@ -30,7 +30,7 @@ class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
@@ -117,7 +117,7 @@ class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
|
||||
if is_use_dashscope_call:
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
# 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx"
|
||||
api_key=use_model.token_mgr.get_token(),
|
||||
api_key=use_model.provider.token_mgr.get_token(),
|
||||
model=use_model.model_entity.name,
|
||||
messages=messages,
|
||||
result_format='message',
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
import typing
|
||||
|
||||
import openai
|
||||
import openai.types.chat.chat_completion as chat_completion
|
||||
import openai.types.chat.chat_completion as chat_completion_module
|
||||
import httpx
|
||||
|
||||
from .. import errors, requester
|
||||
@@ -35,7 +35,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
self,
|
||||
args: dict,
|
||||
extra_body: dict = {},
|
||||
) -> chat_completion.ChatCompletion:
|
||||
) -> chat_completion_module.ChatCompletion:
|
||||
return await self.client.chat.completions.create(**args, extra_body=extra_body)
|
||||
|
||||
async def _req_stream(
|
||||
@@ -48,9 +48,12 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
|
||||
async def _make_msg(
|
||||
self,
|
||||
chat_completion: chat_completion.ChatCompletion,
|
||||
chat_completion: chat_completion_module.ChatCompletion,
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message:
|
||||
if not isinstance(chat_completion, chat_completion_module.ChatCompletion):
|
||||
raise TypeError(f'Expected ChatCompletion, got {type(chat_completion).__name__}: {chat_completion[:16]}')
|
||||
|
||||
chatcmpl_message = chat_completion.choices[0].message.model_dump()
|
||||
|
||||
# 确保 role 字段存在且不为 None
|
||||
@@ -130,7 +133,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.MessageChunk:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
@@ -250,8 +253,8 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
use_funcs: list[resource_tool.LLMTool] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
) -> tuple[provider_message.Message, dict]:
|
||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
@@ -282,7 +285,14 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp, remove_think)
|
||||
|
||||
return message
|
||||
# Extract token usage from response
|
||||
usage_info = {}
|
||||
if hasattr(resp, 'usage') and resp.usage:
|
||||
usage_info['input_tokens'] = resp.usage.prompt_tokens or 0
|
||||
usage_info['output_tokens'] = resp.usage.completion_tokens or 0
|
||||
usage_info['total_tokens'] = resp.usage.total_tokens or 0
|
||||
|
||||
return message, usage_info
|
||||
|
||||
async def invoke_llm(
|
||||
self,
|
||||
@@ -292,7 +302,8 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message:
|
||||
) -> tuple[provider_message.Message, dict]:
|
||||
"""Invoke LLM and return message with usage info"""
|
||||
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
|
||||
for m in messages:
|
||||
msg_dict = m.dict(exclude_none=True)
|
||||
@@ -305,7 +316,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
req_messages.append(msg_dict)
|
||||
|
||||
try:
|
||||
msg = await self._closure(
|
||||
msg, usage_info = await self._closure(
|
||||
query=query,
|
||||
req_messages=req_messages,
|
||||
use_model=model,
|
||||
@@ -313,31 +324,39 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
extra_args=extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
return msg
|
||||
return msg, usage_info
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
except openai.BadRequestError as e:
|
||||
if 'context_length_exceeded' in e.message:
|
||||
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
|
||||
error_message = str(e.message) if hasattr(e, 'message') else str(e)
|
||||
if 'context_length_exceeded' in str(e):
|
||||
raise errors.RequesterError(f'上文过长,请重置会话: {error_message}')
|
||||
else:
|
||||
raise errors.RequesterError(f'请求参数错误: {e.message}')
|
||||
raise errors.RequesterError(f'请求参数错误: {error_message}')
|
||||
except openai.AuthenticationError as e:
|
||||
raise errors.RequesterError(f'无效的 api-key: {e.message}')
|
||||
error_message = str(e.message) if hasattr(e, 'message') else str(e)
|
||||
raise errors.RequesterError(f'无效的 api-key: {error_message}')
|
||||
except openai.NotFoundError as e:
|
||||
raise errors.RequesterError(f'请求路径错误: {e.message}')
|
||||
error_message = str(e.message) if hasattr(e, 'message') else str(e)
|
||||
raise errors.RequesterError(f'请求路径错误: {error_message}')
|
||||
except openai.RateLimitError as e:
|
||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
||||
error_message = str(e.message) if hasattr(e, 'message') else str(e)
|
||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {error_message}')
|
||||
except openai.APIConnectionError as e:
|
||||
error_message = f'连接错误: {str(e)}'
|
||||
raise errors.RequesterError(error_message)
|
||||
except openai.APIError as e:
|
||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||
error_message = str(e.message) if hasattr(e, 'message') else str(e)
|
||||
raise errors.RequesterError(f'请求错误: {error_message}')
|
||||
|
||||
async def invoke_embedding(
|
||||
self,
|
||||
model: requester.RuntimeEmbeddingModel,
|
||||
input_text: list[str],
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> list[list[float]]:
|
||||
"""调用 Embedding API"""
|
||||
self.client.api_key = model.token_mgr.get_token()
|
||||
) -> tuple[list[list[float]], dict]:
|
||||
"""调用 Embedding API, returns (embeddings, usage_info)"""
|
||||
self.client.api_key = model.provider.token_mgr.get_token()
|
||||
|
||||
args = {
|
||||
'model': model.model_entity.name,
|
||||
@@ -352,7 +371,13 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
try:
|
||||
resp = await self.client.embeddings.create(**args)
|
||||
|
||||
return [d.embedding for d in resp.data]
|
||||
# Extract usage info
|
||||
usage_info = {}
|
||||
if hasattr(resp, 'usage') and resp.usage:
|
||||
usage_info['prompt_tokens'] = resp.usage.prompt_tokens or 0
|
||||
usage_info['total_tokens'] = resp.usage.total_tokens or 0
|
||||
|
||||
return [d.embedding for d in resp.data], usage_info
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
except openai.BadRequestError as e:
|
||||
|
||||
@@ -25,8 +25,8 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
use_funcs: list[resource_tool.LLMTool] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
) -> tuple[provider_message.Message, dict]:
|
||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
@@ -43,7 +43,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
# deepseek 不支持多模态,把content都转换成纯文字
|
||||
for m in messages:
|
||||
if 'content' in m and isinstance(m['content'], list):
|
||||
m['content'] = ' '.join([c['text'] for c in m['content']])
|
||||
m['content'] = ' '.join([c['text'] for c in m['content'] if 'text' in c])
|
||||
|
||||
args['messages'] = messages
|
||||
|
||||
@@ -57,4 +57,11 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp, remove_think)
|
||||
|
||||
return message
|
||||
# Extract token usage from response
|
||||
usage_info = {}
|
||||
if hasattr(resp, 'usage') and resp.usage:
|
||||
usage_info['input_tokens'] = resp.usage.prompt_tokens or 0
|
||||
usage_info['output_tokens'] = resp.usage.completion_tokens or 0
|
||||
usage_info['total_tokens'] = resp.usage.total_tokens or 0
|
||||
|
||||
return message, usage_info
|
||||
|
||||
@@ -29,7 +29,7 @@ class GeminiChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.MessageChunk:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
@@ -109,7 +109,7 @@ class JieKouAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
@@ -130,8 +130,8 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
use_funcs: list[resource_tool.LLMTool] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
) -> tuple[provider_message.Message, dict]:
|
||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
@@ -162,7 +162,10 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
||||
# ModelScope uses streaming, usage info not available
|
||||
usage_info = {}
|
||||
|
||||
return message, usage_info
|
||||
|
||||
async def _req_stream(
|
||||
self,
|
||||
@@ -181,7 +184,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
@@ -26,8 +26,8 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
use_funcs: list[resource_tool.LLMTool] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
) -> tuple[provider_message.Message, dict]:
|
||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
@@ -57,4 +57,11 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp, remove_think)
|
||||
|
||||
return message
|
||||
# Extract token usage from response
|
||||
usage_info = {}
|
||||
if hasattr(resp, 'usage') and resp.usage:
|
||||
usage_info['input_tokens'] = resp.usage.prompt_tokens or 0
|
||||
usage_info['output_tokens'] = resp.usage.completion_tokens or 0
|
||||
usage_info['total_tokens'] = resp.usage.total_tokens or 0
|
||||
|
||||
return message, usage_info
|
||||
|
||||
@@ -109,7 +109,7 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
@@ -51,9 +51,10 @@ class SeekDBEmbedding(requester.ProviderAPIRequester):
|
||||
await self.initialize()
|
||||
|
||||
if self._embedding_function is None:
|
||||
raise RuntimeError("SeekDB embedding function initialization failed")
|
||||
raise RuntimeError('SeekDB embedding function initialization failed')
|
||||
|
||||
return self._embedding_function(input_text)
|
||||
except Exception as e:
|
||||
from .. import errors
|
||||
|
||||
raise errors.RequesterError(f'SeekDB embedding failed: {str(e)}')
|
||||
|
||||
BIN
src/langbot/pkg/provider/modelmgr/requesters/space.webp
Normal file
BIN
src/langbot/pkg/provider/modelmgr/requesters/space.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import openai
|
||||
|
||||
from . import chatcmpl
|
||||
|
||||
|
||||
class LangBotSpaceChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""LangBot Space ChatCompletion API 请求器"""
|
||||
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base_url': 'https://api.langbot.cloud/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
apiVersion: v1
|
||||
kind: LLMAPIRequester
|
||||
metadata:
|
||||
name: space-chat-completions
|
||||
label:
|
||||
en_US: Space
|
||||
zh_Hans: Space
|
||||
icon: space.webp
|
||||
spec:
|
||||
config:
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_Hans: 基础 URL
|
||||
type: string
|
||||
required: true
|
||||
default: https://api.langbot.cloud/v1
|
||||
- name: timeout
|
||||
label:
|
||||
en_US: Timeout
|
||||
zh_Hans: 超时时间
|
||||
type: integer
|
||||
required: true
|
||||
default: 120
|
||||
support_type:
|
||||
- llm
|
||||
- text-embedding
|
||||
provider_category: maas
|
||||
execution:
|
||||
python:
|
||||
path: ./spacechatcmpl.py
|
||||
attr: LangBotSpaceChatCompletions
|
||||
@@ -18,6 +18,8 @@ class TokenManager:
|
||||
self.using_token_index = 0
|
||||
|
||||
def get_token(self) -> str:
|
||||
if len(self.tokens) == 0:
|
||||
return ''
|
||||
return self.tokens[self.using_token_index]
|
||||
|
||||
def next_token(self):
|
||||
|
||||
@@ -118,6 +118,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
stream=True, # 流式输出
|
||||
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
|
||||
session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话
|
||||
enable_thinking=has_thoughts,
|
||||
has_thoughts=has_thoughts,
|
||||
# rag_options={ # 主要用于文件交互,暂不支持
|
||||
# "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个
|
||||
@@ -141,14 +142,14 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
# 获取流式传输的output
|
||||
stream_output = chunk.get('output', {})
|
||||
stream_think = stream_output.get('thoughts', [])
|
||||
if stream_think[0].get('thought'):
|
||||
if stream_think and stream_think[0].get('thought'):
|
||||
if not think_start:
|
||||
think_start = True
|
||||
pending_content += f'<think>\n{stream_think[0].get("thought")}'
|
||||
else:
|
||||
# 继续输出 reasoning_content
|
||||
pending_content += stream_think[0].get('thought')
|
||||
elif stream_think[0].get('thought') == '' and not think_end:
|
||||
elif (not stream_think or stream_think[0].get('thought') == '') and not think_end:
|
||||
think_end = True
|
||||
pending_content += '\n</think>\n'
|
||||
if stream_output.get('text') is not None:
|
||||
|
||||
@@ -289,12 +289,16 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
yield msg
|
||||
if chunk['event'] == 'message_file':
|
||||
if chunk['type'] == 'image' and chunk['belongs_to'] == 'assistant':
|
||||
base_url = self.dify_client.base_url
|
||||
# 检查URL是否已经是完整的连接
|
||||
if chunk['url'].startswith('http://') or chunk['url'].startswith('https://'):
|
||||
image_url = chunk['url']
|
||||
else:
|
||||
base_url = self.dify_client.base_url
|
||||
|
||||
if base_url.endswith('/v1'):
|
||||
base_url = base_url[:-3]
|
||||
if base_url.endswith('/v1'):
|
||||
base_url = base_url[:-3]
|
||||
|
||||
image_url = base_url + chunk['url']
|
||||
image_url = base_url + chunk['url']
|
||||
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
@@ -529,7 +533,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
think_end = True
|
||||
elif think_end or not think_start:
|
||||
pending_agent_message += chunk['answer']
|
||||
if think_start:
|
||||
if think_start and not think_end:
|
||||
continue
|
||||
|
||||
else:
|
||||
@@ -559,12 +563,16 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
if chunk['event'] == 'message_file':
|
||||
message_idx += 1
|
||||
if chunk['type'] == 'image' and chunk['belongs_to'] == 'assistant':
|
||||
base_url = self.dify_client.base_url
|
||||
# 检查URL是否已经是完整的连接
|
||||
if chunk['url'].startswith('http://') or chunk['url'].startswith('https://'):
|
||||
image_url = chunk['url']
|
||||
else:
|
||||
base_url = self.dify_client.base_url
|
||||
|
||||
if base_url.endswith('/v1'):
|
||||
base_url = base_url[:-3]
|
||||
if base_url.endswith('/v1'):
|
||||
base_url = base_url[:-3]
|
||||
|
||||
image_url = base_url + chunk['url']
|
||||
image_url = base_url + chunk['url']
|
||||
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
|
||||
@@ -130,7 +130,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
if not is_stream:
|
||||
# 非流式输出,直接请求
|
||||
|
||||
msg = await use_llm_model.requester.invoke_llm(
|
||||
msg = await use_llm_model.provider.invoke_llm(
|
||||
query,
|
||||
use_llm_model,
|
||||
req_messages,
|
||||
@@ -147,7 +147,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
accumulated_content = '' # 从开始累积的所有内容
|
||||
last_role = 'assistant'
|
||||
msg_sequence = 1
|
||||
async for msg in use_llm_model.requester.invoke_llm_stream(
|
||||
async for msg in use_llm_model.provider.invoke_llm_stream(
|
||||
query,
|
||||
use_llm_model,
|
||||
req_messages,
|
||||
@@ -212,19 +212,34 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
try:
|
||||
func = tool_call.function
|
||||
|
||||
parameters = json.loads(func.arguments)
|
||||
if func.arguments:
|
||||
parameters = json.loads(func.arguments)
|
||||
else:
|
||||
parameters = {}
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters, query=query)
|
||||
|
||||
# Handle return value content
|
||||
tool_content = None
|
||||
if (
|
||||
isinstance(func_ret, list)
|
||||
and len(func_ret) > 0
|
||||
and isinstance(func_ret[0], provider_message.ContentElement)
|
||||
):
|
||||
tool_content = func_ret
|
||||
else:
|
||||
tool_content = json.dumps(func_ret, ensure_ascii=False)
|
||||
|
||||
if is_stream:
|
||||
msg = provider_message.MessageChunk(
|
||||
role='tool',
|
||||
content=json.dumps(func_ret, ensure_ascii=False),
|
||||
content=tool_content,
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
else:
|
||||
msg = provider_message.Message(
|
||||
role='tool',
|
||||
content=json.dumps(func_ret, ensure_ascii=False),
|
||||
content=tool_content,
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
|
||||
@@ -250,7 +265,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
last_role = 'assistant'
|
||||
msg_sequence = first_end_sequence
|
||||
|
||||
async for msg in use_llm_model.requester.invoke_llm_stream(
|
||||
async for msg in use_llm_model.provider.invoke_llm_stream(
|
||||
query,
|
||||
use_llm_model,
|
||||
req_messages,
|
||||
@@ -306,7 +321,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
)
|
||||
else:
|
||||
# 处理完所有调用,再次请求
|
||||
msg = await use_llm_model.requester.invoke_llm(
|
||||
msg = await use_llm_model.provider.invoke_llm(
|
||||
query,
|
||||
use_llm_model,
|
||||
req_messages,
|
||||
|
||||
@@ -68,15 +68,16 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
return plain_text
|
||||
|
||||
async def _process_stream_response(self, response: aiohttp.ClientResponse) -> typing.AsyncGenerator[
|
||||
provider_message.Message, None]:
|
||||
async def _process_stream_response(
|
||||
self, response: aiohttp.ClientResponse
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""处理流式响应——支持部分 JSON 和多个 JSON 对象在同一 chunk 的情况"""
|
||||
full_content = ""
|
||||
full_content = ''
|
||||
chunk_idx = 0
|
||||
is_final = False
|
||||
message_idx = 0
|
||||
|
||||
buffer = ""
|
||||
buffer = ''
|
||||
decoder = json.JSONDecoder()
|
||||
|
||||
async for raw_chunk in response.content.iter_chunked(1024):
|
||||
@@ -129,7 +130,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
preview = chunk_str[:200]
|
||||
except Exception:
|
||||
preview = '<unavailable>'
|
||||
self.ap.logger.warning(f"Failed to process chunk: {e}; chunk preview: {preview}")
|
||||
self.ap.logger.warning(f'Failed to process chunk: {e}; chunk preview: {preview}')
|
||||
|
||||
# 流结束后,尝试解析残余 buffer
|
||||
if buffer:
|
||||
@@ -151,7 +152,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
)
|
||||
except Exception as e:
|
||||
preview = buffer[:200]
|
||||
self.ap.logger.warning(f"Failed to parse remaining buffer: {e}; buffer preview: {preview}")
|
||||
self.ap.logger.warning(f'Failed to parse remaining buffer: {e}; buffer preview: {preview}')
|
||||
|
||||
async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""调用n8n webhook"""
|
||||
@@ -165,7 +166,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
# 准备请求数据
|
||||
payload = {
|
||||
# 基本消息内容
|
||||
'chatInput' :plain_text, # 考虑到之前用户直接用的message model这里添加新键
|
||||
'chatInput': plain_text, # 考虑到之前用户直接用的message model这里添加新键
|
||||
'message': plain_text,
|
||||
'user_message_text': plain_text,
|
||||
'conversation_id': query.session.using_conversation.uuid,
|
||||
@@ -217,57 +218,49 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
# 调用webhook
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if is_stream:
|
||||
# 流式请求
|
||||
async with session.post(
|
||||
self.webhook_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
auth=auth,
|
||||
timeout=self.timeout
|
||||
) as response:
|
||||
if is_stream:
|
||||
# 流式请求
|
||||
async with session.post(
|
||||
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
|
||||
# 处理流式响应
|
||||
async for chunk in self._process_stream_response(response):
|
||||
yield chunk
|
||||
else:
|
||||
async with session.post(
|
||||
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||
) as response:
|
||||
try:
|
||||
async for chunk in self._process_stream_response(response):
|
||||
output_content = chunk.content if chunk.is_final else ''
|
||||
except:
|
||||
# 非流式请求(保持原有逻辑)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
|
||||
# 处理流式响应
|
||||
async for chunk in self._process_stream_response(response):
|
||||
yield chunk
|
||||
else:
|
||||
async with session.post(
|
||||
self.webhook_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
auth=auth,
|
||||
timeout=self.timeout
|
||||
) as response:
|
||||
try:
|
||||
async for chunk in self._process_stream_response(response):
|
||||
output_content = chunk.content if chunk.is_final else ''
|
||||
except:
|
||||
# 非流式请求(保持原有逻辑)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
# 解析响应
|
||||
response_data = await response.json()
|
||||
self.ap.logger.debug(f'n8n webhook response: {response_data}')
|
||||
|
||||
# 解析响应
|
||||
response_data = await response.json()
|
||||
self.ap.logger.debug(f'n8n webhook response: {response_data}')
|
||||
# 从响应中提取输出
|
||||
if self.output_key in response_data:
|
||||
output_content = response_data[self.output_key]
|
||||
else:
|
||||
# 如果没有指定的输出键,则使用整个响应
|
||||
output_content = json.dumps(response_data, ensure_ascii=False)
|
||||
|
||||
# 从响应中提取输出
|
||||
if self.output_key in response_data:
|
||||
output_content = response_data[self.output_key]
|
||||
else:
|
||||
# 如果没有指定的输出键,则使用整个响应
|
||||
output_content = json.dumps(response_data, ensure_ascii=False)
|
||||
|
||||
# 返回消息
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=output_content,
|
||||
)
|
||||
# 返回消息
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=output_content,
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'n8n webhook call exception: {str(e)}')
|
||||
raise N8nAPIError(f'n8n webhook call exception: {str(e)}')
|
||||
@@ -275,4 +268,4 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""运行请求"""
|
||||
async for msg in self._call_webhook(query):
|
||||
yield msg
|
||||
yield msg
|
||||
|
||||
@@ -7,14 +7,18 @@ import traceback
|
||||
from langbot_plugin.api.entities.events import pipeline_query
|
||||
import sqlalchemy
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
import uuid as uuid_module
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
|
||||
from .. import loader
|
||||
from ....core import app
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
from ....entity.persistence import mcp as persistence_mcp
|
||||
|
||||
|
||||
@@ -35,7 +39,7 @@ class RuntimeMCPSession:
|
||||
|
||||
server_config: dict
|
||||
|
||||
session: ClientSession
|
||||
session: ClientSession | None
|
||||
|
||||
exit_stack: AsyncExitStack
|
||||
|
||||
@@ -52,6 +56,8 @@ class RuntimeMCPSession:
|
||||
|
||||
_ready_event: asyncio.Event
|
||||
|
||||
error_message: str | None = None
|
||||
|
||||
def __init__(self, server_name: str, server_config: dict, enable: bool, ap: app.Application):
|
||||
self.server_name = server_name
|
||||
self.server_uuid = server_config.get('uuid', '')
|
||||
@@ -100,6 +106,24 @@ class RuntimeMCPSession:
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
async def _init_streamable_http_server(self):
|
||||
transport = await self.exit_stack.enter_async_context(
|
||||
streamable_http_client(
|
||||
self.server_config['url'],
|
||||
http_client=httpx.AsyncClient(
|
||||
headers=self.server_config.get('headers', {}),
|
||||
timeout=self.server_config.get('timeout', 10),
|
||||
follow_redirects=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
read, write, _ = transport
|
||||
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
async def _lifecycle_loop(self):
|
||||
"""在后台任务中管理整个MCP会话的生命周期"""
|
||||
try:
|
||||
@@ -107,6 +131,8 @@ class RuntimeMCPSession:
|
||||
await self._init_stdio_python_server()
|
||||
elif self.server_config['mode'] == 'sse':
|
||||
await self._init_sse_server()
|
||||
elif self.server_config['mode'] == 'http':
|
||||
await self._init_streamable_http_server()
|
||||
else:
|
||||
raise ValueError(f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}')
|
||||
|
||||
@@ -122,6 +148,7 @@ class RuntimeMCPSession:
|
||||
|
||||
except Exception as e:
|
||||
self.status = MCPSessionStatus.ERROR
|
||||
self.error_message = str(e)
|
||||
self.ap.logger.error(f'Error in MCP session lifecycle {self.server_name}: {e}\n{traceback.format_exc()}')
|
||||
# 即使出错也要设置ready事件,让start()方法知道初始化已完成
|
||||
self._ready_event.set()
|
||||
@@ -154,6 +181,9 @@ class RuntimeMCPSession:
|
||||
raise Exception('Connection failed, please check URL')
|
||||
|
||||
async def refresh(self):
|
||||
if not self.session:
|
||||
return
|
||||
|
||||
self.functions.clear()
|
||||
|
||||
tools = await self.session.list_tools()
|
||||
@@ -163,18 +193,36 @@ class RuntimeMCPSession:
|
||||
for tool in tools.tools:
|
||||
|
||||
async def func(*, _tool=tool, **kwargs):
|
||||
if not self.session:
|
||||
raise Exception('MCP session is not connected')
|
||||
|
||||
result = await self.session.call_tool(_tool.name, kwargs)
|
||||
if result.isError:
|
||||
raise Exception(result.content[0].text)
|
||||
return result.content[0].text
|
||||
error_texts = []
|
||||
for content in result.content:
|
||||
if content.type == 'text':
|
||||
error_texts.append(content.text)
|
||||
raise Exception('\n'.join(error_texts) if error_texts else 'Unknown error from MCP tool')
|
||||
|
||||
result_contents: list[provider_message.ContentElement] = []
|
||||
for content in result.content:
|
||||
if content.type == 'text':
|
||||
result_contents.append(provider_message.ContentElement.from_text(content.text))
|
||||
elif content.type == 'image':
|
||||
result_contents.append(provider_message.ContentElement.from_image_base64(content.image_base64))
|
||||
elif content.type == 'resource':
|
||||
# TODO: Handle resource content
|
||||
pass
|
||||
|
||||
return result_contents
|
||||
|
||||
func.__name__ = tool.name
|
||||
|
||||
self.functions.append(
|
||||
resource_tool.LLMTool(
|
||||
name=tool.name,
|
||||
human_desc=tool.description,
|
||||
description=tool.description,
|
||||
human_desc=tool.description or '',
|
||||
description=tool.description or '',
|
||||
parameters=tool.inputSchema,
|
||||
func=func,
|
||||
)
|
||||
@@ -186,6 +234,7 @@ class RuntimeMCPSession:
|
||||
def get_runtime_info_dict(self) -> dict:
|
||||
return {
|
||||
'status': self.status.value,
|
||||
'error_message': self.error_message,
|
||||
'tool_count': len(self.get_tools()),
|
||||
'tools': [
|
||||
{
|
||||
@@ -287,6 +336,11 @@ class MCPLoader(loader.ToolLoader):
|
||||
- enable: 是否启用
|
||||
- extra_args: 额外的配置参数 (可选)
|
||||
"""
|
||||
uuid_ = server_config.get('uuid')
|
||||
if not uuid_:
|
||||
self.ap.logger.warning('Server UUID is None for MCP server, maybe testing in the config page.')
|
||||
uuid_ = str(uuid_module.uuid4())
|
||||
server_config['uuid'] = uuid_
|
||||
|
||||
name = server_config['name']
|
||||
uuid = server_config['uuid']
|
||||
|
||||
@@ -32,12 +32,20 @@ class Embedder(BaseService):
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts))
|
||||
|
||||
# get embeddings
|
||||
embeddings_list: list[list[float]] = await embedding_model.requester.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=chunks,
|
||||
extra_args={}, # TODO: add extra args
|
||||
)
|
||||
# get embeddings (batch size limit: 64 for OpenAI)
|
||||
MAX_BATCH_SIZE = 64
|
||||
embeddings_list: list[list[float]] = []
|
||||
|
||||
for i in range(0, len(chunks), MAX_BATCH_SIZE):
|
||||
batch = chunks[i : i + MAX_BATCH_SIZE]
|
||||
batch_embeddings = await embedding_model.provider.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=batch,
|
||||
extra_args={}, # TODO: add extra args
|
||||
knowledge_base_id=kb_id,
|
||||
call_type='embedding',
|
||||
)
|
||||
embeddings_list.extend(batch_embeddings)
|
||||
|
||||
# save embeddings to vdb
|
||||
await self.ap.vector_db_mgr.vector_db.add_embeddings(kb_id, chunk_ids, embeddings_list, chunk_dicts)
|
||||
|
||||
@@ -19,10 +19,13 @@ class Retriever(base_service.BaseService):
|
||||
f"Retrieving for query: '{query[:10]}' with k={k} using {embedding_model.model_entity.uuid}"
|
||||
)
|
||||
|
||||
query_embedding: list[float] = await embedding_model.requester.invoke_embedding(
|
||||
query_embedding: list[float] = await embedding_model.provider.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=[query],
|
||||
extra_args={}, # TODO: add extra args
|
||||
knowledge_base_id=kb_id,
|
||||
query_text=query,
|
||||
call_type='retrieve',
|
||||
)
|
||||
|
||||
vector_results = await self.ap.vector_db_mgr.vector_db.search(kb_id, query_embedding[0], k)
|
||||
|
||||
0
src/langbot/pkg/telemetry/__init__.py
Normal file
0
src/langbot/pkg/telemetry/__init__.py
Normal file
121
src/langbot/pkg/telemetry/telemetry.py
Normal file
121
src/langbot/pkg/telemetry/telemetry.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
from ..core import app as core_app
|
||||
|
||||
|
||||
class TelemetryManager:
|
||||
"""TelemetryManager handles sending telemetry for a given application instance.
|
||||
|
||||
Usage:
|
||||
telemetry = TelemetryManager(ap)
|
||||
await telemetry.send({ ... })
|
||||
"""
|
||||
|
||||
send_tasks: list[asyncio.Task] = []
|
||||
|
||||
def __init__(self, ap: core_app.Application):
|
||||
self.ap = ap
|
||||
|
||||
self.telemetry_config = {}
|
||||
|
||||
async def initialize(self):
|
||||
self.telemetry_config = self.ap.instance_config.data.get('space', {})
|
||||
|
||||
async def start_send_task(self, payload: dict):
|
||||
task = asyncio.create_task(self.send(payload))
|
||||
self.send_tasks.append(task)
|
||||
|
||||
async def send(self, payload: dict):
|
||||
"""Send telemetry payload to configured telemetry server (non-blocking).
|
||||
|
||||
Expects ap.instance_config.data.telemetry to have:
|
||||
- enabled: bool
|
||||
- server: str (base URL, e.g. https://space.example.com)
|
||||
- timeout_seconds: optional int, overall request timeout (default 10)
|
||||
|
||||
Posts to {server.rstrip('/')}/api/v1/telemetry as JSON. Failures are logged but do not raise.
|
||||
"""
|
||||
|
||||
try:
|
||||
cfg = self.telemetry_config
|
||||
if not cfg:
|
||||
return
|
||||
if cfg.get('disable_telemetry', False):
|
||||
return
|
||||
server = cfg.get('url', '')
|
||||
if not server:
|
||||
return
|
||||
|
||||
# Normalize URL
|
||||
url = server.rstrip('/') + '/api/v1/telemetry'
|
||||
|
||||
try:
|
||||
# Sanitize payload so string fields are strings and not nulls
|
||||
sanitized = dict(payload)
|
||||
if 'query_id' in sanitized:
|
||||
try:
|
||||
sanitized['query_id'] = '' if sanitized['query_id'] is None else str(sanitized['query_id'])
|
||||
except Exception:
|
||||
sanitized['query_id'] = str(sanitized.get('query_id', ''))
|
||||
|
||||
for sfield in ('adapter', 'runner', 'model_name', 'version', 'error', 'timestamp'):
|
||||
v = sanitized.get(sfield)
|
||||
sanitized[sfield] = '' if v is None else str(v)
|
||||
|
||||
if 'duration_ms' in sanitized:
|
||||
try:
|
||||
sanitized['duration_ms'] = (
|
||||
int(sanitized['duration_ms']) if sanitized['duration_ms'] is not None else 0
|
||||
)
|
||||
except Exception:
|
||||
sanitized['duration_ms'] = 0
|
||||
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(10)) as client:
|
||||
try:
|
||||
# Use asyncio.wait_for to ensure we always bound the total time
|
||||
resp = await asyncio.wait_for(client.post(url, json=sanitized), timeout=10 + 1)
|
||||
|
||||
if resp.status_code >= 400:
|
||||
self.ap.logger.warning(
|
||||
f'Telemetry post to {url} returned status {resp.status_code} - {resp.text}'
|
||||
)
|
||||
else:
|
||||
# Detect application-level errors inside HTTP 200 responses
|
||||
app_err = False
|
||||
try:
|
||||
j = resp.json()
|
||||
if isinstance(j, dict) and j.get('code') is not None and int(j.get('code')) >= 400:
|
||||
app_err = True
|
||||
self.ap.logger.warning(
|
||||
f'Telemetry post to {url} returned application error code {j.get("code")} - {j.get("msg")}'
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if app_err:
|
||||
self.ap.logger.warning(
|
||||
f'Telemetry post to {url} returned app-level error - response: {resp.text[:200]}'
|
||||
)
|
||||
else:
|
||||
self.ap.logger.debug(
|
||||
f'Telemetry posted to {url}, status {resp.status_code} - response: {resp.text[:200]}'
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
self.ap.logger.warning(f'Telemetry post to {url} timed out')
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to post telemetry to {url}: {e}', exc_info=True)
|
||||
except Exception as e:
|
||||
try:
|
||||
self.ap.logger.warning(
|
||||
f'Failed to create HTTP client for telemetry or sanitize payload: {e}', exc_info=True
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
# Never raise from telemetry; surface as warning for visibility
|
||||
try:
|
||||
self.ap.logger.warning(f'Unexpected telemetry error: {e}', exc_info=True)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -2,9 +2,11 @@ import langbot
|
||||
|
||||
semantic_version = f'v{langbot.__version__}'
|
||||
|
||||
required_database_version = 13
|
||||
required_database_version = 18
|
||||
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
|
||||
|
||||
debug_mode = False
|
||||
|
||||
edition = 'community'
|
||||
|
||||
instance_id = ''
|
||||
|
||||
@@ -37,7 +37,8 @@ class VectorDBManager:
|
||||
milvus_config = kb_config.get('milvus', {})
|
||||
uri = milvus_config.get('uri', './data/milvus.db')
|
||||
token = milvus_config.get('token')
|
||||
self.vector_db = MilvusVectorDatabase(self.ap, uri=uri, token=token)
|
||||
db_name = milvus_config.get('db_name', 'default')
|
||||
self.vector_db = MilvusVectorDatabase(self.ap, uri=uri, token=token, db_name=db_name)
|
||||
self.ap.logger.info('Initialized Milvus vector database backend.')
|
||||
|
||||
elif vdb_type == 'pgvector':
|
||||
@@ -54,12 +55,7 @@ class VectorDBManager:
|
||||
user = pgvector_config.get('user', 'postgres')
|
||||
password = pgvector_config.get('password', 'postgres')
|
||||
self.vector_db = PgVectorDatabase(
|
||||
self.ap,
|
||||
host=host,
|
||||
port=port,
|
||||
database=database,
|
||||
user=user,
|
||||
password=password
|
||||
self.ap, host=host, port=port, database=database, user=user, password=password
|
||||
)
|
||||
self.ap.logger.info('Initialized pgvector database backend.')
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from pymilvus import MilvusClient, DataType
|
||||
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema
|
||||
from pymilvus.milvus_client.index import IndexParams
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
from langbot.pkg.core import app
|
||||
|
||||
@@ -9,7 +10,7 @@ from langbot.pkg.core import app
|
||||
class MilvusVectorDatabase(VectorDatabase):
|
||||
"""Milvus vector database implementation"""
|
||||
|
||||
def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None):
|
||||
def __init__(self, ap: app.Application, uri: str = 'milvus.db', token: str = None, db_name: str = None):
|
||||
"""Initialize Milvus vector database
|
||||
|
||||
Args:
|
||||
@@ -21,77 +22,133 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
self.ap = ap
|
||||
self.uri = uri
|
||||
self.token = token
|
||||
self.db_name = db_name
|
||||
self.client = None
|
||||
self._collections = {}
|
||||
self._collections: set[str] = set()
|
||||
self._initialize_client()
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize Milvus client connection"""
|
||||
try:
|
||||
if self.token:
|
||||
self.client = MilvusClient(uri=self.uri, token=self.token)
|
||||
self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name)
|
||||
else:
|
||||
self.client = MilvusClient(uri=self.uri)
|
||||
self.ap.logger.info(f"Connected to Milvus at {self.uri}")
|
||||
self.client = MilvusClient(uri=self.uri, db_name=self.db_name)
|
||||
self.ap.logger.info(f'Connected to Milvus at {self.uri}')
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Failed to connect to Milvus: {e}")
|
||||
self.ap.logger.error(f'Failed to connect to Milvus: {e}')
|
||||
raise
|
||||
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
"""Get or create a Milvus collection
|
||||
@staticmethod
|
||||
def _normalize_collection_name(collection: str) -> str:
|
||||
"""Normalize collection name to comply with Milvus naming requirements.
|
||||
|
||||
Milvus requirements:
|
||||
- First character must be an underscore or letter
|
||||
- Can only contain numbers, letters and underscores
|
||||
|
||||
Args:
|
||||
collection: Original collection name (e.g., UUID with hyphens)
|
||||
|
||||
Returns:
|
||||
Normalized collection name that complies with Milvus requirements
|
||||
"""
|
||||
# Replace hyphens with underscores
|
||||
normalized = collection.replace('-', '_')
|
||||
|
||||
# If first character is not a letter or underscore, prepend 'kb_'
|
||||
if normalized and not (normalized[0].isalpha() or normalized[0] == '_'):
|
||||
normalized = 'kb_' + normalized
|
||||
|
||||
return normalized
|
||||
|
||||
async def _ensure_vector_index(self, collection: str) -> None:
|
||||
"""Ensure the vector field has an index.
|
||||
|
||||
Args:
|
||||
collection: Normalized collection name
|
||||
"""
|
||||
index_params = IndexParams()
|
||||
index_params.add_index(
|
||||
field_name='vector',
|
||||
index_type='AUTOINDEX',
|
||||
metric_type='COSINE',
|
||||
)
|
||||
await asyncio.to_thread(self.client.create_index, collection_name=collection, index_params=index_params)
|
||||
|
||||
async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None):
|
||||
"""Internal method to get or create a Milvus collection with proper configuration.
|
||||
|
||||
Args:
|
||||
collection: Collection name (corresponds to knowledge base UUID)
|
||||
vector_size: Dimension of the vectors (if None, defaults to 1536)
|
||||
"""
|
||||
# Normalize collection name for Milvus compatibility
|
||||
collection = self._normalize_collection_name(collection)
|
||||
|
||||
if collection in self._collections:
|
||||
return self._collections[collection]
|
||||
return collection
|
||||
|
||||
# Check if collection exists
|
||||
has_collection = await asyncio.to_thread(
|
||||
self.client.has_collection, collection_name=collection
|
||||
)
|
||||
has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
|
||||
|
||||
if not has_collection:
|
||||
# Create collection with custom schema to support string IDs
|
||||
from pymilvus import CollectionSchema, FieldSchema, DataType
|
||||
# Default dimension if not specified (for backward compatibility)
|
||||
if vector_size is None:
|
||||
vector_size = 1536
|
||||
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=255),
|
||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1536),
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="file_id", dtype=DataType.VARCHAR, max_length=255),
|
||||
FieldSchema(name="chunk_uuid", dtype=DataType.VARCHAR, max_length=255),
|
||||
FieldSchema(name='id', dtype=DataType.VARCHAR, is_primary=True, max_length=255),
|
||||
FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=vector_size),
|
||||
FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name='file_id', dtype=DataType.VARCHAR, max_length=255),
|
||||
FieldSchema(name='chunk_uuid', dtype=DataType.VARCHAR, max_length=255),
|
||||
]
|
||||
|
||||
schema = CollectionSchema(fields=fields, description="LangBot knowledge base vectors")
|
||||
schema = CollectionSchema(fields=fields, description='LangBot knowledge base vectors')
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.client.create_collection,
|
||||
collection_name=collection,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
metric_type='COSINE',
|
||||
)
|
||||
|
||||
# Create index for vector field (required for loading/searching)
|
||||
index_params = {
|
||||
"metric_type": "COSINE",
|
||||
"index_type": "AUTOINDEX",
|
||||
"params": {}
|
||||
}
|
||||
await asyncio.to_thread(
|
||||
self.client.create_index,
|
||||
collection_name=collection,
|
||||
field_name="vector",
|
||||
index_params=index_params
|
||||
await self._ensure_vector_index(collection)
|
||||
self.ap.logger.info(
|
||||
f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX"
|
||||
)
|
||||
|
||||
self.ap.logger.info(f"Created Milvus collection '{collection}' with index")
|
||||
else:
|
||||
# Ensure index exists for existing collection
|
||||
await self._ensure_index_if_missing(collection)
|
||||
self.ap.logger.info(f"Milvus collection '{collection}' already exists")
|
||||
|
||||
self._collections[collection] = collection
|
||||
self._collections.add(collection)
|
||||
return collection
|
||||
|
||||
async def _ensure_index_if_missing(self, collection: str) -> None:
|
||||
"""Check if index exists for collection and create if missing.
|
||||
|
||||
Args:
|
||||
collection: Normalized collection name
|
||||
"""
|
||||
try:
|
||||
indexes = await asyncio.to_thread(self.client.list_indexes, collection_name=collection)
|
||||
if 'vector' not in indexes:
|
||||
await self._ensure_vector_index(collection)
|
||||
self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'")
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f"Could not verify/create index for collection '{collection}': {e}")
|
||||
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
"""Get or create a Milvus collection (without vector size - will use default).
|
||||
|
||||
Args:
|
||||
collection: Collection name (corresponds to knowledge base UUID)
|
||||
"""
|
||||
collection = self._normalize_collection_name(collection)
|
||||
return await self._get_or_create_collection_internal(collection)
|
||||
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
@@ -107,45 +164,43 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
embeddings_list: List of embedding vectors
|
||||
metadatas: List of metadata dictionaries for each vector
|
||||
"""
|
||||
await self.get_or_create_collection(collection)
|
||||
collection = self._normalize_collection_name(collection)
|
||||
|
||||
if not embeddings_list:
|
||||
return
|
||||
|
||||
# Ensure collection exists with correct dimension
|
||||
vector_size = len(embeddings_list[0])
|
||||
await self._get_or_create_collection_internal(collection, vector_size)
|
||||
|
||||
# Prepare data in Milvus format
|
||||
data = []
|
||||
for i, vector_id in enumerate(ids):
|
||||
entry = {
|
||||
"id": vector_id,
|
||||
"vector": embeddings_list[i],
|
||||
'id': vector_id,
|
||||
'vector': embeddings_list[i],
|
||||
}
|
||||
# Add metadata fields
|
||||
if metadatas and i < len(metadatas):
|
||||
metadata = metadatas[i]
|
||||
# Add common metadata fields
|
||||
if "text" in metadata:
|
||||
entry["text"] = metadata["text"]
|
||||
if "file_id" in metadata:
|
||||
entry["file_id"] = metadata["file_id"]
|
||||
if "uuid" in metadata:
|
||||
entry["chunk_uuid"] = metadata["uuid"]
|
||||
if 'text' in metadata:
|
||||
entry['text'] = metadata['text']
|
||||
if 'file_id' in metadata:
|
||||
entry['file_id'] = metadata['file_id']
|
||||
if 'uuid' in metadata:
|
||||
entry['chunk_uuid'] = metadata['uuid']
|
||||
data.append(entry)
|
||||
|
||||
# Insert data into Milvus
|
||||
await asyncio.to_thread(
|
||||
self.client.insert,
|
||||
collection_name=collection,
|
||||
data=data
|
||||
)
|
||||
await asyncio.to_thread(self.client.insert, collection_name=collection, data=data)
|
||||
|
||||
# Load collection for searching (Milvus requires this)
|
||||
await asyncio.to_thread(
|
||||
self.client.load_collection,
|
||||
collection_name=collection
|
||||
)
|
||||
await asyncio.to_thread(self.client.load_collection, collection_name=collection)
|
||||
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'")
|
||||
|
||||
async def search(
|
||||
self, collection: str, query_embedding: list[float], k: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
|
||||
"""Search for similar vectors in Milvus collection
|
||||
|
||||
Args:
|
||||
@@ -156,13 +211,11 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
Returns:
|
||||
Dictionary with search results in Chroma-compatible format
|
||||
"""
|
||||
collection = self._normalize_collection_name(collection)
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
# Perform search
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {}
|
||||
}
|
||||
search_params = {'metric_type': 'COSINE', 'params': {}}
|
||||
|
||||
results = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
@@ -170,7 +223,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
data=[query_embedding],
|
||||
limit=k,
|
||||
search_params=search_params,
|
||||
output_fields=["text", "file_id", "chunk_uuid"]
|
||||
output_fields=['text', 'file_id', 'chunk_uuid'],
|
||||
)
|
||||
|
||||
# Convert results to Chroma-compatible format
|
||||
@@ -181,30 +234,24 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
|
||||
if results and len(results) > 0:
|
||||
for hit in results[0]:
|
||||
ids.append(hit.get("id", ""))
|
||||
distances.append(hit.get("distance", 0.0))
|
||||
ids.append(hit.get('id', ''))
|
||||
distances.append(hit.get('distance', 0.0))
|
||||
|
||||
# Build metadata from entity fields
|
||||
entity = hit.get("entity", {})
|
||||
entity = hit.get('entity', {})
|
||||
metadata = {}
|
||||
if "text" in entity:
|
||||
metadata["text"] = entity["text"]
|
||||
if "file_id" in entity:
|
||||
metadata["file_id"] = entity["file_id"]
|
||||
if "chunk_uuid" in entity:
|
||||
metadata["uuid"] = entity["chunk_uuid"]
|
||||
if 'text' in entity:
|
||||
metadata['text'] = entity['text']
|
||||
if 'file_id' in entity:
|
||||
metadata['file_id'] = entity['file_id']
|
||||
if 'chunk_uuid' in entity:
|
||||
metadata['uuid'] = entity['chunk_uuid']
|
||||
metadatas.append(metadata)
|
||||
|
||||
# Return in Chroma-compatible format (nested lists)
|
||||
result = {
|
||||
"ids": [ids],
|
||||
"distances": [distances],
|
||||
"metadatas": [metadatas]
|
||||
}
|
||||
result = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]}
|
||||
|
||||
self.ap.logger.info(
|
||||
f"Milvus search in '{collection}' returned {len(ids)} results"
|
||||
)
|
||||
self.ap.logger.info(f"Milvus search in '{collection}' returned {len(ids)} results")
|
||||
return result
|
||||
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
@@ -214,17 +261,12 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
collection: Collection name
|
||||
file_id: File ID to filter deletion
|
||||
"""
|
||||
collection = self._normalize_collection_name(collection)
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
# Delete entities matching the file_id
|
||||
await asyncio.to_thread(
|
||||
self.client.delete,
|
||||
collection_name=collection,
|
||||
filter=f'file_id == "{file_id}"'
|
||||
)
|
||||
self.ap.logger.info(
|
||||
f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}"
|
||||
)
|
||||
await asyncio.to_thread(self.client.delete, collection_name=collection, filter=f'file_id == "{file_id}"')
|
||||
self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}")
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
"""Delete a Milvus collection
|
||||
@@ -232,18 +274,15 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
Args:
|
||||
collection: Collection name to delete
|
||||
"""
|
||||
if collection in self._collections:
|
||||
del self._collections[collection]
|
||||
collection = self._normalize_collection_name(collection)
|
||||
|
||||
self._collections.discard(collection)
|
||||
|
||||
# Check if collection exists before attempting deletion
|
||||
has_collection = await asyncio.to_thread(
|
||||
self.client.has_collection, collection_name=collection
|
||||
)
|
||||
has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
|
||||
|
||||
if has_collection:
|
||||
await asyncio.to_thread(
|
||||
self.client.drop_collection, collection_name=collection
|
||||
)
|
||||
await asyncio.to_thread(self.client.drop_collection, collection_name=collection)
|
||||
self.ap.logger.info(f"Deleted Milvus collection '{collection}'")
|
||||
else:
|
||||
self.ap.logger.warning(f"Milvus collection '{collection}' not found")
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from sqlalchemy import create_engine, text, Column, String, Text
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, Session
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
from langbot.pkg.core import app
|
||||
import uuid
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class PgVectorEntry(Base):
|
||||
"""SQLAlchemy model for pgvector entries"""
|
||||
|
||||
__tablename__ = 'langbot_vectors'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
@@ -31,11 +30,11 @@ class PgVectorDatabase(VectorDatabase):
|
||||
self,
|
||||
ap: app.Application,
|
||||
connection_string: str = None,
|
||||
host: str = "localhost",
|
||||
host: str = 'localhost',
|
||||
port: int = 5432,
|
||||
database: str = "langbot",
|
||||
user: str = "postgres",
|
||||
password: str = "postgres"
|
||||
database: str = 'langbot',
|
||||
user: str = 'postgres',
|
||||
password: str = 'postgres',
|
||||
):
|
||||
"""Initialize pgvector database
|
||||
|
||||
@@ -54,14 +53,10 @@ class PgVectorDatabase(VectorDatabase):
|
||||
if connection_string:
|
||||
self.connection_string = connection_string
|
||||
else:
|
||||
self.connection_string = (
|
||||
f"postgresql+psycopg://{user}:{password}@{host}:{port}/{database}"
|
||||
)
|
||||
self.connection_string = f'postgresql+psycopg://{user}:{password}@{host}:{port}/{database}'
|
||||
|
||||
self.async_connection_string = self.connection_string.replace(
|
||||
"postgresql://", "postgresql+asyncpg://"
|
||||
).replace(
|
||||
"postgresql+psycopg://", "postgresql+asyncpg://"
|
||||
self.async_connection_string = self.connection_string.replace('postgresql://', 'postgresql+asyncpg://').replace(
|
||||
'postgresql+psycopg://', 'postgresql+asyncpg://'
|
||||
)
|
||||
|
||||
self.engine = None
|
||||
@@ -75,35 +70,25 @@ class PgVectorDatabase(VectorDatabase):
|
||||
"""Initialize database connection and create tables"""
|
||||
try:
|
||||
# Create async engine for async operations
|
||||
self.async_engine = create_async_engine(
|
||||
self.async_connection_string,
|
||||
echo=False,
|
||||
pool_pre_ping=True
|
||||
)
|
||||
self.AsyncSessionLocal = async_sessionmaker(
|
||||
self.async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
self.async_engine = create_async_engine(self.async_connection_string, echo=False, pool_pre_ping=True)
|
||||
self.AsyncSessionLocal = async_sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
# Create sync engine for table creation
|
||||
sync_connection_string = self.connection_string.replace(
|
||||
"postgresql+asyncpg://", "postgresql+psycopg://"
|
||||
)
|
||||
sync_connection_string = self.connection_string.replace('postgresql+asyncpg://', 'postgresql+psycopg://')
|
||||
self.engine = create_engine(sync_connection_string, echo=False)
|
||||
|
||||
# Create pgvector extension and tables
|
||||
with self.engine.connect() as conn:
|
||||
# Enable pgvector extension
|
||||
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
|
||||
conn.commit()
|
||||
|
||||
# Create tables
|
||||
Base.metadata.create_all(self.engine)
|
||||
|
||||
self.ap.logger.info(f"Connected to PostgreSQL with pgvector")
|
||||
self.ap.logger.info('Connected to PostgreSQL with pgvector')
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Failed to connect to PostgreSQL: {e}")
|
||||
self.ap.logger.error(f'Failed to connect to PostgreSQL: {e}')
|
||||
raise
|
||||
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
@@ -144,24 +129,20 @@ class PgVectorDatabase(VectorDatabase):
|
||||
id=vector_id,
|
||||
collection=collection,
|
||||
embedding=embeddings_list[i],
|
||||
text=metadata.get("text", ""),
|
||||
file_id=metadata.get("file_id", ""),
|
||||
chunk_uuid=metadata.get("uuid", "")
|
||||
text=metadata.get('text', ''),
|
||||
file_id=metadata.get('file_id', ''),
|
||||
chunk_uuid=metadata.get('uuid', ''),
|
||||
)
|
||||
session.add(entry)
|
||||
|
||||
await session.commit()
|
||||
self.ap.logger.info(
|
||||
f"Added {len(ids)} embeddings to pgvector collection '{collection}'"
|
||||
)
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to pgvector collection '{collection}'")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.ap.logger.error(f"Error adding embeddings to pgvector: {e}")
|
||||
self.ap.logger.error(f'Error adding embeddings to pgvector: {e}')
|
||||
raise
|
||||
|
||||
async def search(
|
||||
self, collection: str, query_embedding: list[float], k: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
|
||||
"""Search for similar vectors using cosine distance
|
||||
|
||||
Args:
|
||||
@@ -177,7 +158,7 @@ class PgVectorDatabase(VectorDatabase):
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
try:
|
||||
# Use cosine distance for similarity search
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy import select
|
||||
|
||||
# Query for similar vectors
|
||||
stmt = (
|
||||
@@ -186,7 +167,7 @@ class PgVectorDatabase(VectorDatabase):
|
||||
PgVectorEntry.text,
|
||||
PgVectorEntry.file_id,
|
||||
PgVectorEntry.chunk_uuid,
|
||||
PgVectorEntry.embedding.cosine_distance(query_embedding).label('distance')
|
||||
PgVectorEntry.embedding.cosine_distance(query_embedding).label('distance'),
|
||||
)
|
||||
.filter(PgVectorEntry.collection == collection)
|
||||
.order_by(PgVectorEntry.embedding.cosine_distance(query_embedding))
|
||||
@@ -204,25 +185,17 @@ class PgVectorDatabase(VectorDatabase):
|
||||
for row in rows:
|
||||
ids.append(row.id)
|
||||
distances.append(float(row.distance))
|
||||
metadatas.append({
|
||||
"text": row.text or "",
|
||||
"file_id": row.file_id or "",
|
||||
"uuid": row.chunk_uuid or ""
|
||||
})
|
||||
metadatas.append(
|
||||
{'text': row.text or '', 'file_id': row.file_id or '', 'uuid': row.chunk_uuid or ''}
|
||||
)
|
||||
|
||||
result_dict = {
|
||||
"ids": [ids],
|
||||
"distances": [distances],
|
||||
"metadatas": [metadatas]
|
||||
}
|
||||
result_dict = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]}
|
||||
|
||||
self.ap.logger.info(
|
||||
f"pgvector search in '{collection}' returned {len(ids)} results"
|
||||
)
|
||||
self.ap.logger.info(f"pgvector search in '{collection}' returned {len(ids)} results")
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Error searching pgvector: {e}")
|
||||
self.ap.logger.error(f'Error searching pgvector: {e}')
|
||||
raise
|
||||
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
@@ -239,8 +212,7 @@ class PgVectorDatabase(VectorDatabase):
|
||||
from sqlalchemy import delete
|
||||
|
||||
stmt = delete(PgVectorEntry).where(
|
||||
PgVectorEntry.collection == collection,
|
||||
PgVectorEntry.file_id == file_id
|
||||
PgVectorEntry.collection == collection, PgVectorEntry.file_id == file_id
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
@@ -250,7 +222,7 @@ class PgVectorDatabase(VectorDatabase):
|
||||
)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.ap.logger.error(f"Error deleting from pgvector: {e}")
|
||||
self.ap.logger.error(f'Error deleting from pgvector: {e}')
|
||||
raise
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
@@ -266,16 +238,14 @@ class PgVectorDatabase(VectorDatabase):
|
||||
try:
|
||||
from sqlalchemy import delete
|
||||
|
||||
stmt = delete(PgVectorEntry).where(
|
||||
PgVectorEntry.collection == collection
|
||||
)
|
||||
stmt = delete(PgVectorEntry).where(PgVectorEntry.collection == collection)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
self.ap.logger.info(f"Deleted pgvector collection '{collection}'")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.ap.logger.error(f"Error deleting pgvector collection: {e}")
|
||||
self.ap.logger.error(f'Error deleting pgvector collection: {e}')
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
|
||||
@@ -3,10 +3,8 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from langbot.pkg.core import app
|
||||
from langbot.pkg.entity.persistence import model as persistence_model
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
|
||||
try:
|
||||
@@ -87,14 +85,16 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
self._collections: Dict[str, Any] = {}
|
||||
self._collection_configs: Dict[str, HNSWConfiguration] = {}
|
||||
|
||||
self._escape_table = str.maketrans({
|
||||
'\x00': '',
|
||||
'\\': '\\\\',
|
||||
'"': '\\"',
|
||||
'\n': '\\n',
|
||||
'\r': '\\r',
|
||||
'\t': '\\t',
|
||||
})
|
||||
self._escape_table = str.maketrans(
|
||||
{
|
||||
'\x00': '',
|
||||
'\\': '\\\\',
|
||||
'"': '\\"',
|
||||
'\n': '\\n',
|
||||
'\r': '\\r',
|
||||
'\t': '\\t',
|
||||
}
|
||||
)
|
||||
|
||||
async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None) -> Any:
|
||||
"""Internal method to get or create a collection with proper configuration."""
|
||||
@@ -133,8 +133,10 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
def _clean_metadata(self, meta: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""SeekDB metadata doesn't support \\ and ", insert will error 3104"""
|
||||
return {
|
||||
k: v.translate(self._escape_table) if isinstance(v, str)
|
||||
else v if v is None or isinstance(v, (int, float, bool))
|
||||
k: v.translate(self._escape_table)
|
||||
if isinstance(v, str)
|
||||
else v
|
||||
if v is None or isinstance(v, (int, float, bool))
|
||||
else str(v)
|
||||
for k, v in meta.items()
|
||||
if v is not None
|
||||
@@ -145,11 +147,7 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
return await self._get_or_create_collection_internal(collection)
|
||||
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: List[str],
|
||||
embeddings_list: List[List[float]],
|
||||
metadatas: List[Dict[str, Any]]
|
||||
self, collection: str, ids: List[str], embeddings_list: List[List[float]], metadatas: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Add vector embeddings to the specified collection.
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ proxy:
|
||||
https: ''
|
||||
system:
|
||||
recovery_key: ''
|
||||
allow_change_password: true
|
||||
allow_modify_login_info: true
|
||||
jwt:
|
||||
expire: 604800
|
||||
secret: ''
|
||||
@@ -51,6 +51,7 @@ vdb:
|
||||
milvus:
|
||||
uri: 'http://127.0.0.1:19530'
|
||||
token: ''
|
||||
db_name: ''
|
||||
pgvector:
|
||||
host: '127.0.0.1'
|
||||
port: 5433
|
||||
@@ -69,5 +70,13 @@ plugin:
|
||||
enable: true
|
||||
runtime_ws_url: 'ws://langbot_plugin_runtime:5400/control/ws'
|
||||
enable_marketplace: true
|
||||
cloud_service_url: 'https://space.langbot.app'
|
||||
display_plugin_debug_url: 'ws://localhost:5401/plugin/debug/ws'
|
||||
space:
|
||||
# Space service URL for OAuth and API
|
||||
url: 'https://space.langbot.app'
|
||||
# Space API URL for model requests (MaaS)
|
||||
models_gateway_api_url: 'https://api.langbot.cloud/v1'
|
||||
# OAuth authorization page URL (user will be redirected here)
|
||||
oauth_authorize_url: 'https://space.langbot.app/auth/authorize'
|
||||
disable_models_service: false
|
||||
disable_telemetry: false
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user