diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index f3389c25..71ef28fc 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -9,7 +9,7 @@ *请在方括号间写`x`以打勾 / Please tick the box with `x`* -- [ ] 阅读仓库[贡献指引](https://github.com/RockChinQ/LangBot/blob/master/CONTRIBUTING.md)了吗? / Have you read the [contribution guide](https://github.com/RockChinQ/LangBot/blob/master/CONTRIBUTING.md)? +- [ ] 阅读仓库[贡献指引](https://github.com/langbot-app/LangBot/blob/master/CONTRIBUTING.md)了吗? / Have you read the [contribution guide](https://github.com/langbot-app/LangBot/blob/master/CONTRIBUTING.md)? - [ ] 与项目所有者沟通过了吗? / Have you communicated with the project maintainer? - [ ] 我确定已自行测试所作的更改,确保功能符合预期。 / I have tested the changes and ensured they work as expected. diff --git a/README.md b/README.md index 6e0fa350..11062524 100644 --- a/README.md +++ b/README.md @@ -1,50 +1,38 @@

-LangBot +LangBot

-RockChinQ%2FLangBot | Trendshift +简体中文 / [English](README_EN.md) / [日本語](README_JP.md) / (PR for your language) + +[![Discord](https://img.shields.io/discord/1335141740050649118?logo=discord&labelColor=%20%235462eb&logoColor=%20%23f5f5f5&color=%20%235462eb)](https://discord.gg/wdNEHETs87) +[![QQ Group](https://img.shields.io/badge/%E7%A4%BE%E5%8C%BAQQ%E7%BE%A4-966235608-blue)](https://qm.qq.com/q/JLi38whHum) +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/langbot-app/LangBot) +[![GitHub release (latest by date)](https://img.shields.io/github/v/release/langbot-app/LangBot)](https://github.com/langbot-app/LangBot/releases/latest) +python +[![star](https://gitcode.com/RockChinQ/LangBot/star/badge.svg)](https://gitcode.com/RockChinQ/LangBot) 项目主页部署文档插件介绍 | -提交插件 +提交插件 -
-😎高稳定、🧩支持扩展、🦄多模态 - 大模型原生即时通信机器人平台🤖 -
- -
- -[![Discord](https://img.shields.io/discord/1335141740050649118?logo=discord&labelColor=%20%235462eb&logoColor=%20%23f5f5f5&color=%20%235462eb)](https://discord.gg/wdNEHETs87) -[![QQ Group](https://img.shields.io/badge/%E7%A4%BE%E5%8C%BAQQ%E7%BE%A4-966235608-blue)](https://qm.qq.com/q/JLi38whHum) -[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/RockChinQ/LangBot) -[![GitHub release (latest by date)](https://img.shields.io/github/v/release/RockChinQ/LangBot)](https://github.com/RockChinQ/LangBot/releases/latest) -python -[![star](https://gitcode.com/RockChinQ/LangBot/star/badge.svg)](https://gitcode.com/RockChinQ/LangBot) - -简体中文 / [English](README_EN.md) / [日本語](README_JP.md) / (PR for your language)

-## ✨ 特性 - -- 💬 大模型对话、Agent:支持多种大模型,适配群聊和私聊;具有多轮对话、工具调用、多模态能力,并深度适配 [Dify](https://dify.ai)。目前支持 QQ、QQ频道、企业微信、个人微信、飞书、Discord、Telegram 等平台。 -- 🛠️ 高稳定性、功能完备:原生支持访问控制、限速、敏感词过滤等机制;配置简单,支持多种部署方式。支持多流水线配置,不同机器人用于不同应用场景。 -- 🧩 插件扩展、活跃社区:支持事件驱动、组件扩展等插件机制;适配 Anthropic [MCP 协议](https://modelcontextprotocol.io/);目前已有数百个插件。 -- 😻 Web 管理面板:支持通过浏览器管理 LangBot 实例,不再需要手动编写配置文件。 +LangBot 是一个开源的大语言模型原生即时通信机器人开发平台,旨在提供开箱即用的 IM 机器人开发体验,具有 Agent、RAG、MCP 等多种 LLM 应用功能,适配全球主流即时通信平台,并提供丰富的 API 接口,支持自定义开发。 ## 📦 开始使用 #### Docker Compose 部署 ```bash -git clone https://github.com/RockChinQ/LangBot +git clone https://github.com/langbot-app/LangBot cd LangBot docker compose up -d ``` @@ -71,23 +59,25 @@ docker compose up -d 直接使用发行版运行,查看文档[手动部署](https://docs.langbot.app/zh/deploy/langbot/manual.html)。 -## 📸 效果展示 +## 😎 保持更新 -bots +点击仓库右上角 Star 和 Watch 按钮,获取最新动态。 -bots +![star gif](https://docs.langbot.app/star.gif) -bots +## ✨ 特性 -bots +- 💬 大模型对话、Agent:支持多种大模型,适配群聊和私聊;具有多轮对话、工具调用、多模态能力,自带 RAG(知识库)实现,并深度适配 [Dify](https://dify.ai)。 +- 🤖 多平台支持:目前支持 QQ、QQ频道、企业微信、个人微信、飞书、Discord、Telegram 等平台。 +- 🛠️ 高稳定性、功能完备:原生支持访问控制、限速、敏感词过滤等机制;配置简单,支持多种部署方式。支持多流水线配置,不同机器人用于不同应用场景。 +- 🧩 插件扩展、活跃社区:支持事件驱动、组件扩展等插件机制;适配 Anthropic [MCP 协议](https://modelcontextprotocol.io/);目前已有数百个插件。 +- 😻 Web 管理面板:支持通过浏览器管理 LangBot 实例,不再需要手动编写配置文件。 -回复效果(带有联网插件) +详细规格特性请访问[文档](https://docs.langbot.app/zh/insight/features.html)。 -- WebUI Demo: https://demo.langbot.dev/ - - 登录信息:邮箱:`demo@langbot.app` 密码:`langbot123456` - - 注意:仅展示webui效果,公开环境,请不要在其中填入您的任何敏感信息。 - -## 🔌 组件兼容性 +或访问 demo 环境:https://demo.langbot.dev/ + - 登录信息:邮箱:`demo@langbot.app` 密码:`langbot123456` + - 注意:仅展示 WebUI 效果,公开环境,请不要在其中填入您的任何敏感信息。 ### 消息平台 @@ -104,10 +94,6 @@ docker compose up -d | Discord | ✅ | | | Telegram | ✅ | | | Slack | ✅ | | -| LINE | 🚧 | | -| WhatsApp | 🚧 | | - -🚧: 正在开发中 ### 大模型能力 @@ -149,14 +135,8 @@ docker compose up -d ## 😘 社区贡献 -感谢以下[代码贡献者](https://github.com/RockChinQ/LangBot/graphs/contributors)和社区里其他成员对 LangBot 的贡献: +感谢以下[代码贡献者](https://github.com/langbot-app/LangBot/graphs/contributors)和社区里其他成员对 LangBot 的贡献: - - + + - -## 😎 保持更新 - -点击仓库右上角 Star 和 Watch 按钮,获取最新动态。 - -![star gif](https://docs.langbot.app/star.gif) diff --git a/README_EN.md b/README_EN.md index 07667f84..b3d2b761 100644 --- a/README_EN.md +++ b/README_EN.md @@ -1,48 +1,34 @@

-LangBot +LangBot

-RockChinQ%2FLangBot | Trendshift +[简体中文](README.md) / English / [日本語](README_JP.md) / (PR for your language) + +[![Discord](https://img.shields.io/discord/1335141740050649118?logo=discord&labelColor=%20%235462eb&logoColor=%20%23f5f5f5&color=%20%235462eb)](https://discord.gg/wdNEHETs87) +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/langbot-app/LangBot) +[![GitHub release (latest by date)](https://img.shields.io/github/v/release/langbot-app/LangBot)](https://github.com/langbot-app/LangBot/releases/latest) +python HomeDeploymentPlugin | -Submit Plugin - -
-😎High Stability, 🧩Extension Supported, 🦄Multi-modal - LLM Native Instant Messaging Bot Platform🤖 -
- -
- - -[![Discord](https://img.shields.io/discord/1335141740050649118?logo=discord&labelColor=%20%235462eb&logoColor=%20%23f5f5f5&color=%20%235462eb)](https://discord.gg/wdNEHETs87) -[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/RockChinQ/LangBot) -[![GitHub release (latest by date)](https://img.shields.io/github/v/release/RockChinQ/LangBot)](https://github.com/RockChinQ/LangBot/releases/latest) -python - -[简体中文](README.md) / English / [日本語](README_JP.md) / (PR for your language) +Submit Plugin

-## ✨ Features - -- 💬 Chat with LLM / Agent: Supports multiple LLMs, adapt to group chats and private chats; Supports multi-round conversations, tool calls, and multi-modal capabilities. Deeply integrates with [Dify](https://dify.ai). Currently supports QQ, QQ Channel, WeCom, personal WeChat, Lark, DingTalk, Discord, Telegram, etc. -- 🛠️ High Stability, Feature-rich: Native access control, rate limiting, sensitive word filtering, etc. mechanisms; Easy to use, supports multiple deployment methods. Supports multiple pipeline configurations, different bots can be used for different scenarios. -- 🧩 Plugin Extension, Active Community: Support event-driven, component extension, etc. plugin mechanisms; Integrate Anthropic [MCP protocol](https://modelcontextprotocol.io/); Currently has hundreds of plugins. -- 😻 [New] Web UI: Support management LangBot instance through the browser. No need to manually write configuration files. +LangBot is an open-source LLM native instant messaging robot development platform, aiming to provide out-of-the-box IM robot development experience, with Agent, RAG, MCP and other LLM application functions, adapting to global instant messaging platforms, and providing rich API interfaces, supporting custom development. ## 📦 Getting Started #### Docker Compose Deployment ```bash -git clone https://github.com/RockChinQ/LangBot +git clone https://github.com/langbot-app/LangBot cd LangBot docker compose up -d ``` @@ -69,23 +55,25 @@ Community contributed Zeabur template. Directly use the released version to run, see the [Manual Deployment](https://docs.langbot.app/en/deploy/langbot/manual.html) documentation. -## 📸 Demo +## 😎 Stay Ahead -bots +Click the Star and Watch button in the upper right corner of the repository to get the latest updates. -bots +![star gif](https://docs.langbot.app/star.gif) -bots +## ✨ Features -bots +- 💬 Chat with LLM / Agent: Supports multiple LLMs, adapt to group chats and private chats; Supports multi-round conversations, tool calls, and multi-modal capabilities. Built-in RAG (knowledge base) implementation, and deeply integrates with [Dify](https://dify.ai). +- 🤖 Multi-platform Support: Currently supports QQ, QQ Channel, WeCom, personal WeChat, Lark, DingTalk, Discord, Telegram, etc. +- 🛠️ High Stability, Feature-rich: Native access control, rate limiting, sensitive word filtering, etc. mechanisms; Easy to use, supports multiple deployment methods. Supports multiple pipeline configurations, different bots can be used for different scenarios. +- 🧩 Plugin Extension, Active Community: Support event-driven, component extension, etc. plugin mechanisms; Integrate Anthropic [MCP protocol](https://modelcontextprotocol.io/); Currently has hundreds of plugins. +- 😻 Web UI: Support management LangBot instance through the browser. No need to manually write configuration files. -Reply Effect (with Internet Plugin) +For more detailed specifications, please refer to the [documentation](https://docs.langbot.app/en/insight/features.html). -- WebUI Demo: https://demo.langbot.dev/ - - Login information: Email: `demo@langbot.app` Password: `langbot123456` - - Note: Only the WebUI effect is shown, please do not fill in any sensitive information in the public environment. - -## 🔌 Component Compatibility +Or visit the demo environment: https://demo.langbot.dev/ + - Login information: Email: `demo@langbot.app` Password: `langbot123456` + - Note: For WebUI demo only, please do not fill in any sensitive information in the public environment. ### Message Platform @@ -101,10 +89,6 @@ Directly use the released version to run, see the [Manual Deployment](https://do | Discord | ✅ | | | Telegram | ✅ | | | Slack | ✅ | | -| LINE | 🚧 | | -| WhatsApp | 🚧 | | - -🚧: In development ### LLMs @@ -132,14 +116,8 @@ Directly use the released version to run, see the [Manual Deployment](https://do ## 🤝 Community Contribution -Thank you for the following [code contributors](https://github.com/RockChinQ/LangBot/graphs/contributors) and other members in the community for their contributions to LangBot: +Thank you for the following [code contributors](https://github.com/langbot-app/LangBot/graphs/contributors) and other members in the community for their contributions to LangBot: - - + + - -## 😎 Stay Ahead - -Click the Star and Watch button in the upper right corner of the repository to get the latest updates. - -![star gif](https://docs.langbot.app/star.gif) \ No newline at end of file diff --git a/README_JP.md b/README_JP.md index c54ce51b..5cb629d8 100644 --- a/README_JP.md +++ b/README_JP.md @@ -1,47 +1,34 @@

-LangBot +LangBot

-RockChinQ%2FLangBot | Trendshift +[简体中文](README.md) / [English](README_EN.md) / 日本語 / (PR for your language) + +[![Discord](https://img.shields.io/discord/1335141740050649118?logo=discord&labelColor=%20%235462eb&logoColor=%20%23f5f5f5&color=%20%235462eb)](https://discord.gg/wdNEHETs87) +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/langbot-app/LangBot) +[![GitHub release (latest by date)](https://img.shields.io/github/v/release/langbot-app/LangBot)](https://github.com/langbot-app/LangBot/releases/latest) +python ホームデプロイプラグイン | -プラグインの提出 - -
-😎高い安定性、🧩拡張サポート、🦄マルチモーダル - LLMネイティブインスタントメッセージングボットプラットフォーム🤖 -
- -
- -[![Discord](https://img.shields.io/discord/1335141740050649118?logo=discord&labelColor=%20%235462eb&logoColor=%20%23f5f5f5&color=%20%235462eb)](https://discord.gg/wdNEHETs87) -[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/RockChinQ/LangBot) -[![GitHub release (latest by date)](https://img.shields.io/github/v/release/RockChinQ/LangBot)](https://github.com/RockChinQ/LangBot/releases/latest) -python - -[简体中文](README_CN.md) / [English](README.md) / [日本語](README_JP.md) / (PR for your language) +プラグインの提出

-## ✨ 機能 - -- 💬 LLM / エージェントとのチャット: 複数のLLMをサポートし、グループチャットとプライベートチャットに対応。マルチラウンドの会話、ツールの呼び出し、マルチモーダル機能をサポート。 [Dify](https://dify.ai) と深く統合。現在、QQ、QQ チャンネル、WeChat、個人 WeChat、Lark、DingTalk、Discord、Telegram など、複数のプラットフォームをサポートしています。 -- 🛠️ 高い安定性、豊富な機能: ネイティブのアクセス制御、レート制限、敏感な単語のフィルタリングなどのメカニズムをサポート。使いやすく、複数のデプロイ方法をサポート。複数のパイプライン設定をサポートし、異なるボットを異なる用途に使用できます。 -- 🧩 プラグイン拡張、活発なコミュニティ: イベント駆動、コンポーネント拡張などのプラグインメカニズムをサポート。適配 Anthropic [MCP プロトコル](https://modelcontextprotocol.io/);豊富なエコシステム、現在数百のプラグインが存在。 -- 😻 Web UI: ブラウザを通じてLangBotインスタンスを管理することをサポート。 +LangBot は、エージェント、RAG、MCP などの LLM アプリケーション機能を備えた、オープンソースの LLM ネイティブのインスタントメッセージングロボット開発プラットフォームです。世界中のインスタントメッセージングプラットフォームに適応し、豊富な API インターフェースを提供し、カスタム開発をサポートします。 ## 📦 始め方 #### Docker Compose デプロイ ```bash -git clone https://github.com/RockChinQ/LangBot +git clone https://github.com/langbot-app/LangBot cd LangBot docker compose up -d ``` @@ -50,7 +37,7 @@ http://localhost:5300 にアクセスして使用を開始します。 詳細なドキュメントは[Dockerデプロイ](https://docs.langbot.app/en/deploy/langbot/docker.html)を参照してください。 -#### BTPanelでのワンクリックデプロイ +#### Panelでのワンクリックデプロイ LangBotはBTPanelにリストされています。BTPanelをインストールしている場合は、[ドキュメント](https://docs.langbot.app/en/deploy/langbot/one-click/bt.html)を使用して使用できます。 @@ -68,23 +55,25 @@ LangBotはBTPanelにリストされています。BTPanelをインストール リリースバージョンを直接使用して実行します。[手動デプロイ](https://docs.langbot.app/en/deploy/langbot/manual.html)のドキュメントを参照してください。 -## 📸 デモ +## 😎 最新情報を入手 -bots +リポジトリの右上にある Star と Watch ボタンをクリックして、最新の更新を取得してください。 -bots +![star gif](https://docs.langbot.app/star.gif) -bots +## ✨ 機能 -bots +- 💬 LLM / エージェントとのチャット: 複数のLLMをサポートし、グループチャットとプライベートチャットに対応。マルチラウンドの会話、ツールの呼び出し、マルチモーダル機能をサポート、RAG(知識ベース)を組み込み、[Dify](https://dify.ai) と深く統合。 +- 🤖 多プラットフォーム対応: 現在、QQ、QQ チャンネル、WeChat、個人 WeChat、Lark、DingTalk、Discord、Telegram など、複数のプラットフォームをサポートしています。 +- 🛠️ 高い安定性、豊富な機能: ネイティブのアクセス制御、レート制限、敏感な単語のフィルタリングなどのメカニズムをサポート。使いやすく、複数のデプロイ方法をサポート。複数のパイプライン設定をサポートし、異なるボットを異なる用途に使用できます。 +- 🧩 プラグイン拡張、活発なコミュニティ: イベント駆動、コンポーネント拡張などのプラグインメカニズムをサポート。適配 Anthropic [MCP プロトコル](https://modelcontextprotocol.io/);豊富なエコシステム、現在数百のプラグインが存在。 +- 😻 Web UI: ブラウザを通じてLangBotインスタンスを管理することをサポート。 -返信効果(インターネットプラグイン付き) +詳細な仕様については、[ドキュメント](https://docs.langbot.app/en/insight/features.html)を参照してください。 -- WebUIデモ: https://demo.langbot.dev/ - - ログイン情報: メール: `demo@langbot.app` パスワード: `langbot123456` - - 注意: WebUIの効果のみを示しています。公開環境では、機密情報を入力しないでください。 - -## 🔌 コンポーネントの互換性 +または、デモ環境にアクセスしてください: https://demo.langbot.dev/ + - ログイン情報: メール: `demo@langbot.app` パスワード: `langbot123456` + - 注意: WebUI のデモンストレーションのみの場合、公開環境では機密情報を入力しないでください。 ### メッセージプラットフォーム @@ -100,10 +89,6 @@ LangBotはBTPanelにリストされています。BTPanelをインストール | Discord | ✅ | | | Telegram | ✅ | | | Slack | ✅ | | -| LINE | 🚧 | | -| WhatsApp | 🚧 | | - -🚧: 開発中 ### LLMs @@ -131,14 +116,8 @@ LangBotはBTPanelにリストされています。BTPanelをインストール ## 🤝 コミュニティ貢献 -LangBot への貢献に対して、以下の [コード貢献者](https://github.com/RockChinQ/LangBot/graphs/contributors) とコミュニティの他のメンバーに感謝します。 +LangBot への貢献に対して、以下の [コード貢献者](https://github.com/langbot-app/LangBot/graphs/contributors) とコミュニティの他のメンバーに感謝します。 - - + + - -## 😎 最新情報を入手 - -リポジトリの右上にある Star と Watch ボタンをクリックして、最新の更新を取得してください。 - -![star gif](https://docs.langbot.app/star.gif) \ No newline at end of file diff --git a/main.py b/main.py index 19cb32d6..1909e343 100644 --- a/main.py +++ b/main.py @@ -11,7 +11,7 @@ asciiart = r""" |____\__,_|_||_\__, |___/\___/\__| |___/ -⭐️ Open Source 开源地址: https://github.com/RockChinQ/LangBot +⭐️ Open Source 开源地址: https://github.com/langbot-app/LangBot 📖 Documentation 文档地址: https://docs.langbot.app """ diff --git a/pkg/api/http/controller/group.py b/pkg/api/http/controller/group.py index 3f34d79b..16fa1df1 100644 --- a/pkg/api/http/controller/group.py +++ b/pkg/api/http/controller/group.py @@ -11,7 +11,7 @@ from ....core import app preregistered_groups: list[type[RouterGroup]] = [] -"""RouterGroup 的预注册列表""" +"""Pre-registered list of RouterGroup""" def group_class(name: str, path: str) -> typing.Callable[[typing.Type[RouterGroup]], typing.Type[RouterGroup]]: @@ -27,7 +27,7 @@ def group_class(name: str, path: str) -> typing.Callable[[typing.Type[RouterGrou class AuthType(enum.Enum): - """认证类型""" + """Authentication type""" NONE = 'none' USER_TOKEN = 'user-token' @@ -56,7 +56,7 @@ class RouterGroup(abc.ABC): auth_type: AuthType = AuthType.USER_TOKEN, **options: typing.Any, ) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator - """注册一个路由""" + """Register a route""" def decorator(f: RouteCallable) -> RouteCallable: nonlocal rule @@ -64,11 +64,11 @@ class RouterGroup(abc.ABC): async def handler_error(*args, **kwargs): if auth_type == AuthType.USER_TOKEN: - # 从Authorization头中获取token + # get token from Authorization header token = quart.request.headers.get('Authorization', '').replace('Bearer ', '') if not token: - return self.http_status(401, -1, '未提供有效的用户令牌') + return self.http_status(401, -1, 'No valid user token provided') try: user_email = await self.ap.user_service.verify_jwt_token(token) @@ -76,9 +76,9 @@ class RouterGroup(abc.ABC): # check if this account exists user = await self.ap.user_service.get_user_by_email(user_email) if not user: - return self.http_status(401, -1, '用户不存在') + return self.http_status(401, -1, 'User not found') - # 检查f是否接受user_email参数 + # check if f accepts user_email parameter if 'user_email' in f.__code__.co_varnames: kwargs['user_email'] = user_email except Exception as e: @@ -86,10 +86,11 @@ class RouterGroup(abc.ABC): try: return await f(*args, **kwargs) - except Exception: # 自动 500 + + except Exception as e: # 自动 500 traceback.print_exc() # return self.http_status(500, -2, str(e)) - return self.http_status(500, -2, 'internal server error') + return self.http_status(500, -2, str(e)) new_f = handler_error new_f.__name__ = (self.name + rule).replace('/', '__') @@ -101,7 +102,7 @@ class RouterGroup(abc.ABC): return decorator def success(self, data: typing.Any = None) -> quart.Response: - """返回一个 200 响应""" + """Return a 200 response""" return quart.jsonify( { 'code': 0, @@ -111,7 +112,7 @@ class RouterGroup(abc.ABC): ) def fail(self, code: int, msg: str) -> quart.Response: - """返回一个异常响应""" + """Return an error response""" return quart.jsonify( { @@ -122,4 +123,4 @@ class RouterGroup(abc.ABC): def http_status(self, status: int, code: int, msg: str) -> typing.Tuple[quart.Response, int]: """返回一个指定状态码的响应""" - return (self.fail(code, msg), status) + return (self.fail(code, msg), status) \ No newline at end of file diff --git a/pkg/api/http/controller/groups/files.py b/pkg/api/http/controller/groups/files.py index d08cbd71..b3c1a3f1 100644 --- a/pkg/api/http/controller/groups/files.py +++ b/pkg/api/http/controller/groups/files.py @@ -34,8 +34,9 @@ class FilesRouterGroup(group.RouterGroup): file_bytes = await asyncio.to_thread(file.stream.read) extension = file.filename.split('.')[-1] + file_name = file.filename.split('.')[0] - file_key = str(uuid.uuid4()) + '.' + extension + file_key = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension # save file to storage await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes) return self.success( diff --git a/pkg/api/http/controller/groups/knowledge/base.py b/pkg/api/http/controller/groups/knowledge/base.py index 866b4af2..a5bed5df 100644 --- a/pkg/api/http/controller/groups/knowledge/base.py +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -20,7 +20,7 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): @self.route( '/', - methods=['GET', 'DELETE'], + methods=['GET', 'DELETE', 'PUT'], ) async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> quart.Response: if quart.request.method == 'GET': @@ -34,6 +34,12 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): 'base': knowledge_base, } ) + + elif quart.request.method == 'PUT': + json_data = await quart.request.json + await self.ap.knowledge_service.update_knowledge_base(knowledge_base_uuid, json_data) + return self.success({}) + elif quart.request.method == 'DELETE': await self.ap.knowledge_service.delete_knowledge_base(knowledge_base_uuid) return self.success({}) @@ -72,3 +78,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): async def delete_specific_file_in_kb(file_id: str, knowledge_base_uuid: str) -> str: await self.ap.knowledge_service.delete_file(knowledge_base_uuid, file_id) return self.success({}) + + @self.route( + '//retrieve', + methods=['POST'], + ) + async def retrieve_knowledge_base(knowledge_base_uuid: str) -> str: + json_data = await quart.request.json + query = json_data.get('query') + results = await self.ap.knowledge_service.retrieve_knowledge_base(knowledge_base_uuid, query) + return self.success(data={'results': results}) diff --git a/pkg/api/http/controller/groups/pipelines/pipelines.py b/pkg/api/http/controller/groups/pipelines/pipelines.py index 96ca239a..d056afb4 100644 --- a/pkg/api/http/controller/groups/pipelines/pipelines.py +++ b/pkg/api/http/controller/groups/pipelines/pipelines.py @@ -11,7 +11,9 @@ class PipelinesRouterGroup(group.RouterGroup): @self.route('', methods=['GET', 'POST']) async def _() -> str: if quart.request.method == 'GET': - return self.success(data={'pipelines': await self.ap.pipeline_service.get_pipelines()}) + sort_by = quart.request.args.get('sort_by', 'created_at') + sort_order = quart.request.args.get('sort_order', 'DESC') + return self.success(data={'pipelines': await self.ap.pipeline_service.get_pipelines(sort_by, sort_order)}) elif quart.request.method == 'POST': json_data = await quart.request.json diff --git a/pkg/api/http/controller/groups/pipelines/webchat.py b/pkg/api/http/controller/groups/pipelines/webchat.py index 005738db..c8c8db54 100644 --- a/pkg/api/http/controller/groups/pipelines/webchat.py +++ b/pkg/api/http/controller/groups/pipelines/webchat.py @@ -8,7 +8,7 @@ class WebChatDebugRouterGroup(group.RouterGroup): async def initialize(self) -> None: @self.route('/send', methods=['POST']) async def send_message(pipeline_uuid: str) -> str: - """发送调试消息到流水线""" + """Send a message to the pipeline for debugging""" try: data = await quart.request.get_json() session_type = data.get('session_type', 'person') @@ -38,7 +38,7 @@ class WebChatDebugRouterGroup(group.RouterGroup): @self.route('/messages/', methods=['GET']) async def get_messages(pipeline_uuid: str, session_type: str) -> str: - """获取调试消息历史""" + """Get the message history of the pipeline for debugging""" try: if session_type not in ['person', 'group']: return self.http_status(400, -1, 'session_type must be person or group') @@ -57,7 +57,7 @@ class WebChatDebugRouterGroup(group.RouterGroup): @self.route('/reset/', methods=['POST']) async def reset_session(session_type: str) -> str: - """重置调试会话""" + """Reset the debug session""" try: if session_type not in ['person', 'group']: return self.http_status(400, -1, 'session_type must be person or group') diff --git a/pkg/api/http/controller/groups/plugins.py b/pkg/api/http/controller/groups/plugins.py index daf6ea7d..b7e0a5e9 100644 --- a/pkg/api/http/controller/groups/plugins.py +++ b/pkg/api/http/controller/groups/plugins.py @@ -40,7 +40,7 @@ class PluginsRouterGroup(group.RouterGroup): self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx), kind='plugin-operation', name=f'plugin-update-{plugin_name}', - label=f'更新插件 {plugin_name}', + label=f'Updating plugin {plugin_name}', context=ctx, ) return self.success(data={'task_id': wrapper.id}) @@ -62,7 +62,7 @@ class PluginsRouterGroup(group.RouterGroup): self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx), kind='plugin-operation', name=f'plugin-remove-{plugin_name}', - label=f'删除插件 {plugin_name}', + label=f'Removing plugin {plugin_name}', context=ctx, ) @@ -102,7 +102,7 @@ class PluginsRouterGroup(group.RouterGroup): self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx), kind='plugin-operation', name='plugin-install-github', - label=f'安装插件 ...{short_source_str}', + label=f'Installing plugin ...{short_source_str}', context=ctx, ) diff --git a/pkg/api/http/controller/groups/user.py b/pkg/api/http/controller/groups/user.py index 3ad1335b..d8024107 100644 --- a/pkg/api/http/controller/groups/user.py +++ b/pkg/api/http/controller/groups/user.py @@ -14,7 +14,7 @@ class UserRouterGroup(group.RouterGroup): return self.success(data={'initialized': await self.ap.user_service.is_initialized()}) if await self.ap.user_service.is_initialized(): - return self.fail(1, '系统已初始化') + return self.fail(1, 'System already initialized') json_data = await quart.request.json @@ -32,7 +32,7 @@ class UserRouterGroup(group.RouterGroup): try: token = await self.ap.user_service.authenticate(json_data['user'], json_data['password']) except argon2.exceptions.VerifyMismatchError: - return self.fail(1, '用户名或密码错误') + return self.fail(1, 'Invalid username or password') return self.success(data={'token': token}) @@ -54,15 +54,15 @@ class UserRouterGroup(group.RouterGroup): await asyncio.sleep(3) if not await self.ap.user_service.is_initialized(): - return self.http_status(400, -1, 'system not initialized') + return self.http_status(400, -1, 'System not initialized') user_obj = await self.ap.user_service.get_user_by_email(user_email) if user_obj is None: - return self.http_status(400, -1, 'user not found') + return self.http_status(400, -1, 'User not found') if recovery_key != self.ap.instance_config.data['system']['recovery_key']: - return self.http_status(403, -1, 'invalid recovery key') + return self.http_status(403, -1, 'Invalid recovery key') await self.ap.user_service.reset_password(user_email, new_password) diff --git a/pkg/api/http/controller/main.py b/pkg/api/http/controller/main.py index 4eec4e1d..e45b461d 100644 --- a/pkg/api/http/controller/main.py +++ b/pkg/api/http/controller/main.py @@ -47,7 +47,7 @@ class HTTPController: try: await self.quart_app.run_task(*args, **kwargs) except Exception as e: - self.ap.logger.error(f'启动 HTTP 服务失败: {e}') + self.ap.logger.error(f'Failed to start HTTP service: {e}') self.ap.task_mgr.create_task( exception_handler( diff --git a/pkg/api/http/service/bot.py b/pkg/api/http/service/bot.py index e5010007..adf19d03 100644 --- a/pkg/api/http/service/bot.py +++ b/pkg/api/http/service/bot.py @@ -10,7 +10,7 @@ from ....entity.persistence import pipeline as persistence_pipeline class BotService: - """机器人服务""" + """Bot service""" ap: app.Application @@ -18,7 +18,7 @@ class BotService: self.ap = ap async def get_bots(self) -> list[dict]: - """获取所有机器人""" + """Get all bots""" result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot)) bots = result.all() @@ -26,7 +26,7 @@ class BotService: return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) for bot in bots] async def get_bot(self, bot_uuid: str) -> dict | None: - """获取机器人""" + """Get bot""" result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid) ) @@ -39,7 +39,7 @@ class BotService: return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) async def create_bot(self, bot_data: dict) -> str: - """创建机器人""" + """Create bot""" # TODO: 检查配置信息格式 bot_data['uuid'] = str(uuid.uuid4()) @@ -63,7 +63,7 @@ class BotService: return bot_data['uuid'] async def update_bot(self, bot_uuid: str, bot_data: dict) -> None: - """更新机器人""" + """Update bot""" if 'uuid' in bot_data: del bot_data['uuid'] @@ -99,7 +99,7 @@ class BotService: session.using_conversation = None async def delete_bot(self, bot_uuid: str) -> None: - """删除机器人""" + """Delete bot""" await self.ap.platform_mgr.remove_bot(bot_uuid) await self.ap.persistence_mgr.execute_async( sqlalchemy.delete(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid) diff --git a/pkg/api/http/service/knowledge.py b/pkg/api/http/service/knowledge.py index 5d702ba4..27506ec9 100644 --- a/pkg/api/http/service/knowledge.py +++ b/pkg/api/http/service/knowledge.py @@ -47,12 +47,18 @@ class KnowledgeService: async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None: """更新知识库""" + if 'uuid' in kb_data: + del kb_data['uuid'] + + if 'embedding_model_uuid' in kb_data: + del kb_data['embedding_model_uuid'] + await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_rag.KnowledgeBase) .values(kb_data) .where(persistence_rag.KnowledgeBase.uuid == kb_uuid) ) - await self.ap.rag_mgr.remove_knowledge_base(kb_uuid) + await self.ap.rag_mgr.remove_knowledge_base_from_runtime(kb_uuid) kb = await self.get_knowledge_base(kb_uuid) @@ -67,6 +73,13 @@ class KnowledgeService: raise Exception('Knowledge base not found') return await runtime_kb.store_file(file_id) + async def retrieve_knowledge_base(self, kb_uuid: str, query: str) -> list[dict]: + """检索知识库""" + runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid) + if runtime_kb is None: + raise Exception('Knowledge base not found') + return [result.model_dump() for result in await runtime_kb.retrieve(query)] + async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]: """获取知识库文件""" result = await self.ap.persistence_mgr.execute_async( @@ -77,14 +90,29 @@ class KnowledgeService: async def delete_file(self, kb_uuid: str, file_id: str) -> None: """删除文件""" - await self.ap.persistence_mgr.execute_async( - sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id) - ) - # TODO: remove from memory + runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid) + if runtime_kb is None: + raise Exception('Knowledge base not found') + await runtime_kb.delete_file(file_id) async def delete_knowledge_base(self, kb_uuid: str) -> None: """删除知识库""" + await self.ap.rag_mgr.delete_knowledge_base(kb_uuid) + await self.ap.persistence_mgr.execute_async( sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid) ) - # TODO: remove from memory + + # delete files + files = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid) + ) + for file in files: + # delete chunks + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file.uuid) + ) + # delete file + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file.uuid) + ) diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index 3a4998e2..d8457da3 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -186,6 +186,6 @@ class EmbeddingModelsService: await runtime_embedding_model.requester.invoke_embedding( model=runtime_embedding_model, - input_text='Hello, world!', + input_text=['Hello, world!'], extra_args={}, ) diff --git a/pkg/api/http/service/pipeline.py b/pkg/api/http/service/pipeline.py index f0f6c083..96504d61 100644 --- a/pkg/api/http/service/pipeline.py +++ b/pkg/api/http/service/pipeline.py @@ -38,9 +38,21 @@ class PipelineService: self.ap.pipeline_config_meta_output.data, ] - async def get_pipelines(self) -> list[dict]: - result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline)) - + async def get_pipelines(self, sort_by: str = 'created_at', sort_order: str = 'DESC') -> list[dict]: + query = sqlalchemy.select(persistence_pipeline.LegacyPipeline) + + if sort_by == 'created_at': + if sort_order == 'DESC': + query = query.order_by(persistence_pipeline.LegacyPipeline.created_at.desc()) + else: + query = query.order_by(persistence_pipeline.LegacyPipeline.created_at.asc()) + elif sort_by == 'updated_at': + if sort_order == 'DESC': + query = query.order_by(persistence_pipeline.LegacyPipeline.updated_at.desc()) + else: + query = query.order_by(persistence_pipeline.LegacyPipeline.updated_at.asc()) + + result = await self.ap.persistence_mgr.execute_async(query) pipelines = result.all() return [ self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline) diff --git a/pkg/config/impls/json.py b/pkg/config/impls/json.py index 07fc533c..44b4843c 100644 --- a/pkg/config/impls/json.py +++ b/pkg/config/impls/json.py @@ -6,7 +6,7 @@ from .. import model as file_model class JSONConfigFile(file_model.ConfigFile): - """JSON配置文件""" + """JSON config file""" def __init__( self, @@ -42,7 +42,7 @@ class JSONConfigFile(file_model.ConfigFile): try: cfg = json.load(f) except json.JSONDecodeError as e: - raise Exception(f'配置文件 {self.config_file_name} 语法错误: {e}') + raise Exception(f'Syntax error in config file {self.config_file_name}: {e}') if completion: for key in self.template_data: diff --git a/pkg/config/impls/pymodule.py b/pkg/config/impls/pymodule.py index 2311992e..c3d04bc8 100644 --- a/pkg/config/impls/pymodule.py +++ b/pkg/config/impls/pymodule.py @@ -7,13 +7,13 @@ from .. import model as file_model class PythonModuleConfigFile(file_model.ConfigFile): - """Python模块配置文件""" + """Python module config file""" config_file_name: str = None - """配置文件名""" + """Config file name""" template_file_name: str = None - """模板文件名""" + """Template file name""" def __init__(self, config_file_name: str, template_file_name: str) -> None: self.config_file_name = config_file_name @@ -42,7 +42,7 @@ class PythonModuleConfigFile(file_model.ConfigFile): cfg[key] = getattr(module, key) - # 从模板模块文件中进行补全 + # complete from template module file if completion: module_name = os.path.splitext(os.path.basename(self.template_file_name))[0] module = importlib.import_module(module_name) @@ -60,7 +60,7 @@ class PythonModuleConfigFile(file_model.ConfigFile): return cfg async def save(self, data: dict): - logging.warning('Python模块配置文件不支持保存') + logging.warning('Python module config file does not support saving') def save_sync(self, data: dict): - logging.warning('Python模块配置文件不支持保存') + logging.warning('Python module config file does not support saving') diff --git a/pkg/config/impls/yaml.py b/pkg/config/impls/yaml.py index 55045186..0d69ef9e 100644 --- a/pkg/config/impls/yaml.py +++ b/pkg/config/impls/yaml.py @@ -6,7 +6,7 @@ from .. import model as file_model class YAMLConfigFile(file_model.ConfigFile): - """YAML配置文件""" + """YAML config file""" def __init__( self, @@ -42,7 +42,7 @@ class YAMLConfigFile(file_model.ConfigFile): try: cfg = yaml.load(f, Loader=yaml.FullLoader) except yaml.YAMLError as e: - raise Exception(f'配置文件 {self.config_file_name} 语法错误: {e}') + raise Exception(f'Syntax error in config file {self.config_file_name}: {e}') if completion: for key in self.template_data: diff --git a/pkg/config/manager.py b/pkg/config/manager.py index c2e6bdf4..d552b038 100644 --- a/pkg/config/manager.py +++ b/pkg/config/manager.py @@ -5,27 +5,27 @@ from .impls import pymodule, json as json_file, yaml as yaml_file class ConfigManager: - """配置文件管理器""" + """Config file manager""" name: str = None - """配置管理器名""" + """Config manager name""" description: str = None - """配置管理器描述""" + """Config manager description""" schema: dict = None - """配置文件 schema - 需要符合 JSON Schema Draft 7 规范 + """Config file schema + Must conform to JSON Schema Draft 7 specification """ file: file_model.ConfigFile = None - """配置文件实例""" + """Config file instance""" data: dict = None - """配置数据""" + """Config data""" doc_link: str = None - """配置文件文档链接""" + """Config file documentation link""" def __init__(self, cfg_file: file_model.ConfigFile) -> None: self.file = cfg_file @@ -42,15 +42,15 @@ class ConfigManager: async def load_python_module_config(config_name: str, template_name: str, completion: bool = True) -> ConfigManager: - """加载Python模块配置文件 + """Load Python module config file Args: - config_name (str): 配置文件名 - template_name (str): 模板文件名 - completion (bool): 是否自动补全内存中的配置文件 + config_name (str): Config file name + template_name (str): Template file name + completion (bool): Whether to automatically complete the config file in memory Returns: - ConfigManager: 配置文件管理器 + ConfigManager: Config file manager """ cfg_inst = pymodule.PythonModuleConfigFile(config_name, template_name) @@ -66,13 +66,13 @@ async def load_json_config( template_data: dict = None, completion: bool = True, ) -> ConfigManager: - """加载JSON配置文件 + """Load JSON config file Args: - config_name (str): 配置文件名 - template_name (str): 模板文件名 - template_data (dict): 模板数据 - completion (bool): 是否自动补全内存中的配置文件 + config_name (str): Config file name + template_name (str): Template file name + template_data (dict): Template data + completion (bool): Whether to automatically complete the config file in memory """ cfg_inst = json_file.JSONConfigFile(config_name, template_name, template_data) @@ -88,16 +88,16 @@ async def load_yaml_config( template_data: dict = None, completion: bool = True, ) -> ConfigManager: - """加载YAML配置文件 + """Load YAML config file Args: - config_name (str): 配置文件名 - template_name (str): 模板文件名 - template_data (dict): 模板数据 - completion (bool): 是否自动补全内存中的配置文件 + config_name (str): Config file name + template_name (str): Template file name + template_data (dict): Template data + completion (bool): Whether to automatically complete the config file in memory Returns: - ConfigManager: 配置文件管理器 + ConfigManager: Config file manager """ cfg_inst = yaml_file.YAMLConfigFile(config_name, template_name, template_data) diff --git a/pkg/config/model.py b/pkg/config/model.py index f3536804..8b040f05 100644 --- a/pkg/config/model.py +++ b/pkg/config/model.py @@ -2,16 +2,16 @@ import abc class ConfigFile(metaclass=abc.ABCMeta): - """配置文件抽象类""" + """Config file abstract class""" config_file_name: str = None - """配置文件名""" + """Config file name""" template_file_name: str = None - """模板文件名""" + """Template file name""" template_data: dict = None - """模板数据""" + """Template data""" @abc.abstractmethod def exists(self) -> bool: diff --git a/pkg/core/app.py b/pkg/core/app.py index ca2c5c1c..21816cfc 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -28,11 +28,12 @@ from ..storage import mgr as storagemgr from ..utils import logcache from . import taskmgr from . import entities as core_entities -from ..rag.knowledge import mgr as rag_mgr +from ..rag.knowledge import kbmgr as rag_mgr +from ..vector import mgr as vectordb_mgr class Application: - """运行时应用对象和上下文""" + """Runtime application object and context""" event_loop: asyncio.AbstractEventLoop = None @@ -51,10 +52,10 @@ class Application: rag_mgr: rag_mgr.RAGManager = None - # TODO 移动到 pipeline 里 + # TODO move to pipeline tool_mgr: llm_tool_mgr.ToolManager = None - # ======= 配置管理器 ======= + # ======= Config manager ======= command_cfg: config_mgr.ConfigManager = None # deprecated @@ -68,7 +69,7 @@ class Application: instance_config: config_mgr.ConfigManager = None - # ======= 元数据配置管理器 ======= + # ======= Metadata config manager ======= sensitive_meta: config_mgr.ConfigManager = None @@ -97,6 +98,8 @@ class Application: persistence_mgr: persistencemgr.PersistenceManager = None + vector_db_mgr: vectordb_mgr.VectorDBManager = None + http_ctrl: http_controller.HTTPController = None log_cache: logcache.LogCache = None @@ -163,11 +166,11 @@ class Application: except asyncio.CancelledError: pass except Exception as e: - self.logger.error(f'应用运行致命异常: {e}') + self.logger.error(f'Application runtime fatal exception: {e}') self.logger.debug(f'Traceback: {traceback.format_exc()}') async def print_web_access_info(self): - """打印访问 webui 的提示""" + """Print access webui tips""" if not os.path.exists(os.path.join('.', 'web/out')): self.logger.warning('WebUI 文件缺失,请根据文档部署:https://docs.langbot.app/zh') @@ -199,7 +202,7 @@ class Application: ): match scope: case core_entities.LifecycleControlScope.PLATFORM.value: - self.logger.info('执行热重载 scope=' + scope) + self.logger.info('Hot reload scope=' + scope) await self.platform_mgr.shutdown() self.platform_mgr = im_mgr.PlatformManager(self) @@ -215,7 +218,7 @@ class Application: ], ) case core_entities.LifecycleControlScope.PLUGIN.value: - self.logger.info('执行热重载 scope=' + scope) + self.logger.info('Hot reload scope=' + scope) await self.plugin_mgr.destroy_plugins() # 删除 sys.module 中所有的 plugins/* 下的模块 @@ -231,7 +234,7 @@ class Application: await self.plugin_mgr.load_plugins() await self.plugin_mgr.initialize_plugins() case core_entities.LifecycleControlScope.PROVIDER.value: - self.logger.info('执行热重载 scope=' + scope) + self.logger.info('Hot reload scope=' + scope) await self.tool_mgr.shutdown() diff --git a/pkg/core/boot.py b/pkg/core/boot.py index aff117e6..b8243d4a 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -8,7 +8,7 @@ from . import app from . import stage from ..utils import constants, importutil -# 引入启动阶段实现以便注册 +# Import startup stage implementation to register from . import stages importutil.import_modules_in_pkg(stages) @@ -25,7 +25,7 @@ stage_order = [ async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application: - # 确定是否为调试模式 + # Determine if it is debug mode if 'DEBUG' in os.environ and os.environ['DEBUG'] in ['true', '1']: constants.debug_mode = True @@ -33,7 +33,7 @@ async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application: ap.event_loop = loop - # 执行启动阶段 + # Execute startup stage for stage_name in stage_order: stage_cls = stage.preregistered_stages[stage_name] stage_inst = stage_cls() @@ -47,11 +47,11 @@ async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application: async def main(loop: asyncio.AbstractEventLoop): try: - # 挂系统信号处理 + # Hang system signal processing import signal def signal_handler(sig, frame): - print('[Signal] 程序退出.') + print('[Signal] Program exit.') # ap.shutdown() os._exit(0) diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py index b403bf8d..1a439af8 100644 --- a/pkg/core/bootutils/deps.py +++ b/pkg/core/bootutils/deps.py @@ -2,8 +2,8 @@ import pip import os from ...utils import pkgmgr -# 检查依赖,防止用户未安装 -# 左边为引入名称,右边为依赖名称 +# Check dependencies to prevent users from not installing +# Left is the import name, right is the dependency name required_deps = { 'requests': 'requests', 'openai': 'openai', @@ -65,7 +65,7 @@ async def install_deps(deps: list[str]): async def precheck_plugin_deps(): print('[Startup] Prechecking plugin dependencies...') - # 只有在plugins目录存在时才执行插件依赖安装 + # Only execute plugin dependency installation when the plugins directory exists if os.path.exists('plugins'): for dir in os.listdir('plugins'): subdir = os.path.join('plugins', dir) diff --git a/pkg/core/bootutils/log.py b/pkg/core/bootutils/log.py index eb6806fa..631b05e2 100644 --- a/pkg/core/bootutils/log.py +++ b/pkg/core/bootutils/log.py @@ -17,7 +17,7 @@ log_colors_config = { async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.Logger: - # 删除所有现有的logger + # Remove all existing loggers for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) @@ -54,13 +54,13 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging. handler.setFormatter(color_formatter) qcg_logger.addHandler(handler) - qcg_logger.debug('日志初始化完成,日志级别:%s' % level) + qcg_logger.debug('Logging initialized, log level: %s' % level) logging.basicConfig( - level=logging.CRITICAL, # 设置日志输出格式 + level=logging.CRITICAL, # Set log output format format='[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s', - # 日志输出的格式 - # -8表示占位符,让输出左对齐,输出长度都为8位 - datefmt='%Y-%m-%d %H:%M:%S', # 时间输出的格式 + # Log output format + # -8 is a placeholder, left-align the output, and output length is 8 + datefmt='%Y-%m-%d %H:%M:%S', # Time output format handlers=[logging.NullHandler()], ) diff --git a/pkg/core/migration.py b/pkg/core/migration.py index e97c0cf3..a921e6c7 100644 --- a/pkg/core/migration.py +++ b/pkg/core/migration.py @@ -7,11 +7,11 @@ from . import app preregistered_migrations: list[typing.Type[Migration]] = [] -"""当前阶段暂不支持扩展""" +"""Currently not supported for extension""" def migration_class(name: str, number: int): - """注册一个迁移""" + """Register a migration""" def decorator(cls: typing.Type[Migration]) -> typing.Type[Migration]: cls.name = name @@ -23,7 +23,7 @@ def migration_class(name: str, number: int): class Migration(abc.ABC): - """一个版本的迁移""" + """A version migration""" name: str @@ -36,10 +36,10 @@ class Migration(abc.ABC): @abc.abstractmethod async def need_migrate(self) -> bool: - """判断当前环境是否需要运行此迁移""" + """Determine if the current environment needs to run this migration""" pass @abc.abstractmethod async def run(self): - """执行迁移""" + """Run migration""" pass diff --git a/pkg/core/note.py b/pkg/core/note.py index 07171581..b4c37ce1 100644 --- a/pkg/core/note.py +++ b/pkg/core/note.py @@ -9,7 +9,7 @@ preregistered_notes: list[typing.Type[LaunchNote]] = [] def note_class(name: str, number: int): - """注册一个启动信息""" + """Register a launch information""" def decorator(cls: typing.Type[LaunchNote]) -> typing.Type[LaunchNote]: cls.name = name @@ -21,7 +21,7 @@ def note_class(name: str, number: int): class LaunchNote(abc.ABC): - """启动信息""" + """Launch information""" name: str @@ -34,10 +34,10 @@ class LaunchNote(abc.ABC): @abc.abstractmethod async def need_show(self) -> bool: - """判断当前环境是否需要显示此启动信息""" + """Determine if the current environment needs to display this launch information""" pass @abc.abstractmethod async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]: - """生成启动信息""" + """Generate launch information""" pass diff --git a/pkg/core/notes/n001_classic_msgs.py b/pkg/core/notes/n001_classic_msgs.py index 3f3bd8e0..265ddbe9 100644 --- a/pkg/core/notes/n001_classic_msgs.py +++ b/pkg/core/notes/n001_classic_msgs.py @@ -7,7 +7,7 @@ from .. import note @note.note_class('ClassicNotes', 1) class ClassicNotes(note.LaunchNote): - """经典启动信息""" + """Classic launch information""" async def need_show(self) -> bool: return True diff --git a/pkg/core/notes/n002_selection_mode_on_windows.py b/pkg/core/notes/n002_selection_mode_on_windows.py index 23bff24a..16028de1 100644 --- a/pkg/core/notes/n002_selection_mode_on_windows.py +++ b/pkg/core/notes/n002_selection_mode_on_windows.py @@ -9,7 +9,7 @@ from .. import note @note.note_class('SelectionModeOnWindows', 2) class SelectionModeOnWindows(note.LaunchNote): - """Windows 上的选择模式提示信息""" + """Selection mode prompt information on Windows""" async def need_show(self) -> bool: return os.name == 'nt' @@ -19,3 +19,8 @@ class SelectionModeOnWindows(note.LaunchNote): """您正在使用 Windows 系统,若窗口左上角显示处于”选择“模式,程序将被暂停运行,此时请右键窗口中空白区域退出选择模式。""", logging.INFO, ) + + yield ( + """You are using Windows system, if the top left corner of the window displays "Selection" mode, the program will be paused running, please right-click on the blank area in the window to exit the selection mode.""", + logging.INFO, + ) diff --git a/pkg/core/stage.py b/pkg/core/stage.py index 220c474d..1483e23a 100644 --- a/pkg/core/stage.py +++ b/pkg/core/stage.py @@ -7,9 +7,9 @@ from . import app preregistered_stages: dict[str, typing.Type[BootingStage]] = {} -"""预注册的请求处理阶段。在初始化时,所有请求处理阶段类会被注册到此字典中。 +"""Pre-registered request processing stages. All request processing stage classes are registered in this dictionary during initialization. -当前阶段暂不支持扩展 +Currently not supported for extension """ @@ -22,11 +22,11 @@ def stage_class(name: str): class BootingStage(abc.ABC): - """启动阶段""" + """Booting stage""" name: str = None @abc.abstractmethod async def run(self, ap: app.Application): - """启动""" + """Run""" pass diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 18240962..0f28f0c8 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -9,7 +9,7 @@ from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.tools import toolmgr as llm_tool_mgr -from ...rag.knowledge import mgr as rag_mgr +from ...rag.knowledge import kbmgr as rag_mgr from ...platform import botmgr as im_mgr from ...persistence import mgr as persistencemgr from ...api.http.controller import main as http_controller @@ -21,15 +21,16 @@ from ...api.http.service import knowledge as knowledge_service from ...discover import engine as discover_engine from ...storage import mgr as storagemgr from ...utils import logcache +from ...vector import mgr as vectordb_mgr from .. import taskmgr @stage.stage_class('BuildAppStage') class BuildAppStage(stage.BootingStage): - """构建应用阶段""" + """Build LangBot application""" async def run(self, ap: app.Application): - """构建app对象的各个组件对象并初始化""" + """Build LangBot application""" ap.task_mgr = taskmgr.AsyncTaskManager(ap) discover = discover_engine.ComponentDiscoveryEngine(ap) @@ -44,7 +45,7 @@ class BuildAppStage(stage.BootingStage): await ver_mgr.initialize() ap.ver_mgr = ver_mgr - # 发送公告 + # Send announcement ann_mgr = announce.AnnouncementManager(ap) ap.ann_mgr = ann_mgr @@ -91,9 +92,14 @@ class BuildAppStage(stage.BootingStage): ap.pipeline_mgr = pipeline_mgr rag_mgr_inst = rag_mgr.RAGManager(ap) - await rag_mgr_inst.initialize_rag_system() + await rag_mgr_inst.initialize() ap.rag_mgr = rag_mgr_inst + # 初始化向量数据库管理器 + vectordb_mgr_inst = vectordb_mgr.VectorDBManager(ap) + await vectordb_mgr_inst.initialize() + ap.vector_db_mgr = vectordb_mgr_inst + http_ctrl = http_controller.HTTPController(ap) await http_ctrl.initialize() ap.http_ctrl = http_ctrl diff --git a/pkg/core/stages/genkeys.py b/pkg/core/stages/genkeys.py index 50e7cf7b..f0412b9d 100644 --- a/pkg/core/stages/genkeys.py +++ b/pkg/core/stages/genkeys.py @@ -7,10 +7,10 @@ from .. import stage, app @stage.stage_class('GenKeysStage') class GenKeysStage(stage.BootingStage): - """生成密钥阶段""" + """Generate keys stage""" async def run(self, ap: app.Application): - """启动""" + """Generate keys""" if not ap.instance_config.data['system']['jwt']['secret']: ap.instance_config.data['system']['jwt']['secret'] = secrets.token_hex(16) diff --git a/pkg/core/stages/load_config.py b/pkg/core/stages/load_config.py index ef5f611b..0474b33a 100644 --- a/pkg/core/stages/load_config.py +++ b/pkg/core/stages/load_config.py @@ -8,10 +8,10 @@ from ..bootutils import config @stage.stage_class('LoadConfigStage') class LoadConfigStage(stage.BootingStage): - """加载配置文件阶段""" + """Load config file stage""" async def run(self, ap: app.Application): - """启动""" + """Load config file""" # ======= deprecated ======= if os.path.exists('data/config/command.json'): diff --git a/pkg/core/stages/migrate.py b/pkg/core/stages/migrate.py index 02b03256..229e0060 100644 --- a/pkg/core/stages/migrate.py +++ b/pkg/core/stages/migrate.py @@ -11,10 +11,13 @@ importutil.import_modules_in_pkg(migrations) @stage.stage_class('MigrationStage') class MigrationStage(stage.BootingStage): - """迁移阶段""" + """Migration stage + + These migrations are legacy, only performed in version 3.x + """ async def run(self, ap: app.Application): - """启动""" + """Run migration""" if any( [ @@ -29,7 +32,7 @@ class MigrationStage(stage.BootingStage): migrations = migration.preregistered_migrations - # 按照迁移号排序 + # Sort by migration number migrations.sort(key=lambda x: x.number) for migration_cls in migrations: @@ -37,4 +40,4 @@ class MigrationStage(stage.BootingStage): if await migration_instance.need_migrate(): await migration_instance.run() - print(f'已执行迁移 {migration_instance.name}') + print(f'Migration {migration_instance.name} executed') diff --git a/pkg/core/stages/setup_logger.py b/pkg/core/stages/setup_logger.py index 0c630175..1f7c81ac 100644 --- a/pkg/core/stages/setup_logger.py +++ b/pkg/core/stages/setup_logger.py @@ -8,7 +8,7 @@ from ..bootutils import log class PersistenceHandler(logging.Handler, object): """ - 保存日志到数据库 + Save logs to database """ ap: app.Application @@ -19,9 +19,9 @@ class PersistenceHandler(logging.Handler, object): def emit(self, record): """ - emit函数为自定义handler类时必重写的函数,这里可以根据需要对日志消息做一些处理,比如发送日志到服务器 + emit function is a required function for custom handler classes, here you can process the log messages as needed, such as sending logs to the server - 发出记录(Emit a record) + Emit a record """ try: msg = self.format(record) @@ -34,10 +34,10 @@ class PersistenceHandler(logging.Handler, object): @stage.stage_class('SetupLoggerStage') class SetupLoggerStage(stage.BootingStage): - """设置日志器阶段""" + """Setup logger stage""" async def run(self, ap: app.Application): - """启动""" + """Setup logger""" persistence_handler = PersistenceHandler('LoggerHandler', ap) extra_handlers = [] diff --git a/pkg/core/stages/show_notes.py b/pkg/core/stages/show_notes.py index 5fa7ff08..d0f861ba 100644 --- a/pkg/core/stages/show_notes.py +++ b/pkg/core/stages/show_notes.py @@ -12,10 +12,10 @@ importutil.import_modules_in_pkg(notes) @stage.stage_class('ShowNotesStage') class ShowNotesStage(stage.BootingStage): - """显示启动信息阶段""" + """Show notes stage""" async def run(self, ap: app.Application): - # 排序 + # Sort note.preregistered_notes.sort(key=lambda x: x.number) for note_cls in note.preregistered_notes: diff --git a/pkg/core/taskmgr.py b/pkg/core/taskmgr.py index 0f756118..ca6eb029 100644 --- a/pkg/core/taskmgr.py +++ b/pkg/core/taskmgr.py @@ -9,13 +9,13 @@ from . import entities as core_entities class TaskContext: - """任务跟踪上下文""" + """Task tracking context""" current_action: str - """当前正在执行的动作""" + """Current action being executed""" log: str - """记录日志""" + """Log""" def __init__(self): self.current_action = 'default' @@ -58,40 +58,40 @@ placeholder_context: TaskContext | None = None class TaskWrapper: - """任务包装器""" + """Task wrapper""" _id_index: int = 0 - """任务ID索引""" + """Task ID index""" id: int - """任务ID""" + """Task ID""" - task_type: str = 'system' # 任务类型: system 或 user - """任务类型""" + task_type: str = 'system' # Task type: system or user + """Task type""" - kind: str = 'system_task' # 由发起者确定任务种类,通常同质化的任务种类相同 - """任务种类""" + kind: str = 'system_task' # Task type determined by the initiator, usually the same task type + """Task type""" name: str = '' - """任务唯一名称""" + """Task unique name""" label: str = '' - """任务显示名称""" + """Task display name""" task_context: TaskContext - """任务上下文""" + """Task context""" task: asyncio.Task - """任务""" + """Task""" task_stack: list = None - """任务堆栈""" + """Task stack""" ap: app.Application - """应用实例""" + """Application instance""" scopes: list[core_entities.LifecycleControlScope] - """任务所属生命周期控制范围""" + """Task scope""" def __init__( self, @@ -165,13 +165,13 @@ class TaskWrapper: class AsyncTaskManager: - """保存app中的所有异步任务 - 包含系统级的和用户级(插件安装、更新等由用户直接发起的)的""" + """Save all asynchronous tasks in the app + Include system-level and user-level (plugin installation, update, etc. initiated by users directly)""" ap: app.Application tasks: list[TaskWrapper] - """所有任务""" + """All tasks""" def __init__(self, ap: app.Application): self.ap = ap diff --git a/pkg/entity/persistence/bot.py b/pkg/entity/persistence/bot.py index 3c08f4ec..08eda478 100644 --- a/pkg/entity/persistence/bot.py +++ b/pkg/entity/persistence/bot.py @@ -4,7 +4,7 @@ from .base import Base class Bot(Base): - """机器人""" + """Bot""" __tablename__ = 'bots' diff --git a/pkg/entity/persistence/metadata.py b/pkg/entity/persistence/metadata.py index d9e03663..4db732b9 100644 --- a/pkg/entity/persistence/metadata.py +++ b/pkg/entity/persistence/metadata.py @@ -12,7 +12,7 @@ initial_metadata = [ class Metadata(Base): - """数据库元数据""" + """Database metadata""" __tablename__ = 'metadata' diff --git a/pkg/entity/persistence/model.py b/pkg/entity/persistence/model.py index 418cab70..e9a104c4 100644 --- a/pkg/entity/persistence/model.py +++ b/pkg/entity/persistence/model.py @@ -4,7 +4,7 @@ from .base import Base class LLMModel(Base): - """LLM 模型""" + """LLM model""" __tablename__ = 'llm_models' diff --git a/pkg/entity/persistence/pipeline.py b/pkg/entity/persistence/pipeline.py index 56e2cae9..3a21dbf2 100644 --- a/pkg/entity/persistence/pipeline.py +++ b/pkg/entity/persistence/pipeline.py @@ -4,7 +4,7 @@ from .base import Base class LegacyPipeline(Base): - """旧版流水线""" + """Legacy pipeline""" __tablename__ = 'legacy_pipelines' @@ -20,13 +20,12 @@ class LegacyPipeline(Base): ) for_version = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) is_default = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False) - stages = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) class PipelineRunRecord(Base): - """流水线运行记录""" + """Pipeline run record""" __tablename__ = 'pipeline_run_records' @@ -43,3 +42,4 @@ class PipelineRunRecord(Base): started_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False) finished_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False) result = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) + knowledge_base_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) diff --git a/pkg/entity/persistence/plugin.py b/pkg/entity/persistence/plugin.py index 30db6bd6..e777441f 100644 --- a/pkg/entity/persistence/plugin.py +++ b/pkg/entity/persistence/plugin.py @@ -4,7 +4,7 @@ from .base import Base class PluginSetting(Base): - """插件配置""" + """Plugin setting""" __tablename__ = 'plugin_settings' diff --git a/pkg/rag/knowledge/utils/crawler.py b/pkg/entity/rag/__init__.py similarity index 100% rename from pkg/rag/knowledge/utils/crawler.py rename to pkg/entity/rag/__init__.py diff --git a/pkg/entity/rag/retriever.py b/pkg/entity/rag/retriever.py new file mode 100644 index 00000000..becaf8db --- /dev/null +++ b/pkg/entity/rag/retriever.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import pydantic + +from typing import Any + + +class RetrieveResultEntry(pydantic.BaseModel): + id: str + + metadata: dict[str, Any] + + distance: float diff --git a/pkg/persistence/database.py b/pkg/persistence/database.py index 528c6a34..4debb03d 100644 --- a/pkg/persistence/database.py +++ b/pkg/persistence/database.py @@ -11,7 +11,7 @@ preregistered_managers: list[type[BaseDatabaseManager]] = [] def manager_class(name: str) -> None: - """注册一个数据库管理类""" + """Register a database manager class""" def decorator(cls: type[BaseDatabaseManager]) -> type[BaseDatabaseManager]: cls.name = name @@ -22,7 +22,7 @@ def manager_class(name: str) -> None: class BaseDatabaseManager(abc.ABC): - """基础数据库管理类""" + """Base database manager class""" name: str diff --git a/pkg/persistence/databases/sqlite.py b/pkg/persistence/databases/sqlite.py index 7b095e61..c1337459 100644 --- a/pkg/persistence/databases/sqlite.py +++ b/pkg/persistence/databases/sqlite.py @@ -7,7 +7,7 @@ from .. import database @database.manager_class('sqlite') class SQLiteDatabaseManager(database.BaseDatabaseManager): - """SQLite 数据库管理类""" + """SQLite database manager""" async def initialize(self) -> None: sqlite_path = 'data/langbot.db' diff --git a/pkg/persistence/mgr.py b/pkg/persistence/mgr.py index 606aa9fd..3aa21ad2 100644 --- a/pkg/persistence/mgr.py +++ b/pkg/persistence/mgr.py @@ -22,12 +22,12 @@ importutil.import_modules_in_pkg(persistence) class PersistenceManager: - """持久化模块管理器""" + """Persistence module manager""" ap: app.Application db: database.BaseDatabaseManager - """数据库管理器""" + """Database manager""" meta: sqlalchemy.MetaData @@ -79,7 +79,7 @@ class PersistenceManager: 'stages': pipeline_service.default_stage_order, 'is_default': True, 'name': 'ChatPipeline', - 'description': '默认提供的流水线,您配置的机器人、第一个模型将自动绑定到此流水线', + 'description': 'Default pipeline, new bots will be bound to this pipeline | 默认提供的流水线,您配置的机器人将自动绑定到此流水线', 'config': pipeline_config, } diff --git a/pkg/persistence/migration.py b/pkg/persistence/migration.py index c191b686..294e30ca 100644 --- a/pkg/persistence/migration.py +++ b/pkg/persistence/migration.py @@ -10,7 +10,7 @@ preregistered_db_migrations: list[typing.Type[DBMigration]] = [] def migration_class(number: int): - """迁移类装饰器""" + """Migration class decorator""" def wrapper(cls: typing.Type[DBMigration]) -> typing.Type[DBMigration]: cls.number = number @@ -21,20 +21,20 @@ def migration_class(number: int): class DBMigration(abc.ABC): - """数据库迁移""" + """Database migration""" number: int - """迁移号""" + """Migration number""" def __init__(self, ap: app.Application): self.ap = ap @abc.abstractmethod async def upgrade(self): - """升级""" + """Upgrade""" pass @abc.abstractmethod async def downgrade(self): - """降级""" + """Downgrade""" pass diff --git a/pkg/persistence/migrations/dbm001_migrate_v3_config.py b/pkg/persistence/migrations/dbm001_migrate_v3_config.py index a1145527..58f05e04 100644 --- a/pkg/persistence/migrations/dbm001_migrate_v3_config.py +++ b/pkg/persistence/migrations/dbm001_migrate_v3_config.py @@ -15,21 +15,21 @@ from ...entity.persistence import ( @migration.migration_class(1) class DBMigrateV3Config(migration.DBMigration): - """从 v3 的配置迁移到 v4 的数据库""" + """Migrate v3 config to v4 database""" async def upgrade(self): - """升级""" + """Upgrade""" """ - 将 data/config 下的所有配置文件进行迁移。 - 迁移后,之前的配置文件都保存到 data/legacy/config 下。 - 迁移后,data/metadata/ 下的所有配置文件都保存到 data/legacy/metadata 下。 + Migrate all config files under data/config. + After migration, all previous config files are saved under data/legacy/config. + After migration, all config files under data/metadata/ are saved under data/legacy/metadata. """ if self.ap.provider_cfg is None: return - # ======= 迁移模型 ======= - # 只迁移当前选中的模型 + # ======= Migrate model ======= + # Only migrate the currently selected model model_name = self.ap.provider_cfg.data.get('model', 'gpt-4o') model_requester = 'openai-chat-completions' @@ -91,8 +91,8 @@ class DBMigrateV3Config(migration.DBMigration): sqlalchemy.insert(persistence_model.LLMModel).values(**llm_model_data) ) - # ======= 迁移流水线配置 ======= - # 修改到默认流水线 + # ======= Migrate pipeline config ======= + # Modify to default pipeline default_pipeline = [ self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline) for pipeline in ( @@ -184,8 +184,8 @@ class DBMigrateV3Config(migration.DBMigration): .where(persistence_pipeline.LegacyPipeline.uuid == default_pipeline['uuid']) ) - # ======= 迁移机器人 ======= - # 只迁移启用的机器人 + # ======= Migrate bot ======= + # Only migrate enabled bots for adapter in self.ap.platform_cfg.data.get('platform-adapters', []): if not adapter.get('enable'): continue @@ -207,7 +207,7 @@ class DBMigrateV3Config(migration.DBMigration): await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_bot.Bot).values(**bot_data)) - # ======= 迁移系统设置 ======= + # ======= Migrate system settings ======= self.ap.instance_config.data['admins'] = self.ap.system_cfg.data['admin-sessions'] self.ap.instance_config.data['api']['port'] = self.ap.system_cfg.data['http-api']['port'] self.ap.instance_config.data['command'] = { @@ -223,7 +223,7 @@ class DBMigrateV3Config(migration.DBMigration): await self.ap.instance_config.dump_config() # ======= move files ======= - # 迁移 data/config 下的所有配置文件 + # Migrate all config files under data/config all_legacy_dir_name = [ 'config', # 'metadata', @@ -246,4 +246,4 @@ class DBMigrateV3Config(migration.DBMigration): move_legacy_files(dir_name) async def downgrade(self): - """降级""" + """Downgrade""" diff --git a/pkg/persistence/migrations/dbm002_combine_quote_msg_config.py b/pkg/persistence/migrations/dbm002_combine_quote_msg_config.py index cebf403b..349bb0c2 100644 --- a/pkg/persistence/migrations/dbm002_combine_quote_msg_config.py +++ b/pkg/persistence/migrations/dbm002_combine_quote_msg_config.py @@ -7,10 +7,10 @@ from ...entity.persistence import pipeline as persistence_pipeline @migration.migration_class(2) class DBMigrateCombineQuoteMsgConfig(migration.DBMigration): - """引用消息合并配置""" + """Combine quote message config""" async def upgrade(self): - """升级""" + """Upgrade""" # read all pipelines pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline)) @@ -37,5 +37,5 @@ class DBMigrateCombineQuoteMsgConfig(migration.DBMigration): ) async def downgrade(self): - """降级""" + """Downgrade""" pass diff --git a/pkg/persistence/migrations/dbm003_n8n_config.py b/pkg/persistence/migrations/dbm003_n8n_config.py index 8705040b..15484f22 100644 --- a/pkg/persistence/migrations/dbm003_n8n_config.py +++ b/pkg/persistence/migrations/dbm003_n8n_config.py @@ -7,10 +7,10 @@ from ...entity.persistence import pipeline as persistence_pipeline @migration.migration_class(3) class DBMigrateN8nConfig(migration.DBMigration): - """N8n配置""" + """N8n config""" async def upgrade(self): - """升级""" + """Upgrade""" # read all pipelines pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline)) @@ -45,5 +45,5 @@ class DBMigrateN8nConfig(migration.DBMigration): ) async def downgrade(self): - """降级""" + """Downgrade""" pass diff --git a/pkg/persistence/migrations/dbm004_rag_kb_uuid.py b/pkg/persistence/migrations/dbm004_rag_kb_uuid.py new file mode 100644 index 00000000..b45cfa78 --- /dev/null +++ b/pkg/persistence/migrations/dbm004_rag_kb_uuid.py @@ -0,0 +1,38 @@ +from .. import migration + +import sqlalchemy + +from ...entity.persistence import pipeline as persistence_pipeline + + +@migration.migration_class(4) +class DBMigrateRAGKBUUID(migration.DBMigration): + """RAG知识库UUID""" + + async def upgrade(self): + """升级""" + # read all pipelines + pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline)) + + for pipeline in pipelines: + serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline) + + config = serialized_pipeline['config'] + + if 'knowledge-base' not in config['ai']['local-agent']: + config['ai']['local-agent']['knowledge-base'] = '' + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_pipeline.LegacyPipeline) + .where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid']) + .values( + { + 'config': config, + 'for_version': self.ap.ver_mgr.get_current_version(), + } + ) + ) + + async def downgrade(self): + """降级""" + pass diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py index 3b927a55..c88a1aa2 100644 --- a/pkg/pipeline/bansess/bansess.py +++ b/pkg/pipeline/bansess/bansess.py @@ -6,9 +6,9 @@ from ...core import entities as core_entities @stage.stage_class('BanSessionCheckStage') class BanSessionCheckStage(stage.PipelineStage): - """访问控制处理阶段 + """Access control processing stage - 仅检查query中群号或个人号是否在访问控制列表中。 + Only check if the group or personal number in the query is in the access control list. """ async def initialize(self, pipeline_config: dict): @@ -41,5 +41,7 @@ class BanSessionCheckStage(stage.PipelineStage): return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT, new_query=query, - console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else '', + console_notice=f'Ignore message according to access control: {query.launcher_type.value}_{query.launcher_id}' + if not ctn + else '', ) diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 0a3ceaae..36d8a7f4 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -13,13 +13,13 @@ preregistered_filters: list[typing.Type[ContentFilter]] = [] def filter_class( name: str, ) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: - """内容过滤器类装饰器 + """Content filter class decorator Args: - name (str): 过滤器名称 + name (str): Filter name Returns: - typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器 + typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: Decorator """ def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]: @@ -35,7 +35,7 @@ def filter_class( class ContentFilter(metaclass=abc.ABCMeta): - """内容过滤器抽象类""" + """Content filter abstract class""" name: str @@ -46,31 +46,31 @@ class ContentFilter(metaclass=abc.ABCMeta): @property def enable_stages(self): - """启用的阶段 + """Enabled stages - 默认为消息请求AI前后的两个阶段。 + Default is the two stages before and after the message request to AI. - entity.EnableStage.PRE: 消息请求AI前,此时需要检查的内容是用户的输入消息。 - entity.EnableStage.POST: 消息请求AI后,此时需要检查的内容是AI的回复消息。 + entity.EnableStage.PRE: Before message request to AI, the content to check is the user's input message. + entity.EnableStage.POST: After message request to AI, the content to check is the AI's reply message. """ return [entities.EnableStage.PRE, entities.EnableStage.POST] async def initialize(self): - """初始化过滤器""" + """Initialize filter""" pass @abc.abstractmethod async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult: - """处理消息 + """Process message - 分为前后阶段,具体取决于 enable_stages 的值。 - 对于内容过滤器来说,不需要考虑消息所处的阶段,只需要检查消息内容即可。 + It is divided into two stages, depending on the value of enable_stages. + For content filters, you do not need to consider the stage of the message, you only need to check the message content. Args: - message (str): 需要检查的内容 - image_url (str): 要检查的图片的 URL + message (str): Content to check + image_url (str): URL of the image to check Returns: - entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档 + entities.FilterResult: Filter result, please refer to the documentation of entities.FilterResult class """ raise NotImplementedError diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index 916a1bc1..b03e79a9 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -8,7 +8,7 @@ from ....core import entities as core_entities @filter_model.filter_class('ban-word-filter') class BanWordFilter(filter_model.ContentFilter): - """根据内容过滤""" + """Filter content""" async def initialize(self): pass diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py index 5e410e31..b80d90eb 100644 --- a/pkg/pipeline/cntfilter/filters/cntignore.py +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -8,7 +8,7 @@ from ....core import entities as core_entities @filter_model.filter_class('content-ignore') class ContentIgnore(filter_model.ContentFilter): - """根据内容忽略消息""" + """Ignore message according to content""" @property def enable_stages(self): @@ -24,7 +24,7 @@ class ContentIgnore(filter_model.ContentFilter): level=entities.ResultLevel.BLOCK, replacement='', user_notice='', - console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息', + console_notice='Ignore message according to prefix rule in ignore_rules', ) if 'regexp' in query.pipeline_config['trigger']['ignore-rules']: @@ -34,7 +34,7 @@ class ContentIgnore(filter_model.ContentFilter): level=entities.ResultLevel.BLOCK, replacement='', user_notice='', - console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息', + console_notice='Ignore message according to regexp rule in ignore_rules', ) return entities.FilterResult( diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 5be20650..03457212 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -16,9 +16,9 @@ importutil.import_modules_in_pkg(strategies) @stage.stage_class('LongTextProcessStage') class LongTextProcessStage(stage.PipelineStage): - """长消息处理阶段 + """Long message processing stage - 改写: + Rewrite: - resp_message_chain """ @@ -36,22 +36,22 @@ class LongTextProcessStage(stage.PipelineStage): use_font = 'C:/Windows/Fonts/msyh.ttc' if not os.path.exists(use_font): self.ap.logger.warn( - '未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。' + 'Font file not found, and Windows system font cannot be used, switch to forward message component to send long messages, you can adjust the related settings in the configuration file.' ) config['blob_message_strategy'] = 'forward' else: - self.ap.logger.info('使用Windows自带字体:' + use_font) + self.ap.logger.info('Using Windows system font: ' + use_font) config['font-path'] = use_font else: self.ap.logger.warn( - '未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。' + 'Font file not found, and system font cannot be used, switch to forward message component to send long messages, you can adjust the related settings in the configuration file.' ) pipeline_config['output']['long-text-processing']['strategy'] = 'forward' except Exception: traceback.print_exc() self.ap.logger.error( - '加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。'.format( + 'Failed to load font file ({}), switch to forward message component to send long messages, you can adjust the related settings in the configuration file.'.format( use_font ) ) @@ -63,12 +63,12 @@ class LongTextProcessStage(stage.PipelineStage): self.strategy_impl = strategy_cls(self.ap) break else: - raise ValueError(f'未找到名为 {config["strategy"]} 的长消息处理策略') + raise ValueError(f'Long message processing strategy not found: {config["strategy"]}') await self.strategy_impl.initialize() async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: - # 检查是否包含非 Plain 组件 + # Check if it contains non-Plain components contains_non_plain = False for msg in query.resp_message_chain[-1]: @@ -77,7 +77,7 @@ class LongTextProcessStage(stage.PipelineStage): break if contains_non_plain: - self.ap.logger.debug('消息中包含非 Plain 组件,跳过长消息处理。') + self.ap.logger.debug('Message contains non-Plain components, skip long message processing.') elif ( len(str(query.resp_message_chain[-1])) > query.pipeline_config['output']['long-text-processing']['threshold'] diff --git a/pkg/pipeline/longtext/strategies/forward.py b/pkg/pipeline/longtext/strategies/forward.py index 6228d580..cb772339 100644 --- a/pkg/pipeline/longtext/strategies/forward.py +++ b/pkg/pipeline/longtext/strategies/forward.py @@ -15,17 +15,17 @@ Forward = platform_message.Forward class ForwardComponentStrategy(strategy_model.LongTextStrategy): async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: display = ForwardMessageDiaplay( - title='群聊的聊天记录', - brief='[聊天记录]', - source='聊天记录', - preview=['QQ用户: ' + message], - summary='查看1条转发消息', + title='Group chat history', + brief='[Chat history]', + source='Chat history', + preview=['User: ' + message], + summary='View 1 forwarded message', ) node_list = [ platform_message.ForwardMessageNode( sender_id=query.adapter.bot_account_id, - sender_name='QQ用户', + sender_name='User', message_chain=platform_message.MessageChain([message]), ) ] diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py index 0ddec0c6..5b521067 100644 --- a/pkg/pipeline/longtext/strategy.py +++ b/pkg/pipeline/longtext/strategy.py @@ -14,13 +14,13 @@ preregistered_strategies: list[typing.Type[LongTextStrategy]] = [] def strategy_class( name: str, ) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: - """长文本处理策略类装饰器 + """Long text processing strategy class decorator Args: - name (str): 策略名称 + name (str): Strategy name Returns: - typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: 装饰器 + typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: Decorator """ def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]: @@ -36,7 +36,7 @@ def strategy_class( class LongTextStrategy(metaclass=abc.ABCMeta): - """长文本处理策略抽象类""" + """Long text processing strategy abstract class""" name: str @@ -50,15 +50,15 @@ class LongTextStrategy(metaclass=abc.ABCMeta): @abc.abstractmethod async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: - """处理长文本 + """Process long text - 在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法 + If the text length exceeds the threshold, this method will be called. Args: - message (str): 消息 - query (core_entities.Query): 此次请求的上下文对象 + message (str): Message + query (core_entities.Query): Query object Returns: - list[platform_message.MessageComponent]: 转换后的 平台 消息组件列表 + list[platform_message.MessageComponent]: Converted platform message components """ return [] diff --git a/pkg/pipeline/msgtrun/msgtrun.py b/pkg/pipeline/msgtrun/msgtrun.py index c64f67fc..1c5ee17d 100644 --- a/pkg/pipeline/msgtrun/msgtrun.py +++ b/pkg/pipeline/msgtrun/msgtrun.py @@ -12,9 +12,9 @@ importutil.import_modules_in_pkg(truncators) @stage.stage_class('ConversationMessageTruncator') class ConversationMessageTruncator(stage.PipelineStage): - """会话消息截断器 + """Conversation message truncator - 用于截断会话消息链,以适应平台消息长度限制。 + Used to truncate the conversation message chain to adapt to the LLM message length limit. """ trun: truncator.Truncator @@ -27,10 +27,10 @@ class ConversationMessageTruncator(stage.PipelineStage): self.trun = trun(self.ap) break else: - raise ValueError(f'未知的截断器: {use_method}') + raise ValueError(f'Unknown truncator: {use_method}') async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: - """处理""" + """Process""" query = await self.trun.truncate(query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/pkg/pipeline/msgtrun/truncators/round.py b/pkg/pipeline/msgtrun/truncators/round.py index fa72a0e1..2acb1d8c 100644 --- a/pkg/pipeline/msgtrun/truncators/round.py +++ b/pkg/pipeline/msgtrun/truncators/round.py @@ -6,17 +6,17 @@ from ....core import entities as core_entities @truncator.truncator_class('round') class RoundTruncator(truncator.Truncator): - """前文回合数阶段器""" + """Truncate the conversation message chain to adapt to the LLM message length limit.""" async def truncate(self, query: core_entities.Query) -> core_entities.Query: - """截断""" + """Truncate""" max_round = query.pipeline_config['ai']['local-agent']['max-round'] temp_messages = [] current_round = 0 - # 从后往前遍历 + # Traverse from back to front for msg in query.messages[::-1]: if current_round < max_round: temp_messages.append(msg) diff --git a/pkg/pipeline/pipelinemgr.py b/pkg/pipeline/pipelinemgr.py index b61e34ad..77df09dc 100644 --- a/pkg/pipeline/pipelinemgr.py +++ b/pkg/pipeline/pipelinemgr.py @@ -144,23 +144,27 @@ class RuntimePipeline: result = await result if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果 - self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {result}') + self.ap.logger.debug( + f'Stage {stage_container.inst_name} processed query {query.query_id} res {result.result_type}' + ) await self._check_output(query, result) if result.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}') + self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}') break elif result.result_type == pipeline_entities.ResultType.CONTINUE: query = result.new_query elif isinstance(result, typing.AsyncGenerator): # 生成器 - self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} gen') + self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query.query_id} gen') async for sub_result in result: - self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {sub_result}') + self.ap.logger.debug( + f'Stage {stage_container.inst_name} processed query {query.query_id} res {sub_result.result_type}' + ) await self._check_output(query, sub_result) if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}') + self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}') break elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE: query = sub_result.new_query @@ -192,7 +196,7 @@ class RuntimePipeline: if event_ctx.is_prevented_default(): return - self.ap.logger.debug(f'Processing query {query}') + self.ap.logger.debug(f'Processing query {query.query_id}') await self._execute_from_stage(0, query) except Exception as e: @@ -200,7 +204,7 @@ class RuntimePipeline: self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}') self.ap.logger.error(f'Traceback: {traceback.format_exc()}') finally: - self.ap.logger.debug(f'Query {query} processed') + self.ap.logger.debug(f'Query {query.query_id} processed') class PipelineManager: diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 19478200..1aada6b3 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -11,11 +11,11 @@ from ...platform.types import message as platform_message @stage.stage_class('PreProcessor') class PreProcessor(stage.PipelineStage): - """请求预处理阶段 + """Request pre-processing stage - 签出会话、prompt、上文、模型、内容函数。 + Check out session, prompt, context, model, and content functions. - 改写: + Rewrite: - session - prompt - messages @@ -29,12 +29,12 @@ class PreProcessor(stage.PipelineStage): query: core_entities.Query, stage_inst_name: str, ) -> entities.StageProcessResult: - """处理""" + """Process""" selected_runner = query.pipeline_config['ai']['runner']['runner'] session = await self.ap.sess_mgr.get_session(query) - # 非 local-agent 时,llm_model 为 None + # When not local-agent, llm_model is None llm_model = ( await self.ap.model_mgr.get_model_by_uuid(query.pipeline_config['ai']['local-agent']['model']) if selected_runner == 'local-agent' @@ -51,7 +51,7 @@ class PreProcessor(stage.PipelineStage): conversation.use_llm_model = llm_model - # 设置query + # Set query query.session = session query.prompt = conversation.prompt.copy() query.messages = conversation.messages.copy() @@ -80,14 +80,15 @@ class PreProcessor(stage.PipelineStage): if me.type == 'image_url': msg.content.remove(me) - content_list = [] + content_list: list[llm_entities.ContentElement] = [] plain_text = '' qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message') + # tidy the content_list + # combine all text content into one, and put it in the first position for me in query.message_chain: if isinstance(me, platform_message.Plain): - content_list.append(llm_entities.ContentElement.from_text(me.text)) plain_text += me.text elif isinstance(me, platform_message.Image): if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__( @@ -106,10 +107,12 @@ class PreProcessor(stage.PipelineStage): if msg.base64 is not None: content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64)) + content_list.insert(0, llm_entities.ContentElement.from_text(plain_text)) + query.variables['user_message_text'] = plain_text query.user_message = llm_entities.Message(role='user', content=content_list) - # =========== 触发事件 PromptPreProcessing + # =========== Trigger event PromptPreProcessing event_ctx = await self.ap.plugin_mgr.emit_event( event=events.PromptPreProcessing( diff --git a/pkg/pipeline/process/handler.py b/pkg/pipeline/process/handler.py index 8a32bcfb..837b72e2 100644 --- a/pkg/pipeline/process/handler.py +++ b/pkg/pipeline/process/handler.py @@ -25,7 +25,7 @@ class MessageHandler(metaclass=abc.ABCMeta): def cut_str(self, s: str) -> str: """ - 取字符串第一行,最多20个字符,若有多行,或超过20个字符,则加省略号 + Take the first line of the string, up to 20 characters, if there are multiple lines, or more than 20 characters, add an ellipsis """ s0 = s.split('\n')[0] if len(s0) > 20 or '\n' in s: diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 35fa1611..2aa08e17 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -22,11 +22,11 @@ class ChatMessageHandler(handler.MessageHandler): self, query: core_entities.Query, ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: - """处理""" - # 调API - # 生成器 + """Process""" + # Call API + # generator - # 触发插件事件 + # Trigger plugin event event_class = ( events.PersonNormalMessageReceived if query.launcher_type == core_entities.LauncherTypes.PERSON @@ -54,7 +54,7 @@ class ChatMessageHandler(handler.MessageHandler): yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query) else: if event_ctx.event.alter is not None: - # if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter + # if isinstance(event_ctx.event, str): # Currently not considering multi-modal alter query.user_message.content = event_ctx.event.alter text_length = 0 @@ -65,12 +65,12 @@ class ChatMessageHandler(handler.MessageHandler): runner = r(self.ap, query.pipeline_config) break else: - raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}') + raise ValueError(f'Request runner not found: {query.pipeline_config["ai"]["runner"]["runner"]}') async for result in runner.run(query): query.resp_messages.append(result) - self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') + self.ap.logger.info(f'Response({query.query_id}): {self.cut_str(result.readable_str())}') if result.content is not None: text_length += len(result.content) @@ -80,7 +80,7 @@ class ChatMessageHandler(handler.MessageHandler): query.session.using_conversation.messages.append(query.user_message) query.session.using_conversation.messages.extend(query.resp_messages) except Exception as e: - self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}') + self.ap.logger.error(f'Request failed({query.query_id}): {type(e).__name__} {str(e)}') hide_exception_info = query.pipeline_config['output']['misc']['hide-exception'] diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index cc0e9314..7348d6b8 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -15,7 +15,7 @@ class CommandHandler(handler.MessageHandler): self, query: core_entities.Query, ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: - """处理""" + """Process""" command_text = str(query.message_chain).strip()[1:] @@ -70,7 +70,7 @@ class CommandHandler(handler.MessageHandler): ) ) - self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}') + self.ap.logger.info(f'Command({query.query_id}) error: {self.cut_str(str(ret.error))}') yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) elif ret.text is not None or ret.image_url is not None: @@ -89,7 +89,7 @@ class CommandHandler(handler.MessageHandler): ) ) - self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}') + self.ap.logger.info(f'Command returned: {self.cut_str(str(content[0]))}') yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py index 64903552..db66135c 100644 --- a/pkg/pipeline/process/process.py +++ b/pkg/pipeline/process/process.py @@ -33,11 +33,11 @@ class Processor(stage.PipelineStage): query: core_entities.Query, stage_inst_name: str, ) -> entities.StageProcessResult: - """处理""" + """Process""" message_text = str(query.message_chain).strip() self.ap.logger.info( - f'处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}' + f'Processing request from {query.launcher_type.value}_{query.launcher_id} ({query.query_id}): {message_text}' ) async def generator(): diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index 2730874f..357eb48a 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -76,8 +76,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): return msg_list, msg_id, msg_time @staticmethod - async def target2yiri(message: str, message_id: int = -1, bot=None): - print(message) + async def target2yiri(message: str, message_id: int = -1, bot: aiocqhttp.CQHttp = None): message = aiocqhttp.Message(message) def get_face_name(face_id): @@ -271,16 +270,16 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): ) yiri_msg_list.append(reply_msg) - elif msg.type == 'file': - # file_name = msg.data['file'] - file_id = msg.data['file_id'] - file_data = await bot.get_file(file_id=file_id) - file_name = file_data.get('file_name') - file_path = file_data.get('file') - _ = file_path - file_url = file_data.get('file_url') - file_size = file_data.get('file_size') - yiri_msg_list.append(platform_message.File(id=file_id, name=file_name, url=file_url, size=file_size)) + # 这里下载所有文件会导致下载文件过多,暂时不下载 + # elif msg.type == 'file': + # # file_name = msg.data['file'] + # file_id = msg.data['file_id'] + # file_data = await bot.get_file(file_id=file_id) + # file_name = file_data.get('file_name') + # file_path = file_data.get('file') + # file_url = file_data.get('file_url') + # file_size = file_data.get('file_size') + # yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size)) elif msg.type == 'face': face_id = msg.data['id'] face_name = msg.data['raw']['faceText'] diff --git a/pkg/platform/sources/dingtalk.py b/pkg/platform/sources/dingtalk.py index 3147c984..a40b0f9b 100644 --- a/pkg/platform/sources/dingtalk.py +++ b/pkg/platform/sources/dingtalk.py @@ -116,6 +116,15 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = self.config['robot_name'] + self.bot = DingTalkClient( + client_id=config['client_id'], + client_secret=config['client_secret'], + robot_name=config['robot_name'], + robot_code=config['robot_code'], + markdown_card=config['markdown_card'], + logger=self.logger, + ) + async def reply_message( self, message_source: platform_events.MessageEvent, @@ -157,15 +166,6 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): self.bot.on_message('GroupMessage')(on_message) async def run_async(self): - config = self.config - self.bot = DingTalkClient( - client_id=config['client_id'], - client_secret=config['client_secret'], - robot_name=config['robot_name'], - robot_code=config['robot_code'], - markdown_card=config['markdown_card'], - logger=self.logger, - ) await self.bot.start() async def kill(self) -> bool: diff --git a/pkg/platform/sources/discord.py b/pkg/platform/sources/discord.py index 6cc09a72..c279e714 100644 --- a/pkg/platform/sources/discord.py +++ b/pkg/platform/sources/discord.py @@ -8,15 +8,592 @@ import base64 import uuid import os import datetime +import io +import asyncio +from enum import Enum import aiohttp from .. import adapter from ...core import app +from ..logger import EventLogger from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities -from ..logger import EventLogger + +# 语音功能相关异常定义 +class VoiceConnectionError(Exception): + """语音连接基础异常""" + def __init__(self, message: str, error_code: str = None, guild_id: int = None): + super().__init__(message) + self.error_code = error_code + self.guild_id = guild_id + self.timestamp = datetime.datetime.now() + + +class VoicePermissionError(VoiceConnectionError): + """语音权限异常""" + def __init__(self, message: str, missing_permissions: list = None, user_id: int = None, channel_id: int = None): + super().__init__(message, "PERMISSION_ERROR") + self.missing_permissions = missing_permissions or [] + self.user_id = user_id + self.channel_id = channel_id + + +class VoiceNetworkError(VoiceConnectionError): + """语音网络异常""" + def __init__(self, message: str, retry_count: int = 0): + super().__init__(message, "NETWORK_ERROR") + self.retry_count = retry_count + self.last_attempt = datetime.datetime.now() + + +class VoiceConnectionStatus(Enum): + """语音连接状态枚举""" + IDLE = "idle" + CONNECTING = "connecting" + CONNECTED = "connected" + PLAYING = "playing" + RECONNECTING = "reconnecting" + FAILED = "failed" + + +class VoiceConnectionInfo: + """ + 语音连接信息类 + + 用于存储和管理单个语音连接的详细信息,包括连接状态、时间戳、 + 频道信息等。提供连接信息的标准化数据结构。 + + @author: @ydzat + @version: 1.0 + @since: 2025-07-04 + """ + + def __init__(self, guild_id: int, channel_id: int, channel_name: str = None): + """ + 初始化语音连接信息 + + @author: @ydzat + + Args: + guild_id (int): 服务器ID + channel_id (int): 语音频道ID + channel_name (str, optional): 语音频道名称 + """ + self.guild_id = guild_id + self.channel_id = channel_id + self.channel_name = channel_name or f"Channel-{channel_id}" + self.connected = False + self.connection_time: datetime.datetime = None + self.last_activity = datetime.datetime.now() + self.status = VoiceConnectionStatus.IDLE + self.user_count = 0 + self.latency = 0.0 + self.connection_health = "unknown" + self.voice_client = None + + def update_status(self, status: VoiceConnectionStatus): + """ + 更新连接状态 + + @author: @ydzat + + Args: + status (VoiceConnectionStatus): 新的连接状态 + """ + self.status = status + self.last_activity = datetime.datetime.now() + + if status == VoiceConnectionStatus.CONNECTED: + self.connected = True + if self.connection_time is None: + self.connection_time = datetime.datetime.now() + elif status in [VoiceConnectionStatus.IDLE, VoiceConnectionStatus.FAILED]: + self.connected = False + self.connection_time = None + self.voice_client = None + + def to_dict(self) -> dict: + """ + 转换为字典格式 + + @author: @ydzat + + Returns: + dict: 连接信息的字典表示 + """ + return { + "guild_id": self.guild_id, + "channel_id": self.channel_id, + "channel_name": self.channel_name, + "connected": self.connected, + "connection_time": self.connection_time.isoformat() if self.connection_time else None, + "last_activity": self.last_activity.isoformat(), + "status": self.status.value, + "user_count": self.user_count, + "latency": self.latency, + "connection_health": self.connection_health + } + + +class VoiceConnectionManager: + """ + 语音连接管理器 + + 负责管理多个服务器的语音连接,提供连接建立、断开、状态查询等功能。 + 采用单例模式确保全局只有一个连接管理器实例。 + + @author: @ydzat + @version: 1.0 + @since: 2025-07-04 + """ + + def __init__(self, bot: discord.Client, logger: EventLogger): + """ + 初始化语音连接管理器 + + @author: @ydzat + + Args: + bot (discord.Client): Discord 客户端实例 + logger (EventLogger): 事件日志记录器 + """ + self.bot = bot + self.logger = logger + self.connections: typing.Dict[int, VoiceConnectionInfo] = {} + self._connection_lock = asyncio.Lock() + self._cleanup_task = None + self._monitoring_enabled = True + + async def join_voice_channel(self, guild_id: int, channel_id: int, + user_id: int = None) -> discord.VoiceClient: + """ + 加入语音频道 + + 验证用户权限和频道状态后,建立到指定语音频道的连接。 + 支持连接复用和自动重连机制。 + + @author: @ydzat + + Args: + guild_id (int): 服务器ID + channel_id (int): 语音频道ID + user_id (int, optional): 请求用户ID,用于权限验证 + + Returns: + discord.VoiceClient: 语音客户端实例 + + Raises: + VoicePermissionError: 权限不足时抛出 + VoiceNetworkError: 网络连接失败时抛出 + VoiceConnectionError: 其他连接错误时抛出 + """ + async with self._connection_lock: + try: + # 获取服务器和频道对象 + guild = self.bot.get_guild(guild_id) + if not guild: + raise VoiceConnectionError( + f"无法找到服务器 {guild_id}", + "GUILD_NOT_FOUND", + guild_id + ) + + channel = guild.get_channel(channel_id) + if not channel or not isinstance(channel, discord.VoiceChannel): + raise VoiceConnectionError( + f"无法找到语音频道 {channel_id}", + "CHANNEL_NOT_FOUND", + guild_id + ) + + # 验证用户是否在语音频道中(如果提供了用户ID) + if user_id: + await self._validate_user_in_channel(guild, channel, user_id) + + # 验证机器人权限 + await self._validate_bot_permissions(channel) + + # 检查是否已有连接 + if guild_id in self.connections: + existing_conn = self.connections[guild_id] + if existing_conn.connected and existing_conn.voice_client: + if existing_conn.channel_id == channel_id: + # 已连接到相同频道,返回现有连接 + await self.logger.info(f"复用现有语音连接: {guild.name} -> {channel.name}") + return existing_conn.voice_client + else: + # 连接到不同频道,先断开旧连接 + await self._disconnect_internal(guild_id) + + # 建立新连接 + voice_client = await channel.connect() + + # 更新连接信息 + conn_info = VoiceConnectionInfo(guild_id, channel_id, channel.name) + conn_info.voice_client = voice_client + conn_info.update_status(VoiceConnectionStatus.CONNECTED) + conn_info.user_count = len(channel.members) + self.connections[guild_id] = conn_info + + await self.logger.info(f"成功连接到语音频道: {guild.name} -> {channel.name}") + return voice_client + + except discord.ClientException as e: + raise VoiceNetworkError(f"Discord 客户端错误: {str(e)}") + except discord.opus.OpusNotLoaded as e: + raise VoiceConnectionError(f"Opus 编码器未加载: {str(e)}", "OPUS_NOT_LOADED", guild_id) + except Exception as e: + await self.logger.error(f"连接语音频道时发生未知错误: {str(e)}") + raise VoiceConnectionError(f"连接失败: {str(e)}", "UNKNOWN_ERROR", guild_id) + + async def leave_voice_channel(self, guild_id: int) -> bool: + """ + 离开语音频道 + + 断开指定服务器的语音连接,清理相关资源和状态信息。 + 确保音频播放停止后再断开连接。 + + @author: @ydzat + + Args: + guild_id (int): 服务器ID + + Returns: + bool: 断开是否成功 + """ + async with self._connection_lock: + return await self._disconnect_internal(guild_id) + + async def _disconnect_internal(self, guild_id: int) -> bool: + """ + 内部断开连接方法 + + @author: @ydzat + + Args: + guild_id (int): 服务器ID + + Returns: + bool: 断开是否成功 + """ + if guild_id not in self.connections: + return True + + conn_info = self.connections[guild_id] + + try: + if conn_info.voice_client and conn_info.voice_client.is_connected(): + # 停止当前播放 + if conn_info.voice_client.is_playing(): + conn_info.voice_client.stop() + + # 等待播放完全停止 + await asyncio.sleep(0.1) + + # 断开连接 + await conn_info.voice_client.disconnect() + + conn_info.update_status(VoiceConnectionStatus.IDLE) + del self.connections[guild_id] + + await self.logger.info(f"已断开语音连接: Guild {guild_id}") + return True + + except Exception as e: + await self.logger.error(f"断开语音连接时发生错误: {str(e)}") + # 即使出错也要清理连接记录 + conn_info.update_status(VoiceConnectionStatus.FAILED) + if guild_id in self.connections: + del self.connections[guild_id] + return False + + async def get_voice_client(self, guild_id: int) -> typing.Optional[discord.VoiceClient]: + """ + 获取语音客户端 + + 返回指定服务器的语音客户端实例,如果未连接则返回 None。 + 会验证连接的有效性,自动清理无效连接。 + + @author: @ydzat + + Args: + guild_id (int): 服务器ID + + Returns: + Optional[discord.VoiceClient]: 语音客户端实例或 None + """ + if guild_id not in self.connections: + return None + + conn_info = self.connections[guild_id] + + # 验证连接是否仍然有效 + if conn_info.voice_client and not conn_info.voice_client.is_connected(): + # 连接已失效,清理状态 + await self._disconnect_internal(guild_id) + return None + + return conn_info.voice_client if conn_info.connected else None + + async def is_connected_to_voice(self, guild_id: int) -> bool: + """ + 检查是否连接到语音频道 + + @author: @ydzat + + Args: + guild_id (int): 服务器ID + + Returns: + bool: 是否已连接 + """ + if guild_id not in self.connections: + return False + + conn_info = self.connections[guild_id] + + # 检查实际连接状态 + if conn_info.voice_client and not conn_info.voice_client.is_connected(): + # 连接已失效,清理状态 + await self._disconnect_internal(guild_id) + return False + + return conn_info.connected + + async def get_connection_status(self, guild_id: int) -> typing.Optional[dict]: + """ + 获取连接状态信息 + + @author: @ydzat + + Args: + guild_id (int): 服务器ID + + Returns: + Optional[dict]: 连接状态信息字典或 None + """ + if guild_id not in self.connections: + return None + + conn_info = self.connections[guild_id] + + # 更新实时信息 + if conn_info.voice_client and conn_info.voice_client.is_connected(): + conn_info.latency = conn_info.voice_client.latency * 1000 # 转换为毫秒 + conn_info.connection_health = "good" if conn_info.latency < 100 else "poor" + + # 更新频道用户数 + guild = self.bot.get_guild(guild_id) + if guild: + channel = guild.get_channel(conn_info.channel_id) + if channel and isinstance(channel, discord.VoiceChannel): + conn_info.user_count = len(channel.members) + + return conn_info.to_dict() + + async def list_active_connections(self) -> typing.List[dict]: + """ + 列出所有活跃连接 + + @author: @ydzat + + Returns: + List[dict]: 活跃连接列表 + """ + active_connections = [] + + for guild_id, conn_info in self.connections.items(): + if conn_info.connected: + status = await self.get_connection_status(guild_id) + if status: + active_connections.append(status) + + return active_connections + + async def get_voice_channel_info(self, guild_id: int, channel_id: int) -> typing.Optional[dict]: + """ + 获取语音频道信息 + + @author: @ydzat + + Args: + guild_id (int): 服务器ID + channel_id (int): 频道ID + + Returns: + Optional[dict]: 频道信息字典或 None + """ + guild = self.bot.get_guild(guild_id) + if not guild: + return None + + channel = guild.get_channel(channel_id) + if not channel or not isinstance(channel, discord.VoiceChannel): + return None + + # 获取用户信息 + users = [] + for member in channel.members: + users.append({ + "id": member.id, + "name": member.display_name, + "status": str(member.status), + "is_bot": member.bot + }) + + # 获取权限信息 + bot_member = guild.me + permissions = channel.permissions_for(bot_member) + + return { + "channel_id": channel_id, + "channel_name": channel.name, + "guild_id": guild_id, + "guild_name": guild.name, + "user_limit": channel.user_limit, + "current_users": users, + "user_count": len(users), + "bitrate": channel.bitrate, + "permissions": { + "connect": permissions.connect, + "speak": permissions.speak, + "use_voice_activation": permissions.use_voice_activation, + "priority_speaker": permissions.priority_speaker + } + } + + async def _validate_user_in_channel(self, guild: discord.Guild, + channel: discord.VoiceChannel, user_id: int): + """ + 验证用户是否在语音频道中 + + @author: @ydzat + + Args: + guild: Discord 服务器对象 + channel: 语音频道对象 + user_id: 用户ID + + Raises: + VoicePermissionError: 用户不在频道中时抛出 + """ + member = guild.get_member(user_id) + if not member: + raise VoicePermissionError( + f"无法找到用户 {user_id}", + ["member_not_found"], + user_id, + channel.id + ) + + if not member.voice or member.voice.channel != channel: + raise VoicePermissionError( + f"用户 {member.display_name} 不在语音频道 {channel.name} 中", + ["user_not_in_channel"], + user_id, + channel.id + ) + + async def _validate_bot_permissions(self, channel: discord.VoiceChannel): + """ + 验证机器人权限 + + @author: @ydzat + + Args: + channel: 语音频道对象 + + Raises: + VoicePermissionError: 权限不足时抛出 + """ + bot_member = channel.guild.me + permissions = channel.permissions_for(bot_member) + + missing_permissions = [] + + if not permissions.connect: + missing_permissions.append("connect") + if not permissions.speak: + missing_permissions.append("speak") + + if missing_permissions: + raise VoicePermissionError( + f"机器人在频道 {channel.name} 中缺少权限: {', '.join(missing_permissions)}", + missing_permissions, + channel_id=channel.id + ) + + async def cleanup_inactive_connections(self): + """ + 清理无效连接 + + 定期检查并清理已断开或无效的语音连接,释放资源。 + + @author: @ydzat + """ + cleanup_guilds = [] + + for guild_id, conn_info in self.connections.items(): + if not conn_info.voice_client or not conn_info.voice_client.is_connected(): + cleanup_guilds.append(guild_id) + + for guild_id in cleanup_guilds: + await self._disconnect_internal(guild_id) + + if cleanup_guilds: + await self.logger.info(f"清理了 {len(cleanup_guilds)} 个无效的语音连接") + + async def start_monitoring(self): + """ + 开始连接监控 + + @author: @ydzat + """ + if self._cleanup_task is None and self._monitoring_enabled: + self._cleanup_task = asyncio.create_task(self._monitoring_loop()) + + async def stop_monitoring(self): + """ + 停止连接监控 + + @author: @ydzat + """ + self._monitoring_enabled = False + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + self._cleanup_task = None + + async def _monitoring_loop(self): + """ + 监控循环 + + @author: @ydzat + """ + try: + while self._monitoring_enabled: + await asyncio.sleep(60) # 每分钟检查一次 + await self.cleanup_inactive_connections() + except asyncio.CancelledError: + pass + + async def disconnect_all(self): + """ + 断开所有连接 + + @author: @ydzat + """ + async with self._connection_lock: + guild_ids = list(self.connections.keys()) + for guild_id in guild_ids: + await self._disconnect_internal(guild_id) + + await self.stop_monitoring() class DiscordMessageConverter(adapter.MessageConverter): @@ -238,6 +815,9 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): self.logger = logger self.bot_account_id = self.config['client_id'] + + # 初始化语音连接管理器 + self.voice_manager: VoiceConnectionManager = None adapter_self = self @@ -258,6 +838,169 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): args['proxy'] = os.getenv('http_proxy') self.bot = MyClient(intents=intents, **args) + + # Voice functionality methods + async def join_voice_channel(self, guild_id: int, channel_id: int, + user_id: int = None) -> discord.VoiceClient: + """ + 加入语音频道 + + 为指定服务器的语音频道建立连接,支持用户权限验证和连接复用。 + + @author: @ydzat + @version: 1.0 + @since: 2025-07-04 + + Args: + guild_id (int): Discord 服务器ID + channel_id (int): 语音频道ID + user_id (int, optional): 请求用户ID,用于权限验证 + + Returns: + discord.VoiceClient: 语音客户端实例 + + Raises: + VoicePermissionError: 权限不足 + VoiceNetworkError: 网络连接失败 + VoiceConnectionError: 其他连接错误 + """ + if not self.voice_manager: + raise VoiceConnectionError("语音管理器未初始化", "MANAGER_NOT_READY") + + return await self.voice_manager.join_voice_channel(guild_id, channel_id, user_id) + + async def leave_voice_channel(self, guild_id: int) -> bool: + """ + 离开语音频道 + + 断开指定服务器的语音连接,清理相关资源。 + + @author: @ydzat + @version: 1.0 + @since: 2025-07-04 + + Args: + guild_id (int): Discord 服务器ID + + Returns: + bool: 是否成功断开连接 + """ + if not self.voice_manager: + return False + + return await self.voice_manager.leave_voice_channel(guild_id) + + async def get_voice_client(self, guild_id: int) -> typing.Optional[discord.VoiceClient]: + """ + 获取语音客户端 + + 返回指定服务器的语音客户端实例,用于音频播放控制。 + + @author: @ydzat + @version: 1.0 + @since: 2025-07-04 + + Args: + guild_id (int): Discord 服务器ID + + Returns: + Optional[discord.VoiceClient]: 语音客户端实例或 None + """ + if not self.voice_manager: + return None + + return await self.voice_manager.get_voice_client(guild_id) + + async def is_connected_to_voice(self, guild_id: int) -> bool: + """ + 检查语音连接状态 + + @author: @ydzat + @version: 1.0 + @since: 2025-07-04 + + Args: + guild_id (int): Discord 服务器ID + + Returns: + bool: 是否已连接到语音频道 + """ + if not self.voice_manager: + return False + + return await self.voice_manager.is_connected_to_voice(guild_id) + + async def get_voice_connection_status(self, guild_id: int) -> typing.Optional[dict]: + """ + 获取语音连接详细状态 + + 返回包含连接时间、延迟、用户数等详细信息的状态字典。 + + @author: @ydzat + @version: 1.0 + @since: 2025-07-04 + + Args: + guild_id (int): Discord 服务器ID + + Returns: + Optional[dict]: 连接状态信息或 None + """ + if not self.voice_manager: + return None + + return await self.voice_manager.get_connection_status(guild_id) + + async def list_active_voice_connections(self) -> typing.List[dict]: + """ + 列出所有活跃的语音连接 + + @author: @ydzat + @version: 1.0 + @since: 2025-07-04 + + Returns: + List[dict]: 活跃语音连接列表 + """ + if not self.voice_manager: + return [] + + return await self.voice_manager.list_active_connections() + + async def get_voice_channel_info(self, guild_id: int, channel_id: int) -> typing.Optional[dict]: + """ + 获取语音频道详细信息 + + 包括频道名称、用户列表、权限信息等。 + + @author: @ydzat + @version: 1.0 + @since: 2025-07-04 + + Args: + guild_id (int): Discord 服务器ID + channel_id (int): 语音频道ID + + Returns: + Optional[dict]: 频道信息字典或 None + """ + if not self.voice_manager: + return None + + return await self.voice_manager.get_voice_channel_info(guild_id, channel_id) + + async def cleanup_voice_connections(self): + """ + 清理无效的语音连接 + + 手动触发语音连接清理,移除已断开或无效的连接。 + + @author: @ydzat + @version: 1.0 + @since: 2025-07-04 + """ + if self.voice_manager: + await self.voice_manager.cleanup_inactive_connections() async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): msg_to_send, image_files = await self.message_converter.yiri2target(message) @@ -324,9 +1067,32 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): self.listeners.pop(event_type) async def run_async(self): + """ + 启动 Discord 适配器 + + 初始化语音管理器并启动 Discord 客户端连接。 + + @author: @ydzat (修改) + """ async with self.bot: + # 初始化语音管理器 + self.voice_manager = VoiceConnectionManager(self.bot, self.logger) + await self.voice_manager.start_monitoring() + + await self.logger.info("Discord 适配器语音功能已启用") await self.bot.start(self.config['token'], reconnect=True) async def kill(self) -> bool: + """ + 关闭 Discord 适配器 + + 清理语音连接并关闭 Discord 客户端。 + + @author: @ydzat (修改) + """ + if self.voice_manager: + await self.voice_manager.disconnect_all() + await self.bot.close() return True + diff --git a/pkg/platform/sources/wechatpad.py b/pkg/platform/sources/wechatpad.py index 75cad727..9bbb471d 100644 --- a/pkg/platform/sources/wechatpad.py +++ b/pkg/platform/sources/wechatpad.py @@ -29,18 +29,21 @@ import logging class WeChatPadMessageConverter(adapter.MessageConverter): - def __init__(self, config: dict): + + def __init__(self, config: dict, logger: logging.Logger): self.config = config - self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token']) - self.logger = logging.getLogger('WeChatPadMessageConverter') + self.bot = WeChatPadClient(self.config["wechatpad_url"],self.config["token"]) + self.logger = logger @staticmethod async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]: content_list = [] for component in message_chain: - if isinstance(component, platform_message.At): - content_list.append({'type': 'at', 'target': component.target}) + if isinstance(component, platform_message.AtAll): + content_list.append({"type": "at", "target": "all"}) + elif isinstance(component, platform_message.At): + content_list.append({"type": "at", "target": component.target}) elif isinstance(component, platform_message.Plain): content_list.append({'type': 'text', 'content': component.text}) elif isinstance(component, platform_message.Image): @@ -73,20 +76,34 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return content_list - async def target2yiri(self, message: dict, bot_account_id: str) -> platform_message.MessageChain: + async def target2yiri( + self, + message: dict, + bot_account_id: str, + ) -> platform_message.MessageChain: """外部消息转平台消息""" # 数据预处理 message_list = [] + bot_wxid = self.config['wxid'] ats_bot = False # 是否被@ content = message['content']['str'] content_no_preifx = content # 群消息则去掉前缀 is_group_message = self._is_group_message(message) if is_group_message: ats_bot = self._ats_bot(message, bot_account_id) - if '@所有人' in content: + + self.logger.info(f"ats_bot: {ats_bot}; bot_account_id: {bot_account_id}; bot_wxid: {bot_wxid}") + if "@所有人" in content: message_list.append(platform_message.AtAll()) - elif ats_bot: + if ats_bot: message_list.append(platform_message.At(target=bot_account_id)) + + # 解析@信息并生成At组件 + at_targets = self._extract_at_targets(message) + for target_id in at_targets: + if target_id != bot_wxid: # 避免重复添加机器人的At + message_list.append(platform_message.At(target=target_id)) + content_no_preifx, _ = self._extract_content_and_sender(content) msg_type = message['msg_type'] @@ -395,6 +412,23 @@ class WeChatPadMessageConverter(adapter.MessageConverter): finally: return ats_bot + # 提取一下at的wxid列表 + def _extract_at_targets(self, message: dict) -> list[str]: + """从消息中提取被@用户的ID列表""" + at_targets = [] + try: + # 从msg_source中解析atuserlist + msg_source = message.get('msg_source', '') or '' + if len(msg_source) > 0: + msg_source_data = ET.fromstring(msg_source) + at_user_list = msg_source_data.findtext("atuserlist") or "" + if at_user_list: + # atuserlist格式通常是逗号分隔的用户ID列表 + at_targets = [user_id.strip() for user_id in at_user_list.split(',') if user_id.strip()] + except Exception as e: + self.logger.error(f"_extract_at_targets got except: {e}") + return at_targets + # 提取一下content前面的sender_id, 和去掉前缀的内容 def _extract_content_and_sender(self, raw_content: str) -> Tuple[str, Optional[str]]: try: @@ -418,16 +452,22 @@ class WeChatPadMessageConverter(adapter.MessageConverter): class WeChatPadEventConverter(adapter.EventConverter): - def __init__(self, config: dict): - self.config = config - self.message_converter = WeChatPadMessageConverter(config) - self.logger = logging.getLogger('WeChatPadEventConverter') + def __init__(self, config: dict, logger: logging.Logger): + self.config = config + self.message_converter = WeChatPadMessageConverter(config, logger) + self.logger = logger + @staticmethod async def yiri2target(event: platform_events.MessageEvent) -> dict: pass - async def target2yiri(self, event: dict, bot_account_id: str) -> platform_events.MessageEvent: + async def target2yiri( + self, + event: dict, + bot_account_id: str, + ) -> platform_events.MessageEvent: + # 排除公众号以及微信团队消息 if ( event['from_user_name']['str'].startswith('gh_') @@ -503,8 +543,8 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): self.logger = logger self.quart_app = quart.Quart(__name__) - self.message_converter = WeChatPadMessageConverter(config) - self.event_converter = WeChatPadEventConverter(config) + self.message_converter = WeChatPadMessageConverter(config, ap.logger) + self.event_converter = WeChatPadEventConverter(config, ap.logger) async def ws_message(self, data): """处理接收到的消息""" @@ -539,19 +579,26 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): for msg in content_list: # 文本消息处理@ if msg['type'] == 'text' and at_targets: - at_nick_name_list = [] - for member in member_info: - if member['user_name'] in at_targets: - at_nick_name_list.append(f'@{member["nick_name"]}') - msg['content'] = f'{" ".join(at_nick_name_list)} {msg["content"]}' + if "all" in at_targets: + msg['content'] = f'@所有人 {msg["content"]}' + else: + at_nick_name_list = [] + for member in member_info: + if member["user_name"] in at_targets: + at_nick_name_list.append(f'@{member["nick_name"]}') + msg['content'] = f'{" ".join(at_nick_name_list)} {msg["content"]}' # 统一消息派发 handler_map = { 'text': lambda msg: self.bot.send_text_message( - to_wxid=target_id, message=msg['content'], ats=at_targets + to_wxid=target_id, + message=msg['content'], + ats= ["notify@all"] if "all" in at_targets else at_targets ), 'image': lambda msg: self.bot.send_image_message( - to_wxid=target_id, img_url=msg['image'], ats=at_targets + to_wxid=target_id, + img_url=msg["image"], + ats = ["notify@all"] if "all" in at_targets else at_targets ), 'WeChatEmoji': lambda msg: self.bot.send_emoji_message( to_wxid=target_id, emoji_md5=msg['emoji_md5'], emoji_size=msg['emoji_size'] diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 0abebfa5..17697cdb 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -101,18 +101,18 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): async def invoke_embedding( self, model: RuntimeEmbeddingModel, - input_text: str, + input_text: list[str], extra_args: dict[str, typing.Any] = {}, - ) -> list[float]: + ) -> list[list[float]]: """调用 Embedding API Args: query (core_entities.Query): 请求上下文 model (RuntimeEmbeddingModel): 使用的模型信息 - input_text (str): 输入文本 + input_text (list[str]): 输入文本 extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. Returns: - list[float]: 返回的 embedding 向量 + list[list[float]]: 返回的 embedding 向量 """ pass diff --git a/pkg/provider/modelmgr/requesters/302aichatcmpl.py b/pkg/provider/modelmgr/requesters/302aichatcmpl.py index bd9aaccd..40a41718 100644 --- a/pkg/provider/modelmgr/requesters/302aichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/302aichatcmpl.py @@ -7,7 +7,7 @@ from . import chatcmpl class AI302ChatCompletions(chatcmpl.OpenAIChatCompletions): - """302 AI ChatCompletion API 请求器""" + """302.AI ChatCompletion API 请求器""" client: openai.AsyncClient diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 5dadab7d..aaaf3751 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -145,9 +145,9 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): async def invoke_embedding( self, model: requester.RuntimeEmbeddingModel, - input_text: str, + input_text: list[str], extra_args: dict[str, typing.Any] = {}, - ) -> list[float]: + ) -> list[list[float]]: """调用 Embedding API""" self.client.api_key = model.token_mgr.get_token() @@ -163,7 +163,8 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): try: resp = await self.client.embeddings.create(**args) - return resp.data[0].embedding + + return [d.embedding for d in resp.data] except asyncio.TimeoutError: raise errors.RequesterError('请求超时') except openai.BadRequestError as e: diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 7d5e04c5..1d3e88ac 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -1,13 +1,28 @@ from __future__ import annotations import json +import copy import typing - from .. import runner from ...core import entities as core_entities from .. import entities as llm_entities +rag_combined_prompt_template = """ +The following are relevant context entries retrieved from the knowledge base. +Please use them to answer the user's message. +Respond in the same language as the user's input. + + +{rag_context} + + + +{user_message} + +""" + + @runner.runner_class('local-agent') class LocalAgentRunner(runner.RequestRunner): """本地Agent请求运行器""" @@ -16,7 +31,54 @@ class LocalAgentRunner(runner.RequestRunner): """运行请求""" pending_tool_calls = [] - req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] + kb_uuid = query.pipeline_config['ai']['local-agent']['knowledge-base'] + + if kb_uuid == '__none__': + kb_uuid = None + + user_message = copy.deepcopy(query.user_message) + + user_message_text = '' + + if isinstance(user_message.content, str): + user_message_text = user_message.content + elif isinstance(user_message.content, list): + for ce in user_message.content: + if ce.type == 'text': + user_message_text += ce.text + break + + if kb_uuid and user_message_text: + # only support text for now + kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid) + + if not kb: + self.ap.logger.warning(f'Knowledge base {kb_uuid} not found') + raise ValueError(f'Knowledge base {kb_uuid} not found') + + result = await kb.retrieve(user_message_text) + + final_user_message_text = '' + + if result: + rag_context = '\n\n'.join( + f'[{i + 1}] {entry.metadata.get("text", "")}' for i, entry in enumerate(result) + ) + final_user_message_text = rag_combined_prompt_template.format( + rag_context=rag_context, user_message=user_message_text + ) + + else: + final_user_message_text = user_message_text + + self.ap.logger.debug(f'Final user message text: {final_user_message_text}') + + for ce in user_message.content: + if ce.type == 'text': + ce.text = final_user_message_text + break + + req_messages = query.prompt.messages.copy() + query.messages.copy() + [user_message] # 首次请求 msg = await query.use_llm_model.requester.invoke_llm( diff --git a/pkg/rag/knowledge/kbmgr.py b/pkg/rag/knowledge/kbmgr.py new file mode 100644 index 00000000..a9e7e57a --- /dev/null +++ b/pkg/rag/knowledge/kbmgr.py @@ -0,0 +1,212 @@ +from __future__ import annotations +import traceback +import uuid +from .services import parser, chunker +from pkg.core import app +from pkg.rag.knowledge.services.embedder import Embedder +from pkg.rag.knowledge.services.retriever import Retriever +import sqlalchemy +from ...entity.persistence import rag as persistence_rag +from pkg.core import taskmgr +from ...entity.rag import retriever as retriever_entities + + +class RuntimeKnowledgeBase: + ap: app.Application + + knowledge_base_entity: persistence_rag.KnowledgeBase + + parser: parser.FileParser + + chunker: chunker.Chunker + + embedder: Embedder + + retriever: Retriever + + def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase): + self.ap = ap + self.knowledge_base_entity = knowledge_base_entity + self.parser = parser.FileParser(ap=self.ap) + self.chunker = chunker.Chunker(ap=self.ap) + self.embedder = Embedder(ap=self.ap) + self.retriever = Retriever(ap=self.ap) + # 传递kb_id给retriever + self.retriever.kb_id = knowledge_base_entity.uuid + + async def initialize(self): + pass + + async def _store_file_task(self, file: persistence_rag.File, task_context: taskmgr.TaskContext): + try: + # set file status to processing + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_rag.File) + .where(persistence_rag.File.uuid == file.uuid) + .values(status='processing') + ) + + task_context.set_current_action('Parsing file') + # parse file + text = await self.parser.parse(file.file_name, file.extension) + if not text: + raise Exception(f'No text extracted from file {file.file_name}') + + task_context.set_current_action('Chunking file') + # chunk file + chunks_texts = await self.chunker.chunk(text) + if not chunks_texts: + raise Exception(f'No chunks extracted from file {file.file_name}') + + task_context.set_current_action('Embedding chunks') + + embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid( + self.knowledge_base_entity.embedding_model_uuid + ) + # embed chunks + await self.embedder.embed_and_store( + kb_id=self.knowledge_base_entity.uuid, + file_id=file.uuid, + chunks=chunks_texts, + embedding_model=embedding_model, + ) + + # set file status to completed + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_rag.File) + .where(persistence_rag.File.uuid == file.uuid) + .values(status='completed') + ) + + except Exception as e: + self.ap.logger.error(f'Error storing file {file.uuid}: {e}') + traceback.print_exc() + # set file status to failed + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_rag.File) + .where(persistence_rag.File.uuid == file.uuid) + .values(status='failed') + ) + + raise + + async def store_file(self, file_id: str) -> str: + # pre checking + if not await self.ap.storage_mgr.storage_provider.exists(file_id): + raise Exception(f'File {file_id} not found') + + file_uuid = str(uuid.uuid4()) + kb_id = self.knowledge_base_entity.uuid + file_name = file_id + extension = file_name.split('.')[-1] + + file_obj_data = { + 'uuid': file_uuid, + 'kb_id': kb_id, + 'file_name': file_name, + 'extension': extension, + 'status': 'pending', + } + + file_obj = persistence_rag.File(**file_obj_data) + + await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(file_obj_data)) + + # run background task asynchronously + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + self._store_file_task(file_obj, task_context=ctx), + kind='knowledge-operation', + name=f'knowledge-store-file-{file_id}', + label=f'Store file {file_id}', + context=ctx, + ) + return wrapper.id + + async def retrieve(self, query: str) -> list[retriever_entities.RetrieveResultEntry]: + embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid( + self.knowledge_base_entity.embedding_model_uuid + ) + return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model) + + async def delete_file(self, file_id: str): + # delete vector + await self.ap.vector_db_mgr.vector_db.delete_by_file_id(self.knowledge_base_entity.uuid, file_id) + + # delete chunk + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file_id) + ) + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id) + ) + + async def dispose(self): + await self.ap.vector_db_mgr.vector_db.delete_collection(self.knowledge_base_entity.uuid) + + +class RAGManager: + ap: app.Application + + knowledge_bases: list[RuntimeKnowledgeBase] + + def __init__(self, ap: app.Application): + self.ap = ap + self.knowledge_bases = [] + + async def initialize(self): + await self.load_knowledge_bases_from_db() + + async def load_knowledge_bases_from_db(self): + self.ap.logger.info('Loading knowledge bases from db...') + + self.knowledge_bases = [] + + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase)) + + knowledge_bases = result.all() + + for knowledge_base in knowledge_bases: + try: + await self.load_knowledge_base(knowledge_base) + except Exception as e: + self.ap.logger.error( + f'Error loading knowledge base {knowledge_base.uuid}: {e}\n{traceback.format_exc()}' + ) + + async def load_knowledge_base( + self, + knowledge_base_entity: persistence_rag.KnowledgeBase | sqlalchemy.Row | dict, + ) -> RuntimeKnowledgeBase: + if isinstance(knowledge_base_entity, sqlalchemy.Row): + knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity._mapping) + elif isinstance(knowledge_base_entity, dict): + knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity) + + runtime_knowledge_base = RuntimeKnowledgeBase(ap=self.ap, knowledge_base_entity=knowledge_base_entity) + + await runtime_knowledge_base.initialize() + + self.knowledge_bases.append(runtime_knowledge_base) + + return runtime_knowledge_base + + async def get_knowledge_base_by_uuid(self, kb_uuid: str) -> RuntimeKnowledgeBase | None: + for kb in self.knowledge_bases: + if kb.knowledge_base_entity.uuid == kb_uuid: + return kb + return None + + async def remove_knowledge_base_from_runtime(self, kb_uuid: str): + for kb in self.knowledge_bases: + if kb.knowledge_base_entity.uuid == kb_uuid: + self.knowledge_bases.remove(kb) + return + + async def delete_knowledge_base(self, kb_uuid: str): + for kb in self.knowledge_bases: + if kb.knowledge_base_entity.uuid == kb_uuid: + await kb.dispose() + self.knowledge_bases.remove(kb) + return diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py deleted file mode 100644 index 28b8d666..00000000 --- a/pkg/rag/knowledge/mgr.py +++ /dev/null @@ -1,408 +0,0 @@ -from __future__ import annotations -import os -import asyncio -import traceback -import uuid -from pkg.rag.knowledge.services.parser import FileParser -from pkg.rag.knowledge.services.chunker import Chunker -from pkg.rag.knowledge.services.database import ( - KnowledgeBase, - File, - Chunk, -) -from pkg.core import app -from pkg.rag.knowledge.services.embedder import Embedder -from pkg.rag.knowledge.services.retriever import Retriever -from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager -from pkg.core import taskmgr -from ...entity.persistence import rag as persistence_rag -import sqlalchemy - - -class RuntimeKnowledgeBase: - ap: app.Application - - knowledge_base_entity: persistence_rag.KnowledgeBase - - chroma_manager: ChromaIndexManager - - parser: FileParser - - chunker: Chunker - - embedder: Embedder - - retriever: Retriever - - def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase): - self.ap = ap - self.knowledge_base_entity = knowledge_base_entity - self.chroma_manager = ChromaIndexManager(ap=self.ap) - self.parser = FileParser(ap=self.ap) - self.chunker = Chunker(ap=self.ap) - self.embedder = Embedder(ap=self.ap, chroma_manager=self.chroma_manager) - self.retriever = Retriever(ap=self.ap, chroma_manager=self.chroma_manager) - - async def initialize(self): - pass - - async def _store_file_task(self, file: persistence_rag.File, task_context: taskmgr.TaskContext): - try: - # set file status to processing - await self.ap.persistence_mgr.execute_async( - sqlalchemy.update(persistence_rag.File) - .where(persistence_rag.File.uuid == file.uuid) - .values(status='processing') - ) - - task_context.set_current_action('Parsing file') - # parse file - text = await self.parser.parse(file.file_name, file.extension) - if not text: - raise Exception(f'No text extracted from file {file.file_name}') - - task_context.set_current_action('Chunking file') - # chunk file - chunks_texts = await self.chunker.chunk(text) - if not chunks_texts: - raise Exception(f'No chunks extracted from file {file.file_name}') - - task_context.set_current_action('Embedding chunks') - # embed chunks - await self.embedder.embed_and_store( - file_id=file.uuid, chunks=chunks_texts, embedding_model=self.knowledge_base_entity.embedding_model_uuid - ) - - # set file status to completed - await self.ap.persistence_mgr.execute_async( - sqlalchemy.update(persistence_rag.File) - .where(persistence_rag.File.uuid == file.uuid) - .values(status='completed') - ) - - except Exception as e: - self.ap.logger.error(f'Error storing file {file.file_id}: {e}') - # set file status to failed - await self.ap.persistence_mgr.execute_async( - sqlalchemy.update(persistence_rag.File) - .where(persistence_rag.File.uuid == file.uuid) - .values(status='failed') - ) - - raise - - async def store_file(self, file_id: str) -> int: - # pre checking - if not await self.ap.storage_mgr.storage_provider.exists(file_id): - raise Exception(f'File {file_id} not found') - - file_uuid = str(uuid.uuid4()) - kb_id = self.knowledge_base_entity.uuid - file_name = file_id - extension = os.path.splitext(file_id)[1].lstrip('.') - - file = persistence_rag.File( - uuid=file_uuid, - kb_id=kb_id, - file_name=file_name, - extension=extension, - status='pending', - ) - - await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(**file.to_dict())) - - # run background task asynchronously - ctx = taskmgr.TaskContext.new() - wrapper = self.ap.task_mgr.create_user_task( - self._store_file_task(file, task_context=ctx), - kind='knowledge-operation', - name=f'knowledge-store-file-{file_id}', - label=f'Store file {file_id}', - context=ctx, - ) - return wrapper.id - - async def dispose(self): - pass - - -class RAGManager: - ap: app.Application - - knowledge_bases: list[RuntimeKnowledgeBase] - - def __init__(self, ap: app.Application): - self.ap = ap - self.knowledge_bases = [] - - async def initialize(self): - pass - - async def load_knowledge_bases_from_db(self): - self.ap.logger.info('Loading knowledge bases from db...') - - self.knowledge_bases = [] - - result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase)) - - knowledge_bases = result.all() - - for knowledge_base in knowledge_bases: - try: - await self.load_knowledge_base(knowledge_base) - except Exception as e: - self.ap.logger.error( - f'Error loading knowledge base {knowledge_base.uuid}: {e}\n{traceback.format_exc()}' - ) - - async def load_knowledge_base( - self, - knowledge_base_entity: persistence_rag.KnowledgeBase | sqlalchemy.Row | dict, - ) -> RuntimeKnowledgeBase: - if isinstance(knowledge_base_entity, sqlalchemy.Row): - knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity._mapping) - elif isinstance(knowledge_base_entity, dict): - knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity) - - runtime_knowledge_base = RuntimeKnowledgeBase(ap=self.ap, knowledge_base_entity=knowledge_base_entity) - - await runtime_knowledge_base.initialize() - - self.knowledge_bases.append(runtime_knowledge_base) - - return runtime_knowledge_base - - async def get_knowledge_base_by_uuid(self, kb_uuid: str) -> RuntimeKnowledgeBase | None: - for kb in self.knowledge_bases: - if kb.knowledge_base_entity.uuid == kb_uuid: - return kb - return None - - async def remove_knowledge_base(self, kb_uuid: str): - for kb in self.knowledge_bases: - if kb.knowledge_base_entity.uuid == kb_uuid: - await kb.dispose() - self.knowledge_bases.remove(kb) - return - - async def store_data(self, file_path: str, kb_id: str, file_type: str, file_id: str = None): - """ - Parses, chunks, embeds, and stores data from a given file into the RAG system. - Associates the file with a knowledge base using kb_id in the File table. - """ - self.ap.logger.info(f'Starting data storage process for file: {file_path}') - session = SessionLocal() - file_obj = None - - try: - kb = session.query(KnowledgeBase).filter_by(id=kb_id).first() - if not kb: - self.ap.logger.info(f'Knowledge Base "{kb_id}" does not exist. ') - return - # get embedding model - embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(kb.embedding_model_uuid) - file_name = os.path.basename(file_path) - text = await self.parser.parse(file_path) - if not text: - self.ap.logger.warning(f'No text extracted from file {file_path}. ') - return - - chunks_texts = await self.chunker.chunk(text) - self.ap.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.") - await self.embedder.embed_and_store(file_id=file_id, chunks=chunks_texts, embedding_model=embedding_model) - self.ap.logger.info(f'Data storage process completed for file: {file_path}') - - except Exception as e: - session.rollback() - self.ap.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True) - raise - finally: - if file_id: - file_obj = session.query(File).filter_by(id=file_id).first() - if file_obj: - file_obj.status = 1 - session.close() - - async def retrieve_data(self, query: str): - """ - Retrieves relevant data chunks based on a given query using the configured retriever. - """ - self.ap.logger.info(f"Starting data retrieval process for query: '{query}'") - try: - retrieved_chunks = await self.retriever.retrieve(query) - self.ap.logger.info(f'Successfully retrieved {len(retrieved_chunks)} chunks for query.') - return retrieved_chunks - except Exception as e: - self.ap.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True) - return [] - - async def delete_data_by_file_id(self, file_id: str): - """ - Deletes all data associated with a specific file ID, including its chunks and vectors, - and the file record itself. - """ - self.ap.logger.info(f'Starting data deletion process for file_id: {file_id}') - session = SessionLocal() - try: - # delete vectors - await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id) - self.ap.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}') - - chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all() - for chunk in chunks_to_delete: - session.delete(chunk) - self.ap.logger.info(f'Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}') - - file_to_delete = session.query(File).filter_by(id=file_id).first() - if file_to_delete: - session.delete(file_to_delete) - try: - await self.ap.storage_mgr.storage_provider.delete(file_id) - except Exception as e: - self.ap.logger.error( - f'Error deleting file from storage for file_id {file_id}: {str(e)}', - exc_info=True, - ) - self.ap.logger.info(f'Deleted file record for file_id: {file_id}') - else: - self.ap.logger.warning( - f'File with ID {file_id} not found in database. Skipping deletion of file record.' - ) - session.commit() - self.ap.logger.info(f'Successfully completed data deletion for file_id: {file_id}') - except Exception as e: - session.rollback() - self.ap.logger.error(f'Error deleting data for file_id {file_id}: {str(e)}', exc_info=True) - raise - finally: - session.close() - - async def delete_kb_by_id(self, kb_id: str): - """ - Deletes a knowledge base and all associated files, chunks, and vectors. - This involves querying for associated files and then deleting them. - """ - self.ap.logger.info(f'Starting deletion of knowledge base with ID: {kb_id}') - session = SessionLocal() - - try: - kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() - if not kb_to_delete: - self.ap.logger.warning(f'Knowledge Base with ID {kb_id} not found.') - return - - files_to_delete = session.query(File).filter_by(kb_id=kb_id).all() - - session.close() - - for file_obj in files_to_delete: - try: - await self.delete_data_by_file_id(file_obj.id) - except Exception as file_del_e: - self.ap.logger.error(f'Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}') - - session = SessionLocal() - try: - kb_final_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() - if kb_final_delete: - session.delete(kb_final_delete) - session.commit() - self.ap.logger.info(f'Successfully deleted knowledge base with ID: {kb_id}') - else: - self.ap.logger.warning( - f'Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.' - ) - except Exception as kb_del_e: - session.rollback() - self.ap.logger.error( - f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}', - exc_info=True, - ) - raise - finally: - session.close() - - except Exception as e: - # 如果在最初获取 KB 或文件列表时出错 - if session.is_active: - session.rollback() - self.ap.logger.error( - f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}', - exc_info=True, - ) - raise - finally: - if session.is_active: - session.close() - - async def get_file_content_by_file_id(self, file_id: str) -> str: - file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id) - - _, ext = os.path.splitext(file_id.lower()) - ext = ext.lstrip('.') - - try: - text = file_bytes.decode('utf-8') - except UnicodeDecodeError: - return '[非文本文件或编码无法识别]' - - if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']: - return text - else: - return f'[未知类型: .{ext}]' - - async def relate_file_id_with_kb(self, knowledge_base_uuid: str, file_id: str) -> None: - """ - Associates a file with a knowledge base by updating the kb_id in the File table. - """ - self.ap.logger.info(f'Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}') - session = SessionLocal() - try: - # 查询知识库是否存在 - kb = session.query(KnowledgeBase).filter_by(id=knowledge_base_uuid).first() - if not kb: - self.ap.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.') - return - - if not await self.ap.storage_mgr.storage_provider.exists(file_id): - self.ap.logger.error(f'File with ID {file_id} does not exist.') - return - self.ap.logger.info(f'File with ID {file_id} exists, proceeding with association.') - # add new file record - file_to_update = File( - id=file_id, - kb_id=kb.id, - file_name=file_id, - path=os.path.join('data', 'storage', file_id), - file_type=os.path.splitext(file_id)[1].lstrip('.'), - status=0, - ) - session.add(file_to_update) - session.commit() - self.ap.logger.info( - f'Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}' - ) - except Exception as e: - session.rollback() - self.ap.logger.error( - f'Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}', - exc_info=True, - ) - finally: - # 进行文件解析 - try: - await self.store_data( - file_path=os.path.join('data', 'storage', file_id), - kb_id=knowledge_base_uuid, - file_type=os.path.splitext(file_id)[1].lstrip('.'), - file_id=file_id, - ) - except Exception: - # 如果存储数据时出错,更新文件状态为失败 - file_obj = session.query(File).filter_by(id=file_id).first() - if file_obj: - file_obj.status = 2 - session.commit() - self.ap.logger.error(f'Error storing data for file ID {file_id}', exc_info=True) - - session.close() diff --git a/pkg/rag/knowledge/services/base_service.py b/pkg/rag/knowledge/services/base_service.py index 4ff1ce39..0f71a508 100644 --- a/pkg/rag/knowledge/services/base_service.py +++ b/pkg/rag/knowledge/services/base_service.py @@ -1,26 +1,15 @@ # 封装异步操作 import asyncio -import logging -from pkg.rag.knowledge.services.database import SessionLocal + class BaseService: def __init__(self): - self.logger = logging.getLogger(self.__class__.__name__) - self.db_session_factory = SessionLocal + pass async def _run_sync(self, func, *args, **kwargs): """ 在单独的线程中运行同步函数。 如果第一个参数是 session,则在 to_thread 中获取新的 session。 """ - - if getattr(func, '__name__', '').startswith('_db_'): - session = await asyncio.to_thread(self.db_session_factory) - try: - result = await asyncio.to_thread(func, session, *args, **kwargs) - return result - finally: - session.close() - else: - # 否则,直接运行同步函数 - return await asyncio.to_thread(func, *args, **kwargs) \ No newline at end of file + + return await asyncio.to_thread(func, *args, **kwargs) diff --git a/pkg/rag/knowledge/services/chroma_manager.py b/pkg/rag/knowledge/services/chroma_manager.py deleted file mode 100644 index 17757b47..00000000 --- a/pkg/rag/knowledge/services/chroma_manager.py +++ /dev/null @@ -1,67 +0,0 @@ -import numpy as np -import logging -from chromadb import PersistentClient -from pkg.core import app - -logger = logging.getLogger(__name__) - - -class ChromaIndexManager: - def __init__(self, ap: app.Application, collection_name: str = 'default_collection'): - self.ap = ap - chroma_data_path = './data/chroma' - self.client = PersistentClient(path=chroma_data_path) - self._collection_name = collection_name - self._collection = None - - self.ap.logger.info(f'ChromaIndexManager initialized. Collection name: {self._collection_name}') - - @property - def collection(self): - if self._collection is None: - self._collection = self.client.get_or_create_collection(name=self._collection_name) - self.ap.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.") - return self._collection - - def add_embeddings_sync( - self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str] - ): - if ( - embeddings.shape[0] != len(chunk_ids) - or embeddings.shape[0] != len(file_ids) - or embeddings.shape[0] != len(documents) - ): - raise ValueError('Embedding, file_id, chunk_id, and document count mismatch.') - - chroma_ids = [f'{file_id}_{chunk_id}' for file_id, chunk_id in zip(file_ids, chunk_ids)] - metadatas = [{'file_id': fid, 'chunk_id': cid} for fid, cid in zip(file_ids, chunk_ids)] - - self.logger.debug(f"Adding {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") - self.collection.add(embeddings=embeddings.tolist(), ids=chroma_ids, metadatas=metadatas, documents=documents) - self.logger.info(f"Added {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") - - def search_sync(self, query_embedding: np.ndarray, k: int = 5): - """ - Searches the Chroma collection for the top-k nearest neighbors. - Args: - query_embedding: A numpy array of the query embedding. - k: The number of results to return. - Returns: - A dictionary containing query results from Chroma. - """ - self.logger.debug(f"Searching Chroma collection '{self._collection_name}' with k={k}.") - results = self.collection.query( - query_embeddings=query_embedding.tolist(), - n_results=k, - # REMOVE 'ids' from the include list. It's returned by default. - include=['metadatas', 'distances', 'documents'], - ) - self.logger.debug(f'Chroma search returned {len(results.get("ids", [[]])[0])} results.') - return results - - def delete_by_file_id_sync(self, file_id: int): - self.logger.info( - f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'." - ) - self.collection.delete(where={'file_id': file_id}) - self.logger.info(f'Deleted embeddings for file_id: {file_id} from Chroma.') diff --git a/pkg/rag/knowledge/services/chunker.py b/pkg/rag/knowledge/services/chunker.py index 93b10a55..f169d5f1 100644 --- a/pkg/rag/knowledge/services/chunker.py +++ b/pkg/rag/knowledge/services/chunker.py @@ -1,24 +1,22 @@ -# services/chunker.py -import logging +from __future__ import annotations + +import json from typing import List -from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync +from pkg.rag.knowledge.services import base_service from pkg.core import app -logger = logging.getLogger(__name__) - -class Chunker(BaseService): +class Chunker(base_service.BaseService): """ A class for splitting long texts into smaller, overlapping chunks. """ def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50): - super().__init__(ap) # Initialize BaseService self.ap = ap self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap if self.chunk_overlap >= self.chunk_size: - self.logger.warning( + self.ap.logger.warning( 'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.' ) @@ -61,4 +59,5 @@ class Chunker(BaseService): # Run the synchronous splitting logic in a separate thread chunks = await self._run_sync(self._split_text_sync, text) self.ap.logger.info(f'Text chunked into {len(chunks)} pieces.') + self.ap.logger.debug(f'Chunks: {json.dumps(chunks, indent=4, ensure_ascii=False)}') return chunks diff --git a/pkg/rag/knowledge/services/database.py b/pkg/rag/knowledge/services/database.py deleted file mode 100644 index bc5caa10..00000000 --- a/pkg/rag/knowledge/services/database.py +++ /dev/null @@ -1,23 +0,0 @@ -# 全部迁移过去 - -from pkg.entity.persistence.rag import ( - create_db_and_tables, - SessionLocal, - Base, - engine, - KnowledgeBase, - File, - Chunk, - Vector, -) - -__all__ = [ - "create_db_and_tables", - "SessionLocal", - "Base", - "engine", - "KnowledgeBase", - "File", - "Chunk", - "Vector", -] diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py index 34165eab..a0ae3d49 100644 --- a/pkg/rag/knowledge/services/embedder.py +++ b/pkg/rag/knowledge/services/embedder.py @@ -1,89 +1,47 @@ from __future__ import annotations -import asyncio -import logging -import numpy as np +import uuid from typing import List -from sqlalchemy.orm import Session from pkg.rag.knowledge.services.base_service import BaseService -from pkg.rag.knowledge.services.database import Chunk, SessionLocal -from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager +from ....entity.persistence import rag as persistence_rag from ....core import app from ....provider.modelmgr.requester import RuntimeEmbeddingModel +import sqlalchemy class Embedder(BaseService): - def __init__(self, ap: app.Application, chroma_manager: ChromaIndexManager = None) -> None: + def __init__(self, ap: app.Application) -> None: super().__init__() - self.logger = logging.getLogger(self.__class__.__name__) - self.chroma_manager = chroma_manager self.ap = ap - def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]): - """ - Saves chunks to the relational database and returns the created Chunk objects. - This function assumes it's called within a context where the session - will be committed/rolled back and closed by the caller. - """ - self.logger.debug(f'Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).') - chunk_objects = [] - for text in chunks_texts: - chunk = Chunk(file_id=file_id, text=text) - session.add(chunk) - chunk_objects.append(chunk) - session.flush() # This populates the .id attribute for each new chunk object - self.logger.debug(f'Successfully added {len(chunk_objects)} chunk entries to DB.') - return chunk_objects - async def embed_and_store( - self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel - ) -> List[Chunk]: - if not embedding_model: - raise RuntimeError('Embedding model not loaded. Please check Embedder initialization.') + self, kb_id: str, file_id: str, chunks: List[str], embedding_model: RuntimeEmbeddingModel + ) -> list[persistence_rag.Chunk]: + # save chunk to db + chunk_entities: list[persistence_rag.Chunk] = [] + chunk_ids: list[str] = [] - session = SessionLocal() # Start a session that will live for the whole operation - chunk_objects = [] - try: - # 1. Save chunks to the relational database first to get their IDs - # We call _db_save_chunks_sync directly without _run_sync's session management - # because we manage the session here across multiple async calls. - chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks) - session.commit() # Commit chunks to make their IDs permanent and accessible + for chunk_text in chunks: + chunk_uuid = str(uuid.uuid4()) + chunk_ids.append(chunk_uuid) + chunk_entity = persistence_rag.Chunk(uuid=chunk_uuid, file_id=file_id, text=chunk_text) + chunk_entities.append(chunk_entity) - if not chunk_objects: - self.logger.warning( - f'No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.' - ) - return [] + chunk_dicts = [ + self.ap.persistence_mgr.serialize_model(persistence_rag.Chunk, chunk) for chunk in chunk_entities + ] - # get the embeddings for the chunks - embeddings: list[list[float]] = [] + await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts)) - for chunk in chunks: - result = await embedding_model.requester.invoke_embedding( - model=embedding_model, - input_text=chunk, - ) - embeddings.append(result) + # get embeddings + embeddings_list: list[list[float]] = await embedding_model.requester.invoke_embedding( + model=embedding_model, + input_text=chunks, + extra_args={}, # TODO: add extra args + ) - embeddings_np = np.array(embeddings, dtype=np.float32) + # save embeddings to vdb + await self.ap.vector_db_mgr.vector_db.add_embeddings(kb_id, chunk_ids, embeddings_list, chunk_dicts) - self.logger.info('Saving embeddings to Chroma...') - chunk_ids = [c.id for c in chunk_objects] - file_ids_for_chroma = [file_id] * len(chunk_ids) + self.ap.logger.info(f'Successfully saved {len(chunk_entities)} embeddings to Knowledge Base.') - await self._run_sync( # Use _run_sync for the Chroma operation, as it's a sync call - self.chroma_manager.add_embeddings_sync, - file_ids_for_chroma, - chunk_ids, - embeddings_np, - chunks, # Pass original chunks texts for documents - ) - self.logger.info(f'Successfully saved {len(chunk_objects)} embeddings to Chroma.') - return chunk_objects - - except Exception as e: - session.rollback() # Rollback on any error - self.logger.error(f'Failed to process and store data for file_id {file_id}: {e}', exc_info=True) - raise # Re-raise the exception to propagate it - finally: - session.close() # Ensure the session is always closed + return chunk_entities diff --git a/pkg/rag/knowledge/services/parser.py b/pkg/rag/knowledge/services/parser.py index 91b4f9ff..004dbdaa 100644 --- a/pkg/rag/knowledge/services/parser.py +++ b/pkg/rag/knowledge/services/parser.py @@ -1,21 +1,16 @@ +from __future__ import annotations + import PyPDF2 import io from docx import Document -import pandas as pd import chardet from typing import Union, Callable, Any -import logging import markdown from bs4 import BeautifulSoup -import ebooklib -from ebooklib import epub import re import asyncio # Import asyncio for async operations from pkg.core import app -# Configure logging -logger = logging.getLogger(__name__) - class FileParser: """ @@ -144,45 +139,45 @@ class FileParser: self.ap.logger.warning(f'Direct .doc parsing is not supported for {file_name}. Please convert to .docx first.') raise NotImplementedError('Direct .doc parsing not supported. Please convert to .docx first.') - async def _parse_xlsx(self, file_name: str) -> str: - """Parses an XLSX file, returning text from all sheets.""" - self.ap.logger.info(f'Parsing XLSX file: {file_name}') + # async def _parse_xlsx(self, file_name: str) -> str: + # """Parses an XLSX file, returning text from all sheets.""" + # self.ap.logger.info(f'Parsing XLSX file: {file_name}') - xlsx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) + # xlsx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) - def _parse_xlsx_sync(): - excel_file = pd.ExcelFile(io.BytesIO(xlsx_bytes)) - all_sheet_content = [] - for sheet_name in excel_file.sheet_names: - df = pd.read_excel(io.BytesIO(xlsx_bytes), sheet_name=sheet_name) - sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n' - all_sheet_content.append(sheet_text) - return '\n'.join(all_sheet_content) + # def _parse_xlsx_sync(): + # excel_file = pd.ExcelFile(io.BytesIO(xlsx_bytes)) + # all_sheet_content = [] + # for sheet_name in excel_file.sheet_names: + # df = pd.read_excel(io.BytesIO(xlsx_bytes), sheet_name=sheet_name) + # sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n' + # all_sheet_content.append(sheet_text) + # return '\n'.join(all_sheet_content) - return await self._run_sync(_parse_xlsx_sync) + # return await self._run_sync(_parse_xlsx_sync) - async def _parse_csv(self, file_name: str) -> str: - """Parses a CSV file and returns its content as a string.""" - self.ap.logger.info(f'Parsing CSV file: {file_name}') + # async def _parse_csv(self, file_name: str) -> str: + # """Parses a CSV file and returns its content as a string.""" + # self.ap.logger.info(f'Parsing CSV file: {file_name}') - csv_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) + # csv_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) - def _parse_csv_sync(): - # pd.read_csv can often detect encoding, but explicit detection is safer - # raw_data = self._read_file_content( - # file_name, mode='rb' - # ) # Note: this will need to be await outside this sync function - # _ = raw_data - # For simplicity, we'll let pandas handle encoding internally after a raw read. - # A more robust solution might pass encoding directly to pd.read_csv after detection. - detected = chardet.detect(io.BytesIO(csv_bytes)) - encoding = detected['encoding'] or 'utf-8' - df = pd.read_csv(io.BytesIO(csv_bytes), encoding=encoding) - return df.to_string(index=False) + # def _parse_csv_sync(): + # # pd.read_csv can often detect encoding, but explicit detection is safer + # # raw_data = self._read_file_content( + # # file_name, mode='rb' + # # ) # Note: this will need to be await outside this sync function + # # _ = raw_data + # # For simplicity, we'll let pandas handle encoding internally after a raw read. + # # A more robust solution might pass encoding directly to pd.read_csv after detection. + # detected = chardet.detect(io.BytesIO(csv_bytes)) + # encoding = detected['encoding'] or 'utf-8' + # df = pd.read_csv(io.BytesIO(csv_bytes), encoding=encoding) + # return df.to_string(index=False) - return await self._run_sync(_parse_csv_sync) + # return await self._run_sync(_parse_csv_sync) - async def _parse_markdown(self, file_name: str) -> str: + async def _parse_md(self, file_name: str) -> str: """Parses a Markdown file, converting it to structured plain text.""" self.ap.logger.info(f'Parsing Markdown file: {file_name}') @@ -261,43 +256,6 @@ class FileParser: return await self._run_sync(_parse_html_sync) - async def _parse_epub(self, file_name: str) -> str: - """Parses an EPUB file, extracting metadata and content.""" - self.ap.logger.info(f'Parsing EPUB file: {file_name}') - - epub_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) - - def _parse_epub_sync(): - book = epub.read_epub(io.BytesIO(epub_bytes)) - text_content = [] - title_meta = book.get_metadata('DC', 'title') - if title_meta: - text_content.append(f'Title: {title_meta[0][0]}') - creator_meta = book.get_metadata('DC', 'creator') - if creator_meta: - text_content.append(f'Author: {creator_meta[0][0]}') - date_meta = book.get_metadata('DC', 'date') - if date_meta: - text_content.append(f'Publish Date: {date_meta[0][0]}') - toc = book.get_toc() - if toc: - text_content.append('\n--- Table of Contents ---') - self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper - text_content.append('--- End of Table of Contents ---\n') - for item in book.get_items(): - if item.get_type() == ebooklib.ITEM_DOCUMENT: - html_content = item.get_content().decode('utf-8', errors='ignore') - soup = BeautifulSoup(html_content, 'html.parser') - for junk in soup(['script', 'style', 'nav', 'header', 'footer']): - junk.decompose() - text = soup.get_text(separator='\n', strip=True) - text = re.sub(r'\n\s*\n', '\n\n', text) - if text: - text_content.append(text) - return re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_content)).strip() - - return await self._run_sync(_parse_epub_sync) - def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int): """Recursively adds TOC items to text_content (synchronous helper).""" indent = ' ' * level diff --git a/pkg/rag/knowledge/services/retriever.py b/pkg/rag/knowledge/services/retriever.py index d330747c..73c7edaa 100644 --- a/pkg/rag/knowledge/services/retriever.py +++ b/pkg/rag/knowledge/services/retriever.py @@ -1,94 +1,48 @@ from __future__ import annotations -import logging -import numpy as np # Make sure numpy is imported -from typing import List, Dict, Any -from sqlalchemy.orm import Session -from pkg.rag.knowledge.services.base_service import BaseService -from pkg.rag.knowledge.services.database import Chunk, SessionLocal -from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager + +from . import base_service from ....core import app - -logger = logging.getLogger(__name__) +from ....provider.modelmgr.requester import RuntimeEmbeddingModel +from ....entity.rag import retriever as retriever_entities -class Retriever(BaseService): - def __init__(self, ap:app.Application, chroma_manager: ChromaIndexManager): +class Retriever(base_service.BaseService): + def __init__(self, ap: app.Application): super().__init__() - self.logger = logging.getLogger(self.__class__.__name__) - self.chroma_manager = chroma_manager self.ap = ap - async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]: - if not self.embedding_model: - raise RuntimeError('Retriever embedding model not loaded. Please check Retriever initialization.') + async def retrieve( + self, kb_id: str, query: str, embedding_model: RuntimeEmbeddingModel, k: int = 5 + ) -> list[retriever_entities.RetrieveResultEntry]: + self.ap.logger.info( + f"Retrieving for query: '{query[:10]}' with k={k} using {embedding_model.model_entity.uuid}" + ) - self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}") + query_embedding: list[float] = await embedding_model.requester.invoke_embedding( + model=embedding_model, + input_text=[query], + extra_args={}, # TODO: add extra args + ) - query_embedding: List[float] = await self.embedding_model.embed_query(query) - query_embedding_np = np.array([query_embedding], dtype=np.float32) - - chroma_results = await self._run_sync(self.chroma_manager.search_sync, query_embedding_np, k) + chroma_results = await self.ap.vector_db_mgr.vector_db.search(kb_id, query_embedding[0], k) # 'ids' is always returned by ChromaDB, even if not explicitly in 'include' matched_chroma_ids = chroma_results.get('ids', [[]])[0] distances = chroma_results.get('distances', [[]])[0] chroma_metadatas = chroma_results.get('metadatas', [[]])[0] - chroma_documents = chroma_results.get('documents', [[]])[0] if not matched_chroma_ids: - self.logger.info('No relevant chunks found in Chroma.') + self.ap.logger.info('No relevant chunks found in Chroma.') return [] - db_chunk_ids = [] - for metadata in chroma_metadatas: - if 'chunk_id' in metadata: - db_chunk_ids.append(metadata['chunk_id']) - else: - self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.") + result: list[retriever_entities.RetrieveResultEntry] = [] - if not db_chunk_ids: - self.logger.warning('No valid chunk_ids extracted from Chroma results metadata.') - return [] - - self.logger.info(f'Fetching {len(db_chunk_ids)} chunk details from relational database...') - chunks_from_db = await self._run_sync( - lambda cids: self._db_get_chunks_sync( - SessionLocal(), cids - ), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync - db_chunk_ids, - ) - - chunk_map = {chunk.id: chunk for chunk in chunks_from_db} - results_list: List[Dict[str, Any]] = [] - - for i, chroma_id in enumerate(matched_chroma_ids): - try: - # Ensure original_chunk_id is int for DB lookup - original_chunk_id = int(chroma_id.split('_')[-1]) - except (ValueError, IndexError): - self.logger.warning(f'Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.') - continue - - chunk_text_from_chroma = chroma_documents[i] - distance = float(distances[i]) - file_id_from_chroma = chroma_metadatas[i].get('file_id') - - chunk_from_db = chunk_map.get(original_chunk_id) - - results_list.append( - { - 'chunk_id': original_chunk_id, - 'text': chunk_from_db.text if chunk_from_db else chunk_text_from_chroma, - 'distance': distance, - 'file_id': file_id_from_chroma, - } + for i, id in enumerate(matched_chroma_ids): + entry = retriever_entities.RetrieveResultEntry( + id=id, + metadata=chroma_metadatas[i], + distance=distances[i], ) + result.append(entry) - self.logger.info(f'Retrieved {len(results_list)} chunks.') - return results_list - - def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]: - self.logger.debug(f'Fetching {len(chunk_ids)} chunk details from database (sync).') - chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all() - session.close() - return chunks + return result diff --git a/pkg/utils/announce.py b/pkg/utils/announce.py index 7108a08c..a6b8539a 100644 --- a/pkg/utils/announce.py +++ b/pkg/utils/announce.py @@ -46,7 +46,7 @@ class AnnouncementManager: async def fetch_all(self) -> list[Announcement]: """获取所有公告""" resp = requests.get( - url='https://api.github.com/repos/RockChinQ/LangBot/contents/res/announcement.json', + url='https://api.github.com/repos/langbot-app/LangBot/contents/res/announcement.json', proxies=self.ap.proxy_mgr.get_forward_proxies(), timeout=5, ) diff --git a/pkg/utils/constants.py b/pkg/utils/constants.py index e8193839..bc96b45c 100644 --- a/pkg/utils/constants.py +++ b/pkg/utils/constants.py @@ -1,6 +1,6 @@ -semantic_version = 'v4.0.8' +semantic_version = 'v4.1.0' -required_database_version = 3 +required_database_version = 4 """Tag the version of the database schema, used to check if the database needs to be migrated""" debug_mode = False diff --git a/pkg/utils/version.py b/pkg/utils/version.py index ec0683c3..b26b1e33 100644 --- a/pkg/utils/version.py +++ b/pkg/utils/version.py @@ -29,7 +29,7 @@ class VersionManager: async def get_release_list(self) -> list: """获取发行列表""" rls_list_resp = requests.get( - url='https://api.github.com/repos/RockChinQ/LangBot/releases', + url='https://api.github.com/repos/langbot-app/LangBot/releases', proxies=self.ap.proxy_mgr.get_forward_proxies(), timeout=5, ) diff --git a/pkg/vector/mgr.py b/pkg/vector/mgr.py index b2f47d61..ea198ac2 100644 --- a/pkg/vector/mgr.py +++ b/pkg/vector/mgr.py @@ -1,13 +1,18 @@ from __future__ import annotations from ..core import app +from .vdb import VectorDatabase +from .vdbs.chroma import ChromaVectorDatabase class VectorDBManager: ap: app.Application + vector_db: VectorDatabase = None def __init__(self, ap: app.Application): self.ap = ap async def initialize(self): - pass + # 初始化 Chroma 向量数据库(可扩展为多种实现) + if self.vector_db is None: + self.vector_db = ChromaVectorDatabase(self.ap) diff --git a/pkg/vector/vdb.py b/pkg/vector/vdb.py index 100ded93..73a3cc0e 100644 --- a/pkg/vector/vdb.py +++ b/pkg/vector/vdb.py @@ -1,7 +1,37 @@ from __future__ import annotations - import abc +from typing import Any, Dict +import numpy as np class VectorDatabase(abc.ABC): - pass + @abc.abstractmethod + async def add_embeddings( + self, + collection: str, + ids: list[str], + embeddings_list: list[list[float]], + metadatas: list[dict[str, Any]], + documents: list[str], + ) -> None: + """向指定 collection 添加向量数据。""" + pass + + @abc.abstractmethod + async def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]: + """在指定 collection 中检索最相似的向量。""" + pass + + @abc.abstractmethod + async def delete_by_file_id(self, collection: str, file_id: str) -> None: + """根据 file_id 删除指定 collection 中的向量。""" + pass + + @abc.abstractmethod + async def get_or_create_collection(self, collection: str): + """获取或创建 collection。""" + pass + + @abc.abstractmethod + async def delete_collection(self, collection: str): + pass diff --git a/pkg/vector/vdbs/__init__.py b/pkg/vector/vdbs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/vector/vdbs/chroma.py b/pkg/vector/vdbs/chroma.py new file mode 100644 index 00000000..41ab7d36 --- /dev/null +++ b/pkg/vector/vdbs/chroma.py @@ -0,0 +1,61 @@ +from __future__ import annotations +import asyncio +from typing import Any +from chromadb import PersistentClient +from pkg.vector.vdb import VectorDatabase +from pkg.core import app +import chromadb +import chromadb.errors + + +class ChromaVectorDatabase(VectorDatabase): + def __init__(self, ap: app.Application, base_path: str = './data/chroma'): + self.ap = ap + self.client = PersistentClient(path=base_path) + self._collections = {} + + async def get_or_create_collection(self, collection: str) -> chromadb.Collection: + if collection not in self._collections: + self._collections[collection] = await asyncio.to_thread( + self.client.get_or_create_collection, name=collection + ) + self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.") + return self._collections[collection] + + async def add_embeddings( + self, + collection: str, + ids: list[str], + embeddings_list: list[list[float]], + metadatas: list[dict[str, Any]], + ) -> None: + col = await self.get_or_create_collection(collection) + await asyncio.to_thread(col.add, embeddings=embeddings_list, ids=ids, metadatas=metadatas) + self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.") + + async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]: + col = await self.get_or_create_collection(collection) + results = await asyncio.to_thread( + col.query, + query_embeddings=query_embedding, + n_results=k, + include=['metadatas', 'distances', 'documents'], + ) + self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.") + return results + + async def delete_by_file_id(self, collection: str, file_id: str) -> None: + col = await self.get_or_create_collection(collection) + await asyncio.to_thread(col.delete, where={'file_id': file_id}) + self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with file_id: {file_id}") + + async def delete_collection(self, collection: str): + if collection in self._collections: + del self._collections[collection] + + try: + await asyncio.to_thread(self.client.delete_collection, name=collection) + except chromadb.errors.NotFoundError: + self.ap.logger.warning(f"Chroma collection '{collection}' not found.") + return + self.ap.logger.info(f"Chroma collection '{collection}' deleted.") diff --git a/pyproject.toml b/pyproject.toml index 27a03a92..b12db3e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "langbot" -version = "4.0.7" +version = "4.1.0" description = "高稳定、支持扩展、多模态 - 大模型原生即时通信机器人平台" readme = "README.md" requires-python = ">=3.10.1" @@ -19,6 +19,7 @@ dependencies = [ "dashscope>=1.23.2", "dingtalk-stream>=0.24.0", "discord-py>=2.5.2", + "pynacl>=1.5.0", # Required for Discord voice support "gewechat-client>=0.1.5", "lark-oapi>=1.4.15", "mcp>=1.8.1", @@ -59,7 +60,6 @@ dependencies = [ "html2text>=2024.2.26", "langchain>=0.2.0", "chromadb>=0.4.24", - "sentence-transformers>=2.6.1", ] keywords = [ "bot", @@ -90,11 +90,13 @@ classifiers = [ [project.urls] Homepage = "https://langbot.app" Documentation = "https://docs.langbot.app" -Repository = "https://github.com/RockChinQ/langbot" +Repository = "https://github.com/langbot-app/LangBot" [dependency-groups] dev = [ "pre-commit>=4.2.0", + "pytest>=8.4.1", + "pytest-asyncio>=1.0.0", "ruff>=0.11.9", ] diff --git a/templates/default-pipeline-config.json b/templates/default-pipeline-config.json index 796c6356..d06e4661 100644 --- a/templates/default-pipeline-config.json +++ b/templates/default-pipeline-config.json @@ -44,7 +44,8 @@ "role": "system", "content": "You are a helpful assistant." } - ] + ], + "knowledge-base": "" }, "dify-service-api": { "base-url": "https://api.dify.ai/v1", diff --git a/templates/metadata/pipeline/ai.yaml b/templates/metadata/pipeline/ai.yaml index 90732dc8..ffbefe63 100644 --- a/templates/metadata/pipeline/ai.yaml +++ b/templates/metadata/pipeline/ai.yaml @@ -68,6 +68,16 @@ stages: zh_Hans: 除非您了解消息结构,否则请只使用 system 单提示词 type: prompt-editor required: true + - name: knowledge-base + label: + en_US: Knowledge Base + zh_Hans: 知识库 + description: + en_US: Configure the knowledge base to use for the agent, if not selected, the agent will directly use the LLM to reply + zh_Hans: 配置用于提升回复质量的知识库,若不选择,则直接使用大模型回复 + type: knowledge-base-selector + required: false + default: '' - name: dify-service-api label: en_US: Dify Service API @@ -298,3 +308,4 @@ stages: type: string required: false default: 'response' + diff --git a/web/package-lock.json b/web/package-lock.json index fcc17852..cb2b05d8 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -13,6 +13,7 @@ "@hookform/resolvers": "^5.0.1", "@radix-ui/react-checkbox": "^1.3.1", "@radix-ui/react-dialog": "^1.1.14", + "@radix-ui/react-dropdown-menu": "^2.1.15", "@radix-ui/react-hover-card": "^1.1.13", "@radix-ui/react-label": "^2.1.6", "@radix-ui/react-popover": "^1.1.14", @@ -1152,31 +1153,6 @@ } } }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-focus-scope": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.1.7.tgz", - "integrity": "sha512-t2ODlkXBQyn7jkl6TNaw/MtVEVvIGelJDCG41Okq/KwUsJBwQ4XVZsHAVUkK4mBv3ewiAS3PGuUWuY2BoK4ZUw==", - "license": "MIT", - "dependencies": { - "@radix-ui/react-compose-refs": "1.1.2", - "@radix-ui/react-primitive": "2.1.3", - "@radix-ui/react-use-callback-ref": "1.1.1" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", - "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-portal": { "version": "1.1.9", "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.9.tgz", @@ -1266,6 +1242,58 @@ } } }, + "node_modules/@radix-ui/react-dropdown-menu": { + "version": "2.1.15", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dropdown-menu/-/react-dropdown-menu-2.1.15.tgz", + "integrity": "sha512-mIBnOjgwo9AH3FyKaSWoSu/dYj6VdhJ7frEPiGTeXCdUFHjl9h3mFh2wwhEtINOmYXWhdpf1rY2minFsmaNgVQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-menu": "2.1.15", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-controllable-state": "1.2.2" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-dropdown-menu/node_modules/@radix-ui/react-primitive": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-focus-guards": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-guards/-/react-focus-guards-1.1.2.tgz", @@ -1282,13 +1310,13 @@ } }, "node_modules/@radix-ui/react-focus-scope": { - "version": "1.1.6", - "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.1.6.tgz", - "integrity": "sha512-r9zpYNUQY+2jWHWZGyddQLL9YHkM/XvSFHVcWs7bdVuxMAnCwTAuy6Pf47Z4nw7dYcUou1vg/VgjjrrH03VeBw==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.1.7.tgz", + "integrity": "sha512-t2ODlkXBQyn7jkl6TNaw/MtVEVvIGelJDCG41Okq/KwUsJBwQ4XVZsHAVUkK4mBv3ewiAS3PGuUWuY2BoK4ZUw==", "license": "MIT", "dependencies": { "@radix-ui/react-compose-refs": "1.1.2", - "@radix-ui/react-primitive": "2.1.2", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1" }, "peerDependencies": { @@ -1306,6 +1334,29 @@ } } }, + "node_modules/@radix-ui/react-focus-scope/node_modules/@radix-ui/react-primitive": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-hover-card": { "version": "1.1.13", "resolved": "https://registry.npmjs.org/@radix-ui/react-hover-card/-/react-hover-card-1.1.13.tgz", @@ -1378,6 +1429,232 @@ } } }, + "node_modules/@radix-ui/react-menu": { + "version": "2.1.15", + "resolved": "https://registry.npmjs.org/@radix-ui/react-menu/-/react-menu-2.1.15.tgz", + "integrity": "sha512-tVlmA3Vb9n8SZSd+YSbuFR66l87Wiy4du+YE+0hzKQEANA+7cWKH1WgqcEX4pXqxUFQKrWQGHdvEfw00TjFiew==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-collection": "1.1.7", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-direction": "1.1.1", + "@radix-ui/react-dismissable-layer": "1.1.10", + "@radix-ui/react-focus-guards": "1.1.2", + "@radix-ui/react-focus-scope": "1.1.7", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", + "@radix-ui/react-presence": "1.1.4", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", + "@radix-ui/react-slot": "1.2.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "aria-hidden": "^1.2.4", + "react-remove-scroll": "^2.6.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-menu/node_modules/@radix-ui/react-arrow": { + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.1.7.tgz", + "integrity": "sha512-F+M1tLhO+mlQaOWspE8Wstg+z6PwxwRd8oQ8IXceWz92kfAmalTRf0EjrouQeo7QssEPfCn05B4Ihs1K9WQ/7w==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-menu/node_modules/@radix-ui/react-collection": { + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.7.tgz", + "integrity": "sha512-Fh9rGN0MoI4ZFUNyfFVNU4y9LUz93u9/0K+yLgA2bwRojxM8JU1DyvvMBabnZPBgMWREAJvU2jjVzq+LrFUglw==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-menu/node_modules/@radix-ui/react-dismissable-layer": { + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.10.tgz", + "integrity": "sha512-IM1zzRV4W3HtVgftdQiiOmA0AdJlCtMLe00FXaHwgt3rAnNsIyDqshvkIW3hj/iu5hu8ERP7KIYki6NkqDxAwQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-escape-keydown": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-menu/node_modules/@radix-ui/react-popper": { + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.7.tgz", + "integrity": "sha512-IUFAccz1JyKcf/RjB552PlWwxjeCJB8/4KxT7EhBHOJM+mN7LdW+B3kacJXILm32xawcMMjb2i0cIZpo+f9kiQ==", + "license": "MIT", + "dependencies": { + "@floating-ui/react-dom": "^2.0.0", + "@radix-ui/react-arrow": "1.1.7", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-layout-effect": "1.1.1", + "@radix-ui/react-use-rect": "1.1.1", + "@radix-ui/react-use-size": "1.1.1", + "@radix-ui/rect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-menu/node_modules/@radix-ui/react-portal": { + "version": "1.1.9", + "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.9.tgz", + "integrity": "sha512-bpIxvq03if6UNwXZ+HTK71JLh4APvnXntDc6XOX8UVq4XQOVl7lwok0AvIl+b8zgCw3fSaVTZMpAPPagXbKmHQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-menu/node_modules/@radix-ui/react-primitive": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-menu/node_modules/@radix-ui/react-roving-focus": { + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-roving-focus/-/react-roving-focus-1.1.10.tgz", + "integrity": "sha512-dT9aOXUen9JSsxnMPv/0VqySQf5eDQ6LCk5Sw28kamz8wSOW2bJdlX2Bg5VUIIcV+6XlHpWTIuTPCf/UNIyq8Q==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-collection": "1.1.7", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-direction": "1.1.1", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-controllable-state": "1.2.2" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-popover": { "version": "1.1.14", "resolved": "https://registry.npmjs.org/@radix-ui/react-popover/-/react-popover-1.1.14.tgz", @@ -1465,31 +1742,6 @@ } } }, - "node_modules/@radix-ui/react-popover/node_modules/@radix-ui/react-focus-scope": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.1.7.tgz", - "integrity": "sha512-t2ODlkXBQyn7jkl6TNaw/MtVEVvIGelJDCG41Okq/KwUsJBwQ4XVZsHAVUkK4mBv3ewiAS3PGuUWuY2BoK4ZUw==", - "license": "MIT", - "dependencies": { - "@radix-ui/react-compose-refs": "1.1.2", - "@radix-ui/react-primitive": "2.1.3", - "@radix-ui/react-use-callback-ref": "1.1.1" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", - "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, "node_modules/@radix-ui/react-popover/node_modules/@radix-ui/react-popper": { "version": "1.2.7", "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.7.tgz", diff --git a/web/package.json b/web/package.json index 8cd6b8dc..62d62874 100644 --- a/web/package.json +++ b/web/package.json @@ -5,6 +5,7 @@ "scripts": { "dev": "next dev --turbopack", "dev:local": "NEXT_PUBLIC_API_BASE_URL=http://localhost:5300 next dev --turbopack", + "dev:local:win": "set NEXT_PUBLIC_API_BASE_URL=http://localhost:5300&&next dev --turbopack", "build": "next build", "start": "next start", "lint": "next lint", @@ -16,6 +17,9 @@ "prettier --write" ] }, + "overrides": { + "@radix-ui/react-focus-scope": "1.1.7" + }, "dependencies": { "@dnd-kit/core": "^6.3.1", "@dnd-kit/sortable": "^10.0.0", diff --git a/web/src/app/home/bots/BotDetailDialog.tsx b/web/src/app/home/bots/BotDetailDialog.tsx index cad04e7b..db19e1d4 100644 --- a/web/src/app/home/bots/BotDetailDialog.tsx +++ b/web/src/app/home/bots/BotDetailDialog.tsx @@ -127,7 +127,6 @@ export default function BotDetailDialog({ @@ -198,7 +197,6 @@ export default function BotDetailDialog({ diff --git a/web/src/app/home/bots/components/bot-form/BotForm.tsx b/web/src/app/home/bots/components/bot-form/BotForm.tsx index e4b6d40e..bd757ae0 100644 --- a/web/src/app/home/bots/components/bot-form/BotForm.tsx +++ b/web/src/app/home/bots/components/bot-form/BotForm.tsx @@ -64,13 +64,11 @@ const getFormSchema = (t: (key: string) => string) => export default function BotForm({ initBotId, onFormSubmit, - onFormCancel, onBotDeleted, onNewBotCreated, }: { initBotId?: string; onFormSubmit: (value: z.infer>) => void; - onFormCancel: () => void; onBotDeleted: () => void; onNewBotCreated: (botId: string) => void; }) { diff --git a/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx b/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx index f3df9e87..de040db3 100644 --- a/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx +++ b/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx @@ -50,6 +50,9 @@ export default function DynamicFormComponent({ case 'llm-model-selector': fieldSchema = z.string(); break; + case 'knowledge-base-selector': + fieldSchema = z.string(); + break; case 'prompt-editor': fieldSchema = z.array( z.object({ diff --git a/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx b/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx index 28d963d3..69cb79e1 100644 --- a/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx +++ b/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx @@ -17,6 +17,7 @@ import { Button } from '@/components/ui/button'; import { useEffect, useState } from 'react'; import { httpClient } from '@/app/infra/http/HttpClient'; import { LLMModel } from '@/app/infra/entities/api'; +import { KnowledgeBase } from '@/app/infra/entities/api'; import { toast } from 'sonner'; import { HoverCard, @@ -35,6 +36,7 @@ export default function DynamicFormItemComponent({ field: ControllerRenderProps; }) { const [llmModels, setLlmModels] = useState([]); + const [knowledgeBases, setKnowledgeBases] = useState([]); const { t } = useTranslation(); useEffect(() => { @@ -50,6 +52,19 @@ export default function DynamicFormItemComponent({ } }, [config.type]); + useEffect(() => { + if (config.type === DynamicFormItemType.KNOWLEDGE_BASE_SELECTOR) { + httpClient + .getKnowledgeBases() + .then((resp) => { + setKnowledgeBases(resp.bases); + }) + .catch((err) => { + toast.error('获取知识库列表失败:' + err.message); + }); + } + }, [config.type]); + switch (config.type) { case DynamicFormItemType.INT: case DynamicFormItemType.FLOAT: @@ -249,6 +264,25 @@ export default function DynamicFormItemComponent({ ); + case DynamicFormItemType.KNOWLEDGE_BASE_SELECTOR: + return ( + + ); + case DynamicFormItemType.PROMPT_EDITOR: return (
diff --git a/web/src/app/home/knowledge/KBDetailDialog.tsx b/web/src/app/home/knowledge/KBDetailDialog.tsx index 3854e933..7ad8d4a4 100644 --- a/web/src/app/home/knowledge/KBDetailDialog.tsx +++ b/web/src/app/home/knowledge/KBDetailDialog.tsx @@ -7,7 +7,6 @@ import { DialogHeader, DialogTitle, DialogFooter, - DialogDescription, } from '@/components/ui/dialog'; import { Sidebar, @@ -21,36 +20,34 @@ import { } from '@/components/ui/sidebar'; import { Button } from '@/components/ui/button'; import { useTranslation } from 'react-i18next'; -import { z } from 'zod'; import { httpClient } from '@/app/infra/http/HttpClient'; // import { KnowledgeBase } from '@/app/infra/entities/api'; import KBForm from '@/app/home/knowledge/components/kb-form/KBForm'; import KBDoc from '@/app/home/knowledge/components/kb-docs/KBDoc'; +import KBRetrieve from '@/app/home/knowledge/components/kb-retrieve/KBRetrieve'; interface KBDetailDialogProps { open: boolean; onOpenChange: (open: boolean) => void; kbId?: string; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - onFormSubmit: (value: z.infer) => void; onFormCancel: () => void; onKbDeleted: () => void; onNewKbCreated: (kbId: string) => void; + onKbUpdated: (kbId: string) => void; } export default function KBDetailDialog({ open, onOpenChange, kbId: propKbId, - onFormSubmit, onFormCancel, onKbDeleted, onNewKbCreated, + onKbUpdated, }: KBDetailDialogProps) { const { t } = useTranslation(); const [kbId, setKbId] = useState(propKbId); const [activeMenu, setActiveMenu] = useState('metadata'); - const [fileId, setFileId] = useState(undefined); const [showDeleteConfirm, setShowDeleteConfirm] = useState(false); useEffect(() => { @@ -85,6 +82,19 @@ export default function KBDetailDialog({ ), }, + { + key: 'retrieve', + label: t('knowledge.retrieve'), + icon: ( + + + + ), + }, ]; const confirmDelete = () => { @@ -107,10 +117,8 @@ export default function KBDetailDialog({ {activeMenu === 'metadata' && ( )} {activeMenu === 'documents' &&
documents
} @@ -174,20 +182,21 @@ export default function KBDetailDialog({ {activeMenu === 'metadata' ? t('knowledge.editKnowledgeBase') - : t('knowledge.editDocument')} + : activeMenu === 'documents' + ? t('knowledge.editDocument') + : t('knowledge.retrieveTest')}
{activeMenu === 'metadata' && ( )} {activeMenu === 'documents' && } + {activeMenu === 'retrieve' && }
{activeMenu === 'metadata' && ( diff --git a/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx b/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx index aa8adede..3b4123ec 100644 --- a/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx +++ b/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx @@ -104,7 +104,7 @@ export default function FileUploadZone({ id="file-upload" className="hidden" onChange={handleFileSelect} - accept=".pdf,.doc,.docx,.txt,.md" + accept=".pdf,.doc,.docx,.txt,.md,.html" disabled={isUploading} /> diff --git a/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx b/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx index 0a779112..fb94dace 100644 --- a/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx +++ b/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx @@ -28,7 +28,7 @@ export default function KBDoc({ kbId }: { kbId: string }) { setDocumentsList( resp.files.map((file: KnowledgeBaseFile) => { return { - id: file.id, + uuid: file.uuid, name: file.file_name, status: file.status, }; @@ -66,7 +66,7 @@ export default function KBDoc({ kbId }: { kbId: string }) { onUploadSuccess={handleUploadSuccess} onUploadError={handleUploadError} /> - +
); } diff --git a/web/src/app/home/knowledge/components/kb-docs/documents/columns.tsx b/web/src/app/home/knowledge/components/kb-docs/documents/columns.tsx index 0a43cf8f..6142cfc4 100644 --- a/web/src/app/home/knowledge/components/kb-docs/documents/columns.tsx +++ b/web/src/app/home/knowledge/components/kb-docs/documents/columns.tsx @@ -8,21 +8,21 @@ import { DropdownMenuContent, DropdownMenuItem, DropdownMenuLabel, - DropdownMenuSeparator, DropdownMenuTrigger, } from '@/components/ui/dropdown-menu'; -import { useTranslation } from 'react-i18next'; +import { Badge } from '@/components/ui/badge'; +import { TFunction } from 'i18next'; export type DocumentFile = { - id: string; + uuid: string; name: string; status: string; }; export const columns = ( onDelete: (id: string) => void, + t: TFunction, ): ColumnDef[] => { - const { t } = useTranslation(); return [ { accessorKey: 'name', @@ -31,6 +31,36 @@ export const columns = ( { accessorKey: 'status', header: t('knowledge.documentsTab.status'), + cell: ({ row }) => { + const document = row.original; + + switch (document.status) { + case 'processing': + return ( + + {t('knowledge.documentsTab.processing')} + + ); + case 'completed': + return ( + + {t('knowledge.documentsTab.completed')} + + ); + case 'failed': + return ( + + {t('knowledge.documentsTab.failed')} + + ); + default: + return ( + + {document.status} + + ); + } + }, }, { id: 'actions', @@ -52,7 +82,7 @@ export const columns = ( {t('knowledge.documentsTab.actions')} - onDelete(document.id)}> + onDelete(document.uuid)}> {t('knowledge.documentsTab.delete')} diff --git a/web/src/app/home/knowledge/components/kb-form/KBForm.tsx b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx index 6d0b42f9..398d2302 100644 --- a/web/src/app/home/knowledge/components/kb-form/KBForm.tsx +++ b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx @@ -24,6 +24,7 @@ import { SelectValue, } from '@/components/ui/select'; import { KnowledgeBase } from '@/app/infra/entities/api'; +import { toast } from 'sonner'; const getFormSchema = (t: (key: string) => string) => z.object({ @@ -42,17 +43,12 @@ const getFormSchema = (t: (key: string) => string) => export default function KBForm({ initKbId, - onFormSubmit, - onFormCancel, - onKbDeleted, onNewKbCreated, + onKbUpdated, }: { initKbId?: string; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - onFormSubmit: (value: any) => void; - onFormCancel: () => void; - onKbDeleted: () => void; onNewKbCreated: (kbId: string) => void; + onKbUpdated: (kbId: string) => void; }) { const { t } = useTranslation(); const formSchema = getFormSchema(t); @@ -87,7 +83,7 @@ export default function KBForm({ const getKbConfig = async ( kbId: string, ): Promise> => { - return new Promise((resolve, reject) => { + return new Promise((resolve) => { httpClient.getKnowledgeBase(kbId).then((res) => { resolve({ name: res.base.name, @@ -122,6 +118,17 @@ export default function KBForm({ embedding_model_uuid: data.embeddingModelUUID, top_k: data.top_k, }; + httpClient + .updateKnowledgeBase(initKbId, updateKb) + .then((res) => { + console.log('update knowledge base success', res); + onKbUpdated(res.uuid); + toast.success(t('knowledge.updateKnowledgeBaseSuccess')); + }) + .catch((err) => { + console.error('update knowledge base failed', err); + toast.error(t('knowledge.updateKnowledgeBaseFailed')); + }); } else { // create knowledge base const newKb: KnowledgeBase = { @@ -195,6 +202,7 @@ export default function KBForm({
setQuery(e.target.value)} + placeholder={t('knowledge.queryPlaceholder')} + onKeyPress={(e) => e.key === 'Enter' && handleRetrieve()} + /> + +
+ +
+ {results.length === 0 && !loading && ( +

{t('knowledge.noResults')}

+ )} + + {loading ? ( +

{t('common.loading')}

+ ) : ( + results.map((result) => ( + + + + {getFileName(result.metadata.file_id)} + + {t('knowledge.distance')}: {result.distance.toFixed(4)} + + + + +

+ {result.metadata.text} +

+
+
+ )) + )} +
+ + ); +} diff --git a/web/src/app/home/knowledge/page.tsx b/web/src/app/home/knowledge/page.tsx index a2841b24..b290f7be 100644 --- a/web/src/app/home/knowledge/page.tsx +++ b/web/src/app/home/knowledge/page.tsx @@ -63,11 +63,6 @@ export default function KnowledgePage() { setDetailDialogOpen(true); }; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const handleFormSubmit = (value: any) => { - console.log('handleFormSubmit', value); - }; - const handleFormCancel = () => { setDetailDialogOpen(false); }; @@ -77,9 +72,14 @@ export default function KnowledgePage() { setDetailDialogOpen(false); }; - const handleNewKbCreated = () => { + const handleNewKbCreated = (newKbId: string) => { + getKnowledgeBaseList(); + setSelectedKbId(newKbId); + setDetailDialogOpen(true); + }; + + const handleKbUpdated = () => { getKnowledgeBaseList(); - setDetailDialogOpen(false); }; return ( @@ -88,10 +88,10 @@ export default function KnowledgePage() { open={detailDialogOpen} onOpenChange={setDetailDialogOpen} kbId={selectedKbId || undefined} - onFormSubmit={handleFormSubmit} onFormCancel={handleFormCancel} onKbDeleted={handleKbDeleted} onNewKbCreated={handleNewKbCreated} + onKbUpdated={handleKbUpdated} />
diff --git a/web/src/app/home/pipelines/page.tsx b/web/src/app/home/pipelines/page.tsx index 40875f6e..ecead827 100644 --- a/web/src/app/home/pipelines/page.tsx +++ b/web/src/app/home/pipelines/page.tsx @@ -9,6 +9,13 @@ import styles from './pipelineConfig.module.css'; import { toast } from 'sonner'; import { useTranslation } from 'react-i18next'; import PipelineDialog from './PipelineDetailDialog'; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; export default function PluginConfigPage() { const { t } = useTranslation(); @@ -26,14 +33,19 @@ export default function PluginConfigPage() { }); const [selectedPipelineIsDefault, setSelectedPipelineIsDefault] = useState(false); + const [sortByValue, setSortByValue] = useState('created_at'); + const [sortOrderValue, setSortOrderValue] = useState('DESC'); useEffect(() => { getPipelines(); }, []); - function getPipelines() { + function getPipelines( + sortBy: string = sortByValue, + sortOrder: string = sortOrderValue, + ) { httpClient - .getPipelines() + .getPipelines(sortBy, sortOrder) .then((value) => { const currentTime = new Date(); const pipelineList = value.pipelines.map((pipeline) => { @@ -106,6 +118,13 @@ export default function PluginConfigPage() { setDialogOpen(true); }; + function handleSortChange(value: string) { + const [newSortBy, newSortOrder] = value.split(',').map((s) => s.trim()); + setSortByValue(newSortBy); + setSortOrderValue(newSortOrder); + getPipelines(newSortBy, newSortOrder); + } + return (
+
+ +
('pushed_at'); const [sortOrderValue, setSortOrderValue] = useState('DESC'); const searchTimeout = useRef(null); - const pageSize = 10; + const pageSize = 12; useEffect(() => { initData(); diff --git a/web/src/app/infra/entities/api/index.ts b/web/src/app/infra/entities/api/index.ts index 77dfff05..787631a3 100644 --- a/web/src/app/infra/entities/api/index.ts +++ b/web/src/app/infra/entities/api/index.ts @@ -55,6 +55,15 @@ export interface LLMModel { // updated_at: string; } +export interface KnowledgeBase { + uuid?: string; + name: string; + description: string; + embedding_model_uuid: string; + created_at?: string; + top_k?: number; +} + export interface ApiRespProviderEmbeddingModels { models: EmbeddingModel[]; } @@ -156,7 +165,7 @@ export interface ApiRespKnowledgeBaseFiles { } export interface KnowledgeBaseFile { - id: string; + uuid: string; file_name: string; status: string; } @@ -288,3 +297,18 @@ export interface ApiRespWebChatMessage { export interface ApiRespWebChatMessages { messages: Message[]; } + +export interface RetrieveResult { + id: string; + metadata: { + file_id: string; + text: string; + uuid: string; + [key: string]: unknown; + }; + distance: number; +} + +export interface ApiRespKnowledgeBaseRetrieve { + results: RetrieveResult[]; +} diff --git a/web/src/app/infra/entities/form/dynamic.ts b/web/src/app/infra/entities/form/dynamic.ts index 6a185c8b..6d6de096 100644 --- a/web/src/app/infra/entities/form/dynamic.ts +++ b/web/src/app/infra/entities/form/dynamic.ts @@ -21,6 +21,7 @@ export enum DynamicFormItemType { LLM_MODEL_SELECTOR = 'llm-model-selector', PROMPT_EDITOR = 'prompt-editor', UNKNOWN = 'unknown', + KNOWLEDGE_BASE_SELECTOR = 'knowledge-base-selector', } export interface IDynamicFormItemOption { diff --git a/web/src/app/infra/http/HttpClient.ts b/web/src/app/infra/http/HttpClient.ts index d3530a4e..9a49c1e3 100644 --- a/web/src/app/infra/http/HttpClient.ts +++ b/web/src/app/infra/http/HttpClient.ts @@ -38,6 +38,7 @@ import { ApiRespKnowledgeBase, KnowledgeBase, ApiRespKnowledgeBaseFiles, + ApiRespKnowledgeBaseRetrieve, } from '@/app/infra/entities/api'; import { GetBotLogsRequest } from '@/app/infra/http/requestParam/bots/GetBotLogsRequest'; import { GetBotLogsResponse } from '@/app/infra/http/requestParam/bots/GetBotLogsResponse'; @@ -323,8 +324,15 @@ class HttpClient { return this.get('/api/v1/pipelines/_/metadata'); } - public getPipelines(): Promise { - return this.get('/api/v1/pipelines'); + public getPipelines( + sortBy?: string, + sortOrder?: string, + ): Promise { + const params = new URLSearchParams(); + if (sortBy) params.append('sort_by', sortBy); + if (sortOrder) params.append('sort_order', sortOrder); + const queryString = params.toString(); + return this.get(`/api/v1/pipelines${queryString ? `?${queryString}` : ''}`); } public getPipeline(uuid: string): Promise { @@ -459,6 +467,13 @@ class HttpClient { return this.post('/api/v1/knowledge/bases', base); } + public updateKnowledgeBase( + uuid: string, + base: KnowledgeBase, + ): Promise<{ uuid: string }> { + return this.put(`/api/v1/knowledge/bases/${uuid}`, base); + } + public uploadKnowledgeBaseFile( uuid: string, file_id: string, @@ -485,6 +500,13 @@ class HttpClient { return this.delete(`/api/v1/knowledge/bases/${uuid}`); } + public retrieveKnowledgeBase( + uuid: string, + query: string, + ): Promise { + return this.post(`/api/v1/knowledge/bases/${uuid}/retrieve`, { query }); + } + // ============ Plugins API ============ public getPlugins(): Promise { return this.get('/api/v1/plugins'); diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index ddf7ad0c..ff855d31 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -40,6 +40,7 @@ const enUS = { copySuccess: 'Copy Successfully', test: 'Test', forgotPassword: 'Forgot Password?', + loading: 'Loading...', }, notFound: { title: 'Page not found', @@ -194,6 +195,10 @@ const enUS = { today: 'Today', updateTime: 'Updated ', defaultBadge: 'Default', + sortBy: 'Sort by', + newestCreated: 'Newest Created', + recentlyEdited: 'Recently Edited', + earliestEdited: 'Earliest Edited', basicInfo: 'Basic', aiCapabilities: 'AI', triggerConditions: 'Trigger', @@ -234,6 +239,8 @@ const enUS = { title: 'Knowledge', createKnowledgeBase: 'Create Knowledge Base', editKnowledgeBase: 'Edit Knowledge Base', + selectKnowledgeBase: 'Select Knowledge Base', + empty: 'Empty', editDocument: 'Documents', description: 'Configuring knowledge bases for improved LLM responses', metadata: 'Metadata', @@ -255,6 +262,10 @@ const enUS = { embeddingModelDescription: 'Used to vectorize the text, you can configure it in the Models page', updateTime: 'Updated ', + cannotChangeEmbeddingModel: + 'Knowledge base created cannot be modified embedding model', + updateKnowledgeBaseSuccess: 'Knowledge base updated successfully', + updateKnowledgeBaseFailed: 'Knowledge base update failed', documentsTab: { name: 'Name', status: 'Status', @@ -270,9 +281,21 @@ const enUS = { delete: 'Delete File', fileDeleteSuccess: 'File deleted successfully', fileDeleteFailed: 'File deletion failed', + processing: 'Processing', + completed: 'Completed', + failed: 'Failed', }, deleteKnowledgeBaseConfirmation: 'Are you sure you want to delete this knowledge base? All documents in this knowledge base will be deleted.', + retrieve: 'Retrieve Test', + retrieveTest: 'Retrieve Test', + query: 'Query', + queryPlaceholder: 'Enter query text...', + distance: 'Distance', + content: 'Content', + fileName: 'File Name', + noResults: 'No results', + retrieveError: 'Retrieve failed', }, register: { title: 'Initialize LangBot 👋', diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index 69357cb9..1fa337a0 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -41,6 +41,7 @@ const jaJP = { copySuccess: 'コピーに成功しました', test: 'テスト', forgotPassword: 'パスワードを忘れた?', + loading: '読み込み中...', }, notFound: { title: 'ページが見つかりません', @@ -195,6 +196,10 @@ const jaJP = { today: '今日', updateTime: '更新日時', defaultBadge: 'デフォルト', + sortBy: '並び順', + newestCreated: '最新作成', + recentlyEdited: '最近編集', + earliestEdited: '最古編集', basicInfo: '基本情報', aiCapabilities: 'AI機能', triggerConditions: 'トリガー条件', @@ -236,6 +241,8 @@ const jaJP = { title: '知識ベース', createKnowledgeBase: '知識ベースを作成', editKnowledgeBase: '知識ベースを編集', + selectKnowledgeBase: '知識ベースを選択', + empty: 'なし', editDocument: 'ドキュメント', description: 'LLMの回答品質向上のための知識ベースを設定します', metadata: 'メタデータ', @@ -257,6 +264,10 @@ const jaJP = { embeddingModelDescription: 'テキストのベクトル化に使用する埋め込みモデルを管理します', updateTime: '更新日時', + cannotChangeEmbeddingModel: + '知識ベース作成後は埋め込みモデルを変更できません', + updateKnowledgeBaseSuccess: '知識ベースの更新に成功しました', + updateKnowledgeBaseFailed: '知識ベースの更新に失敗しました', documentsTab: { name: '名前', status: 'ステータス', @@ -273,9 +284,21 @@ const jaJP = { delete: 'ドキュメントを削除', fileDeleteSuccess: 'ドキュメントの削除に成功しました', fileDeleteFailed: 'ドキュメントの削除に失敗しました', + processing: '処理中', + completed: '完了', + failed: '失敗', }, deleteKnowledgeBaseConfirmation: '本当にこの知識ベースを削除しますか?この知識ベースに紐付けられたドキュメントは削除されます。', + retrieve: '検索テスト', + retrieveTest: '検索テスト', + query: '検索', + queryPlaceholder: '検索内容を入力...', + distance: '距離', + content: '内容', + fileName: 'ファイル名', + noResults: '検索結果がありません', + retrieveError: '検索に失敗しました', }, register: { title: 'LangBot を初期化 👋', diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index a0ea2b8e..2575094a 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -40,6 +40,7 @@ const zhHans = { copySuccess: '复制成功', test: '测试', forgotPassword: '忘记密码?', + loading: '加载中...', }, notFound: { title: '页面不存在', @@ -189,6 +190,10 @@ const zhHans = { today: '今天', updateTime: '更新于', defaultBadge: '默认', + sortBy: '排序方式', + newestCreated: '最新创建', + recentlyEdited: '最近编辑', + earliestEdited: '最早编辑', basicInfo: '基础信息', aiCapabilities: 'AI 能力', triggerConditions: '触发条件', @@ -229,6 +234,8 @@ const zhHans = { title: '知识库', createKnowledgeBase: '创建知识库', editKnowledgeBase: '编辑知识库', + selectKnowledgeBase: '选择知识库', + empty: '无', editDocument: '文档', description: '配置可用于提升模型回复质量的知识库', metadata: '元数据', @@ -249,6 +256,9 @@ const zhHans = { selectEmbeddingModel: '选择嵌入模型', embeddingModelDescription: '用于向量化文本,可在模型配置页面配置', updateTime: '更新于', + cannotChangeEmbeddingModel: '知识库创建后不可修改嵌入模型', + updateKnowledgeBaseSuccess: '知识库更新成功', + updateKnowledgeBaseFailed: '知识库更新失败', documentsTab: { name: '名称', status: '状态', @@ -263,9 +273,21 @@ const zhHans = { delete: '删除文件', fileDeleteSuccess: '文件删除成功', fileDeleteFailed: '文件删除失败', + processing: '处理中', + completed: '完成', + failed: '失败', }, deleteKnowledgeBaseConfirmation: '你确定要删除这个知识库吗?此知识库下的所有文档将被删除。', + retrieve: '检索测试', + retrieveTest: '检索测试', + query: '查询', + queryPlaceholder: '输入查询内容...', + distance: '距离', + content: '内容', + fileName: '文件名', + noResults: '暂无结果', + retrieveError: '检索失败', }, register: { title: '初始化 LangBot 👋',