mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
Compare commits
1 Commits
v4.9.2
...
feat/napca
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
46a993b9a3 |
17
README.md
17
README.md
@@ -19,10 +19,9 @@ English / [简体中文](README_CN.md) / [繁體中文](README_TW.md) / [日本
|
||||
[](https://github.com/langbot-app/LangBot/stargazers)
|
||||
|
||||
<a href="https://langbot.app">Website</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features">Features</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide">Docs</a> |
|
||||
<a href="https://docs.langbot.app/en/tags/readme">API</a> |
|
||||
<a href="https://space.langbot.app/cloud">Cloud</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/features.html">Features</a> |
|
||||
<a href="https://docs.langbot.app/en/insight/guide.html">Docs</a> |
|
||||
<a href="https://docs.langbot.app/en/tags/readme.html">API</a> |
|
||||
<a href="https://space.langbot.app">Plugin Market</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">Roadmap</a>
|
||||
|
||||
@@ -45,16 +44,12 @@ LangBot is an **open-source, production-grade platform** for building AI-powered
|
||||
- **Web Management Panel** — Configure, manage, and monitor your bots through an intuitive browser interface. No YAML editing required.
|
||||
- **Multi-Pipeline Architecture** — Different bots for different scenarios, with comprehensive monitoring and exception handling.
|
||||
|
||||
[→ Learn more about all features](https://docs.langbot.app/en/insight/features)
|
||||
[→ Learn more about all features](https://docs.langbot.app/en/insight/features.html)
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
### ☁️ LangBot Cloud (Recommended)
|
||||
|
||||
**[LangBot Cloud](https://space.langbot.app/cloud)** — Zero deployment, ready to use.
|
||||
|
||||
### One-Line Launch
|
||||
|
||||
```bash
|
||||
@@ -76,7 +71,7 @@ docker compose up -d
|
||||
[](https://zeabur.com/en-US/templates/ZKTBDH)
|
||||
[](https://railway.app/template/yRrAyL?referralCode=vogKPF)
|
||||
|
||||
**More options:** [Docker](https://docs.langbot.app/en/deploy/langbot/docker) · [Manual](https://docs.langbot.app/en/deploy/langbot/manual) · [BTPanel](https://docs.langbot.app/en/deploy/langbot/one-click/bt) · [Kubernetes](./docker/README_K8S.md)
|
||||
**More options:** [Docker](https://docs.langbot.app/en/deploy/langbot/docker.html) · [Manual](https://docs.langbot.app/en/deploy/langbot/manual.html) · [BTPanel](https://docs.langbot.app/en/deploy/langbot/one-click/bt.html) · [Kubernetes](./docker/README_K8S.md)
|
||||
|
||||
---
|
||||
|
||||
@@ -124,7 +119,7 @@ docker compose up -d
|
||||
| [接口 AI](https://jiekou.ai/) | Gateway | ✅ |
|
||||
| [302.AI](https://share.302.ai/SuTG99) | Gateway | ✅ |
|
||||
|
||||
[→ View all integrations](https://docs.langbot.app/en/insight/features)
|
||||
[→ View all integrations](https://docs.langbot.app/en/insight/features.html)
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@
|
||||
<a href="https://docs.langbot.app/zh/insight/features.html">特性</a> |
|
||||
<a href="https://docs.langbot.app/zh/insight/guide.html">文档</a> |
|
||||
<a href="https://docs.langbot.app/zh/tags/readme.html">API</a> |
|
||||
<a href="https://space.langbot.app/cloud">Cloud</a> |
|
||||
<a href="https://space.langbot.app">插件市场</a> |
|
||||
<a href="https://langbot.featurebase.app/roadmap">路线图</a>
|
||||
|
||||
@@ -53,10 +52,6 @@ LangBot 是一个**开源的生产级平台**,用于构建 AI 驱动的即时
|
||||
|
||||
## 快速开始
|
||||
|
||||
### ☁️ LangBot Cloud(推荐)
|
||||
|
||||
**[LangBot Cloud](https://space.langbot.app/cloud)** — 免部署,开箱即用。
|
||||
|
||||
### 一键启动
|
||||
|
||||
```bash
|
||||
|
||||
@@ -50,10 +50,6 @@ LangBot es una **plataforma de código abierto y grado de producción** para con
|
||||
|
||||
## Inicio Rápido
|
||||
|
||||
### ☁️ LangBot Cloud (Recomendado)
|
||||
|
||||
**[LangBot Cloud](https://space.langbot.app/cloud)** — Sin despliegue, listo para usar.
|
||||
|
||||
### Lanzamiento en una línea
|
||||
|
||||
```bash
|
||||
|
||||
@@ -50,10 +50,6 @@ LangBot est une **plateforme open-source de niveau production** pour créer des
|
||||
|
||||
## Démarrage Rapide
|
||||
|
||||
### ☁️ LangBot Cloud (Recommandé)
|
||||
|
||||
**[LangBot Cloud](https://space.langbot.app/cloud)** — Sans déploiement, prêt à utiliser.
|
||||
|
||||
### Lancement en une ligne
|
||||
|
||||
```bash
|
||||
|
||||
@@ -50,10 +50,6 @@ LangBot は、AI搭載のインスタントメッセージングボットを構
|
||||
|
||||
## クイックスタート
|
||||
|
||||
### ☁️ LangBot Cloud(推奨)
|
||||
|
||||
**[LangBot Cloud](https://space.langbot.app/cloud)** — デプロイ不要、すぐに使えます。
|
||||
|
||||
### ワンライン起動
|
||||
|
||||
```bash
|
||||
|
||||
@@ -50,10 +50,6 @@ LangBot은 AI 기반 인스턴트 메시징 봇을 구축하기 위한 **오픈
|
||||
|
||||
## 빠른 시작
|
||||
|
||||
### ☁️ LangBot Cloud (추천)
|
||||
|
||||
**[LangBot Cloud](https://space.langbot.app/cloud)** — 배포 없이 바로 사용.
|
||||
|
||||
### 원라인 실행
|
||||
|
||||
```bash
|
||||
|
||||
@@ -50,10 +50,6 @@ LangBot — это **платформа с открытым исходным к
|
||||
|
||||
## Быстрый старт
|
||||
|
||||
### ☁️ LangBot Cloud (Рекомендуется)
|
||||
|
||||
**[LangBot Cloud](https://space.langbot.app/cloud)** — Без развёртывания, готово к использованию.
|
||||
|
||||
### Запуск одной командой
|
||||
|
||||
```bash
|
||||
|
||||
@@ -52,10 +52,6 @@ LangBot 是一個**開源的生產級平台**,用於建構 AI 驅動的即時
|
||||
|
||||
## 快速開始
|
||||
|
||||
### ☁️ LangBot Cloud(推薦)
|
||||
|
||||
**[LangBot Cloud](https://space.langbot.app/cloud)** — 免部署,開箱即用。
|
||||
|
||||
### 一鍵啟動
|
||||
|
||||
```bash
|
||||
|
||||
@@ -50,10 +50,6 @@ LangBot là một **nền tảng mã nguồn mở, cấp sản xuất** để x
|
||||
|
||||
## Bắt đầu nhanh
|
||||
|
||||
### ☁️ LangBot Cloud (Khuyên dùng)
|
||||
|
||||
**[LangBot Cloud](https://space.langbot.app/cloud)** — Không cần triển khai, sẵn sàng sử dụng.
|
||||
|
||||
### Khởi chạy một dòng
|
||||
|
||||
```bash
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "langbot"
|
||||
version = "4.9.2"
|
||||
version = "4.8.4"
|
||||
description = "Production-grade platform for building agentic IM bots"
|
||||
readme = "README.md"
|
||||
license-files = ["LICENSE"]
|
||||
@@ -61,17 +61,16 @@ dependencies = [
|
||||
"html2text>=2024.2.26",
|
||||
"langchain>=0.2.0",
|
||||
"langchain-text-splitters>=0.0.1",
|
||||
"chromadb>=1.0.0,<2.0.0",
|
||||
"chromadb>=0.4.24",
|
||||
"qdrant-client (>=1.15.1,<2.0.0)",
|
||||
"pyseekdb==1.1.0.post3",
|
||||
"langbot-plugin==0.3.1",
|
||||
"pyseekdb==1.0.0b7",
|
||||
"langbot-plugin==0.2.6",
|
||||
"asyncpg>=0.30.0",
|
||||
"line-bot-sdk>=3.19.0",
|
||||
"tboxsdk>=0.0.10",
|
||||
"boto3>=1.35.0",
|
||||
"pymilvus>=2.6.4",
|
||||
"pgvector>=0.4.1",
|
||||
"botocore>=1.42.39",
|
||||
]
|
||||
keywords = [
|
||||
"bot",
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""LangBot - Production-grade platform for building agentic IM bots"""
|
||||
|
||||
__version__ = '4.9.2'
|
||||
__version__ = '4.8.4'
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import requests
|
||||
from langbot.pkg.utils import httpclient
|
||||
import aiohttp
|
||||
|
||||
|
||||
def post_json(base_url, token, data=None):
|
||||
@@ -63,16 +63,16 @@ async def async_request(
|
||||
"""
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
url = f'{base_url}?key={token_key}'
|
||||
session = httpclient.get_session()
|
||||
async with session.request(
|
||||
method=method, url=url, params=params, headers=headers, data=data, json=json
|
||||
) as response:
|
||||
response.raise_for_status() # 如果状态码不是200,抛出异常
|
||||
result = await response.json()
|
||||
# print(result)
|
||||
return result
|
||||
# if result.get('Code') == 200:
|
||||
#
|
||||
# return await result
|
||||
# else:
|
||||
# raise RuntimeError("请求失败",response.text)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
method=method, url=url, params=params, headers=headers, data=data, json=json
|
||||
) as response:
|
||||
response.raise_for_status() # 如果状态码不是200,抛出异常
|
||||
result = await response.json()
|
||||
# print(result)
|
||||
return result
|
||||
# if result.get('Code') == 200:
|
||||
#
|
||||
# return await result
|
||||
# else:
|
||||
# raise RuntimeError("请求失败",response.text)
|
||||
|
||||
@@ -199,253 +199,6 @@ class StreamSessionManager:
|
||||
self._msg_index.pop(msg_id, None)
|
||||
|
||||
|
||||
async def download_encrypted_file(download_url: str, encoding_aes_key: str, logger: EventLogger) -> Optional[str]:
|
||||
"""Download an AES-encrypted file from WeChat Work and return as data URI.
|
||||
|
||||
Args:
|
||||
download_url: The encrypted file download URL.
|
||||
encoding_aes_key: The AES key used for decryption (base64-encoded, without trailing '=').
|
||||
logger: Logger instance.
|
||||
|
||||
Returns:
|
||||
A data URI string (e.g. 'data:image/jpeg;base64,...') or None on failure.
|
||||
"""
|
||||
if not download_url:
|
||||
return None
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(download_url)
|
||||
if response.status_code != 200:
|
||||
await logger.error(f'failed to get file: {response.text}')
|
||||
return None
|
||||
encrypted_bytes = response.content
|
||||
|
||||
aes_key = base64.b64decode(encoding_aes_key + '=')
|
||||
iv = aes_key[:16]
|
||||
|
||||
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
||||
decrypted = cipher.decrypt(encrypted_bytes)
|
||||
|
||||
pad_len = decrypted[-1]
|
||||
decrypted = decrypted[:-pad_len]
|
||||
|
||||
if decrypted.startswith(b'\xff\xd8'):
|
||||
mime_type = 'image/jpeg'
|
||||
elif decrypted.startswith(b'\x89PNG'):
|
||||
mime_type = 'image/png'
|
||||
elif decrypted.startswith((b'GIF87a', b'GIF89a')):
|
||||
mime_type = 'image/gif'
|
||||
elif decrypted.startswith(b'BM'):
|
||||
mime_type = 'image/bmp'
|
||||
elif decrypted.startswith(b'II*\x00') or decrypted.startswith(b'MM\x00*'):
|
||||
mime_type = 'image/tiff'
|
||||
else:
|
||||
mime_type = 'application/octet-stream'
|
||||
|
||||
base64_str = base64.b64encode(decrypted).decode('utf-8')
|
||||
return f'data:{mime_type};base64,{base64_str}'
|
||||
|
||||
|
||||
async def parse_wecom_bot_message(
|
||||
msg_json: dict[str, Any], encoding_aes_key: str, logger: EventLogger
|
||||
) -> dict[str, Any]:
|
||||
"""Parse a decrypted WeChat Work AI Bot message JSON into a unified message dict.
|
||||
|
||||
This is the shared message parsing logic used by both webhook and WebSocket modes.
|
||||
|
||||
Args:
|
||||
msg_json: The decrypted message JSON from WeChat Work.
|
||||
encoding_aes_key: AES key for file decryption.
|
||||
logger: Logger instance.
|
||||
|
||||
Returns:
|
||||
A dict suitable for constructing a WecomBotEvent.
|
||||
"""
|
||||
message_data: dict[str, Any] = {}
|
||||
|
||||
msg_type = msg_json.get('msgtype', '')
|
||||
if msg_type:
|
||||
message_data['msgtype'] = msg_type
|
||||
|
||||
if msg_json.get('chattype', '') == 'single':
|
||||
message_data['type'] = 'single'
|
||||
elif msg_json.get('chattype', '') == 'group':
|
||||
message_data['type'] = 'group'
|
||||
|
||||
max_inline_file_size = 5 * 1024 * 1024
|
||||
|
||||
async def _safe_download(url: str):
|
||||
if not url:
|
||||
return None
|
||||
return await download_encrypted_file(url, encoding_aes_key, logger)
|
||||
|
||||
if msg_type == 'text':
|
||||
message_data['content'] = msg_json.get('text', {}).get('content')
|
||||
elif msg_type == 'markdown':
|
||||
message_data['content'] = msg_json.get('markdown', {}).get('content') or msg_json.get('text', {}).get(
|
||||
'content', ''
|
||||
)
|
||||
elif msg_type == 'image':
|
||||
picurl = msg_json.get('image', {}).get('url', '')
|
||||
base64_data = await _safe_download(picurl)
|
||||
if base64_data:
|
||||
message_data['picurl'] = base64_data
|
||||
message_data['images'] = [base64_data]
|
||||
elif msg_type == 'voice':
|
||||
voice_info = msg_json.get('voice', {}) or {}
|
||||
download_url = voice_info.get('url')
|
||||
message_data['voice'] = {
|
||||
'url': download_url,
|
||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||
}
|
||||
if voice_info.get('content'):
|
||||
message_data['content'] = voice_info.get('content')
|
||||
if (message_data['voice'].get('filesize') or 0) <= max_inline_file_size:
|
||||
voice_base64 = await _safe_download(download_url)
|
||||
if voice_base64:
|
||||
message_data['voice']['base64'] = voice_base64
|
||||
elif msg_type == 'video':
|
||||
video_info = msg_json.get('video', {}) or {}
|
||||
download_url = video_info.get('url')
|
||||
video_data = {
|
||||
'url': download_url,
|
||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||
'filename': video_info.get('filename') or video_info.get('name'),
|
||||
}
|
||||
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
video_base64 = await _safe_download(download_url)
|
||||
if video_base64:
|
||||
video_data['base64'] = video_base64
|
||||
message_data['video'] = video_data
|
||||
elif msg_type == 'file':
|
||||
file_info = msg_json.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
file_base64 = await _safe_download(download_url)
|
||||
if file_base64:
|
||||
file_data['base64'] = file_base64
|
||||
message_data['file'] = file_data
|
||||
elif msg_type == 'link':
|
||||
message_data['link'] = msg_json.get('link', {})
|
||||
if not message_data.get('content'):
|
||||
title = message_data['link'].get('title', '')
|
||||
desc = message_data['link'].get('description') or message_data['link'].get('digest', '')
|
||||
message_data['content'] = '\n'.join(filter(None, [title, desc]))
|
||||
elif msg_type == 'mixed':
|
||||
items = msg_json.get('mixed', {}).get('msg_item', [])
|
||||
texts = []
|
||||
images = []
|
||||
files = []
|
||||
voices = []
|
||||
videos = []
|
||||
links = []
|
||||
for item in items:
|
||||
item_type = item.get('msgtype')
|
||||
if item_type == 'text':
|
||||
texts.append(item.get('text', {}).get('content', ''))
|
||||
elif item_type == 'image':
|
||||
img_url = item.get('image', {}).get('url')
|
||||
base64_data = await _safe_download(img_url)
|
||||
if base64_data:
|
||||
images.append(base64_data)
|
||||
elif item_type == 'file':
|
||||
file_info = item.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
file_base64 = await _safe_download(download_url)
|
||||
if file_base64:
|
||||
file_data['base64'] = file_base64
|
||||
files.append(file_data)
|
||||
elif item_type == 'voice':
|
||||
voice_info = item.get('voice', {}) or {}
|
||||
download_url = voice_info.get('url')
|
||||
voice_data = {
|
||||
'url': download_url,
|
||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||
}
|
||||
if voice_info.get('content'):
|
||||
texts.append(voice_info.get('content'))
|
||||
if (voice_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
voice_base64 = await _safe_download(download_url)
|
||||
if voice_base64:
|
||||
voice_data['base64'] = voice_base64
|
||||
voices.append(voice_data)
|
||||
elif item_type == 'video':
|
||||
video_info = item.get('video', {}) or {}
|
||||
download_url = video_info.get('url')
|
||||
video_data = {
|
||||
'url': download_url,
|
||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||
'filename': video_info.get('filename') or video_info.get('name'),
|
||||
}
|
||||
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
video_base64 = await _safe_download(download_url)
|
||||
if video_base64:
|
||||
video_data['base64'] = video_base64
|
||||
videos.append(video_data)
|
||||
elif item_type == 'link':
|
||||
links.append(item.get('link', {}))
|
||||
|
||||
if texts:
|
||||
message_data['content'] = ' '.join(texts)
|
||||
if images:
|
||||
message_data['images'] = images
|
||||
message_data['picurl'] = images[0]
|
||||
if files:
|
||||
message_data['files'] = files
|
||||
message_data['file'] = files[0]
|
||||
if voices:
|
||||
message_data['voices'] = voices
|
||||
message_data['voice'] = voices[0]
|
||||
if videos:
|
||||
message_data['videos'] = videos
|
||||
message_data['video'] = videos[0]
|
||||
if links:
|
||||
message_data['link'] = links[0]
|
||||
if items:
|
||||
message_data['attachments'] = items
|
||||
else:
|
||||
message_data['raw_msg'] = msg_json
|
||||
|
||||
from_info = msg_json.get('from', {})
|
||||
message_data['userid'] = from_info.get('userid', '')
|
||||
message_data['username'] = from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
|
||||
|
||||
if msg_json.get('chattype', '') == 'group':
|
||||
message_data['chatid'] = msg_json.get('chatid', '')
|
||||
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
|
||||
|
||||
message_data['msgid'] = msg_json.get('msgid', '')
|
||||
|
||||
if msg_json.get('aibotid'):
|
||||
message_data['aibotid'] = msg_json.get('aibotid', '')
|
||||
|
||||
return message_data
|
||||
|
||||
|
||||
class WecomBotClient:
|
||||
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger, unified_mode: bool = False):
|
||||
"""企业微信智能机器人客户端。
|
||||
@@ -702,7 +455,196 @@ class WecomBotClient:
|
||||
return await self._handle_post_initial_response(msg_json, nonce)
|
||||
|
||||
async def get_message(self, msg_json):
|
||||
return await parse_wecom_bot_message(msg_json, self.EnCodingAESKey, self.logger)
|
||||
message_data = {}
|
||||
|
||||
msg_type = msg_json.get('msgtype', '')
|
||||
if msg_type:
|
||||
message_data['msgtype'] = msg_type
|
||||
|
||||
if msg_json.get('chattype', '') == 'single':
|
||||
message_data['type'] = 'single'
|
||||
elif msg_json.get('chattype', '') == 'group':
|
||||
message_data['type'] = 'group'
|
||||
|
||||
max_inline_file_size = 5 * 1024 * 1024 # avoid decoding very large payloads by default
|
||||
|
||||
async def _safe_download(url: str):
|
||||
if not url:
|
||||
return None
|
||||
return await self.download_url_to_base64(url, self.EnCodingAESKey)
|
||||
|
||||
if msg_type == 'text':
|
||||
message_data['content'] = msg_json.get('text', {}).get('content')
|
||||
elif msg_type == 'markdown':
|
||||
message_data['content'] = msg_json.get('markdown', {}).get('content') or msg_json.get('text', {}).get(
|
||||
'content', ''
|
||||
)
|
||||
elif msg_type == 'image':
|
||||
picurl = msg_json.get('image', {}).get('url', '')
|
||||
base64_data = await _safe_download(picurl)
|
||||
if base64_data:
|
||||
message_data['picurl'] = base64_data
|
||||
message_data['images'] = [base64_data]
|
||||
elif msg_type == 'voice':
|
||||
voice_info = msg_json.get('voice', {}) or {}
|
||||
download_url = voice_info.get('url')
|
||||
message_data['voice'] = {
|
||||
'url': download_url,
|
||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||
}
|
||||
# 企业微信智能转写文本(如果已有)直接复用,避免重复转写
|
||||
if voice_info.get('content'):
|
||||
message_data['content'] = voice_info.get('content')
|
||||
if (message_data['voice'].get('filesize') or 0) <= max_inline_file_size:
|
||||
voice_base64 = await _safe_download(download_url)
|
||||
if voice_base64:
|
||||
message_data['voice']['base64'] = voice_base64
|
||||
elif msg_type == 'video':
|
||||
video_info = msg_json.get('video', {}) or {}
|
||||
download_url = video_info.get('url')
|
||||
video_data = {
|
||||
'url': download_url,
|
||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||
'filename': video_info.get('filename') or video_info.get('name'),
|
||||
}
|
||||
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
video_base64 = await _safe_download(download_url)
|
||||
if video_base64:
|
||||
video_data['base64'] = video_base64
|
||||
message_data['video'] = video_data
|
||||
elif msg_type == 'file':
|
||||
file_info = msg_json.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
file_base64 = await _safe_download(download_url)
|
||||
if file_base64:
|
||||
file_data['base64'] = file_base64
|
||||
message_data['file'] = file_data
|
||||
elif msg_type == 'link':
|
||||
message_data['link'] = msg_json.get('link', {})
|
||||
if not message_data.get('content'):
|
||||
title = message_data['link'].get('title', '')
|
||||
desc = message_data['link'].get('description') or message_data['link'].get('digest', '')
|
||||
message_data['content'] = '\n'.join(filter(None, [title, desc]))
|
||||
elif msg_type == 'mixed':
|
||||
items = msg_json.get('mixed', {}).get('msg_item', [])
|
||||
texts = []
|
||||
images = []
|
||||
files = []
|
||||
voices = []
|
||||
videos = []
|
||||
links = []
|
||||
for item in items:
|
||||
item_type = item.get('msgtype')
|
||||
if item_type == 'text':
|
||||
texts.append(item.get('text', {}).get('content', ''))
|
||||
elif item_type == 'image':
|
||||
img_url = item.get('image', {}).get('url')
|
||||
base64_data = await _safe_download(img_url)
|
||||
if base64_data:
|
||||
images.append(base64_data)
|
||||
elif item_type == 'file':
|
||||
file_info = item.get('file', {}) or {}
|
||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||
file_data = {
|
||||
'filename': file_info.get('filename') or file_info.get('name'),
|
||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||
'download_url': download_url,
|
||||
'extra': file_info,
|
||||
}
|
||||
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
file_base64 = await _safe_download(download_url)
|
||||
if file_base64:
|
||||
file_data['base64'] = file_base64
|
||||
files.append(file_data)
|
||||
elif item_type == 'voice':
|
||||
voice_info = item.get('voice', {}) or {}
|
||||
download_url = voice_info.get('url')
|
||||
voice_data = {
|
||||
'url': download_url,
|
||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||
}
|
||||
if voice_info.get('content'):
|
||||
texts.append(voice_info.get('content'))
|
||||
if (voice_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
voice_base64 = await _safe_download(download_url)
|
||||
if voice_base64:
|
||||
voice_data['base64'] = voice_base64
|
||||
voices.append(voice_data)
|
||||
elif item_type == 'video':
|
||||
video_info = item.get('video', {}) or {}
|
||||
download_url = video_info.get('url')
|
||||
video_data = {
|
||||
'url': download_url,
|
||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||
'filename': video_info.get('filename') or video_info.get('name'),
|
||||
}
|
||||
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||
video_base64 = await _safe_download(download_url)
|
||||
if video_base64:
|
||||
video_data['base64'] = video_base64
|
||||
videos.append(video_data)
|
||||
elif item_type == 'link':
|
||||
links.append(item.get('link', {}))
|
||||
|
||||
if texts:
|
||||
message_data['content'] = ' '.join(texts) # 拼接所有 text
|
||||
if images:
|
||||
message_data['images'] = images
|
||||
message_data['picurl'] = images[0] # 只保留第一个 image
|
||||
if files:
|
||||
message_data['files'] = files
|
||||
message_data['file'] = files[0]
|
||||
if voices:
|
||||
message_data['voices'] = voices
|
||||
message_data['voice'] = voices[0]
|
||||
if videos:
|
||||
message_data['videos'] = videos
|
||||
message_data['video'] = videos[0]
|
||||
if links:
|
||||
message_data['link'] = links[0]
|
||||
if items:
|
||||
message_data['attachments'] = items
|
||||
else:
|
||||
message_data['raw_msg'] = msg_json
|
||||
|
||||
# Extract user information
|
||||
from_info = msg_json.get('from', {})
|
||||
message_data['userid'] = from_info.get('userid', '')
|
||||
message_data['username'] = (
|
||||
from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
|
||||
)
|
||||
|
||||
# Extract chat/group information
|
||||
if msg_json.get('chattype', '') == 'group':
|
||||
message_data['chatid'] = msg_json.get('chatid', '')
|
||||
# Try to get group name if available
|
||||
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
|
||||
|
||||
message_data['msgid'] = msg_json.get('msgid', '')
|
||||
|
||||
if msg_json.get('aibotid'):
|
||||
message_data['aibotid'] = msg_json.get('aibotid', '')
|
||||
|
||||
return message_data
|
||||
|
||||
async def _handle_message(self, event: wecombotevent.WecomBotEvent):
|
||||
"""
|
||||
@@ -770,7 +712,39 @@ class WecomBotClient:
|
||||
return decorator
|
||||
|
||||
async def download_url_to_base64(self, download_url, encoding_aes_key):
|
||||
return await download_encrypted_file(download_url, encoding_aes_key, self.logger)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(download_url)
|
||||
if response.status_code != 200:
|
||||
await self.logger.error(f'failed to get file: {response.text}')
|
||||
return None
|
||||
|
||||
encrypted_bytes = response.content
|
||||
|
||||
aes_key = base64.b64decode(encoding_aes_key + '=') # base64 补齐
|
||||
iv = aes_key[:16]
|
||||
|
||||
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
||||
decrypted = cipher.decrypt(encrypted_bytes)
|
||||
|
||||
pad_len = decrypted[-1]
|
||||
decrypted = decrypted[:-pad_len]
|
||||
|
||||
if decrypted.startswith(b'\xff\xd8'): # JPEG
|
||||
mime_type = 'image/jpeg'
|
||||
elif decrypted.startswith(b'\x89PNG'): # PNG
|
||||
mime_type = 'image/png'
|
||||
elif decrypted.startswith((b'GIF87a', b'GIF89a')): # GIF
|
||||
mime_type = 'image/gif'
|
||||
elif decrypted.startswith(b'BM'): # BMP
|
||||
mime_type = 'image/bmp'
|
||||
elif decrypted.startswith(b'II*\x00') or decrypted.startswith(b'MM\x00*'): # TIFF
|
||||
mime_type = 'image/tiff'
|
||||
else:
|
||||
mime_type = 'application/octet-stream'
|
||||
|
||||
# 转 base64
|
||||
base64_str = base64.b64encode(decrypted).decode('utf-8')
|
||||
return f'data:{mime_type};base64,{base64_str}'
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -1,596 +0,0 @@
|
||||
"""WeChat Work AI Bot WebSocket long connection client.
|
||||
|
||||
Implements the WebSocket protocol for receiving messages and sending replies
|
||||
via a persistent connection to wss://openws.work.weixin.qq.com, as an
|
||||
alternative to the HTTP callback (webhook) mode.
|
||||
|
||||
Protocol reference: https://developer.work.weixin.qq.com/document/path/101463
|
||||
Official Node.js SDK: https://github.com/WecomTeam/aibot-node-sdk
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from langbot.libs.wecom_ai_bot_api import wecombotevent
|
||||
from langbot.libs.wecom_ai_bot_api.api import parse_wecom_bot_message
|
||||
from langbot.pkg.platform.logger import EventLogger
|
||||
|
||||
DEFAULT_WS_URL = 'wss://openws.work.weixin.qq.com'
|
||||
|
||||
# WebSocket frame command constants
|
||||
CMD_SUBSCRIBE = 'aibot_subscribe'
|
||||
CMD_HEARTBEAT = 'ping'
|
||||
CMD_MSG_CALLBACK = 'aibot_msg_callback'
|
||||
CMD_EVENT_CALLBACK = 'aibot_event_callback'
|
||||
CMD_RESPOND_MSG = 'aibot_respond_msg'
|
||||
CMD_RESPOND_WELCOME = 'aibot_respond_welcome_msg'
|
||||
CMD_RESPOND_UPDATE = 'aibot_respond_update_msg'
|
||||
CMD_SEND_MSG = 'aibot_send_msg'
|
||||
|
||||
|
||||
def _generate_req_id(prefix: str) -> str:
|
||||
"""Generate a unique request ID in the format: {prefix}_{timestamp}_{random}."""
|
||||
ts = int(time.time() * 1000)
|
||||
rand = secrets.token_hex(4)
|
||||
return f'{prefix}_{ts}_{rand}'
|
||||
|
||||
|
||||
class WecomBotWsClient:
|
||||
"""WeChat Work AI Bot WebSocket long connection client.
|
||||
|
||||
Provides message receiving, streaming reply, proactive message sending,
|
||||
and event callback handling over a persistent WebSocket connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot_id: str,
|
||||
secret: str,
|
||||
logger: EventLogger,
|
||||
encoding_aes_key: str = '',
|
||||
ws_url: str = DEFAULT_WS_URL,
|
||||
heartbeat_interval: float = 30.0,
|
||||
max_reconnect_attempts: int = -1,
|
||||
reconnect_base_delay: float = 1.0,
|
||||
reconnect_max_delay: float = 30.0,
|
||||
):
|
||||
self.bot_id = bot_id
|
||||
self.secret = secret
|
||||
self.logger = logger
|
||||
self.encoding_aes_key = encoding_aes_key
|
||||
self.ws_url = ws_url
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
self.max_reconnect_attempts = max_reconnect_attempts
|
||||
self.reconnect_base_delay = reconnect_base_delay
|
||||
self.reconnect_max_delay = reconnect_max_delay
|
||||
|
||||
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._running = False
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._missed_pong_count = 0
|
||||
self._max_missed_pong = 2
|
||||
self._reconnect_attempts = 0
|
||||
|
||||
# Message handler registry (same pattern as WecomBotClient)
|
||||
self._message_handlers: dict[str, list[Callable]] = {}
|
||||
# Message deduplication
|
||||
self._msg_id_map: dict[str, int] = {}
|
||||
|
||||
# Pending ACK futures: req_id -> Future[dict]
|
||||
self._pending_acks: dict[str, asyncio.Future] = {}
|
||||
# Per-req_id serial reply queues
|
||||
self._reply_queues: dict[str, asyncio.Queue] = {}
|
||||
self._reply_workers: dict[str, asyncio.Task] = {}
|
||||
self._reply_ack_timeout = 5.0
|
||||
|
||||
# Stream ID tracking for WebSocket mode
|
||||
self._stream_ids: dict[str, str] = {} # msg_id -> req_id|stream_id
|
||||
# Dedup: skip sending when content hasn't changed
|
||||
self._stream_last_content: dict[str, str] = {} # msg_id -> last content sent
|
||||
|
||||
# ── Public API ──────────────────────────────────────────────────
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to WebSocket server with automatic reconnection.
|
||||
|
||||
This method blocks until disconnect() is called or max reconnect
|
||||
attempts are exhausted.
|
||||
"""
|
||||
self._running = True
|
||||
self._reconnect_attempts = 0
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await self._connect_once()
|
||||
except Exception:
|
||||
if not self._running:
|
||||
break
|
||||
await self.logger.error(f'WebSocket connection error: {traceback.format_exc()}')
|
||||
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
# Reconnect with exponential backoff
|
||||
if self.max_reconnect_attempts != -1 and self._reconnect_attempts >= self.max_reconnect_attempts:
|
||||
await self.logger.error(f'Max reconnect attempts reached ({self.max_reconnect_attempts}), giving up')
|
||||
break
|
||||
|
||||
self._reconnect_attempts += 1
|
||||
delay = min(
|
||||
self.reconnect_base_delay * (2 ** (self._reconnect_attempts - 1)),
|
||||
self.reconnect_max_delay,
|
||||
)
|
||||
await self.logger.info(f'Reconnecting in {delay:.1f}s (attempt {self._reconnect_attempts})...')
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async def disconnect(self):
|
||||
"""Gracefully disconnect from the WebSocket server."""
|
||||
self._running = False
|
||||
if self._heartbeat_task and not self._heartbeat_task.done():
|
||||
self._heartbeat_task.cancel()
|
||||
for task in self._reply_workers.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
if self._ws and not self._ws.closed:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
def on_message(self, msg_type: str) -> Callable:
|
||||
"""Decorator to register a message handler.
|
||||
|
||||
Same interface as WecomBotClient.on_message for compatibility.
|
||||
|
||||
Args:
|
||||
msg_type: 'single', 'group', or specific message type.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[[wecombotevent.WecomBotEvent], Any]):
|
||||
if msg_type not in self._message_handlers:
|
||||
self._message_handlers[msg_type] = []
|
||||
self._message_handlers[msg_type].append(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def reply_stream(
|
||||
self,
|
||||
req_id: str,
|
||||
stream_id: str,
|
||||
content: str,
|
||||
finish: bool = False,
|
||||
) -> Optional[dict]:
|
||||
"""Send a streaming reply frame.
|
||||
|
||||
Args:
|
||||
req_id: The req_id from the original message frame (must be passed through).
|
||||
stream_id: The stream ID for this streaming session.
|
||||
content: The content to send (supports Markdown).
|
||||
finish: Whether this is the final chunk.
|
||||
|
||||
Returns:
|
||||
The ACK frame dict, or None on failure.
|
||||
"""
|
||||
body = {
|
||||
'msgtype': 'stream',
|
||||
'stream': {
|
||||
'id': stream_id,
|
||||
'finish': finish,
|
||||
'content': content,
|
||||
},
|
||||
}
|
||||
return await self._send_reply(req_id, body)
|
||||
|
||||
async def reply_text(self, req_id: str, content: str) -> Optional[dict]:
|
||||
"""Send a non-streaming text reply.
|
||||
|
||||
Args:
|
||||
req_id: The req_id from the original message frame.
|
||||
content: The text content to reply.
|
||||
|
||||
Returns:
|
||||
The ACK frame dict, or None on failure.
|
||||
"""
|
||||
body = {
|
||||
'msgtype': 'markdown',
|
||||
'markdown': {
|
||||
'content': content,
|
||||
},
|
||||
}
|
||||
return await self._send_reply(req_id, body)
|
||||
|
||||
async def send_message(self, chat_id: str, content: str, msgtype: str = 'markdown') -> Optional[dict]:
|
||||
"""Proactively send a message to a specified chat.
|
||||
|
||||
Args:
|
||||
chat_id: The chat ID (userid for single chat, chatid for group chat).
|
||||
content: The message content.
|
||||
msgtype: Message type, 'markdown' by default.
|
||||
|
||||
Returns:
|
||||
The ACK frame dict, or None on failure.
|
||||
"""
|
||||
req_id = _generate_req_id(CMD_SEND_MSG)
|
||||
body: dict[str, Any] = {
|
||||
'chatid': chat_id,
|
||||
'msgtype': msgtype,
|
||||
}
|
||||
if msgtype == 'markdown':
|
||||
body['markdown'] = {'content': content}
|
||||
elif msgtype == 'text':
|
||||
body['text'] = {'content': content}
|
||||
return await self._send_reply(req_id, body, cmd=CMD_SEND_MSG)
|
||||
|
||||
async def push_stream_chunk(self, msg_id: str, content: str, is_final: bool = False) -> bool:
|
||||
"""Push a streaming chunk for a given message ID.
|
||||
|
||||
Compatible interface with WecomBotClient.push_stream_chunk.
|
||||
|
||||
Args:
|
||||
msg_id: The original message ID.
|
||||
content: The cumulative content from the pipeline.
|
||||
is_final: Whether this is the final chunk.
|
||||
|
||||
Returns:
|
||||
True if the stream session exists and chunk was sent.
|
||||
"""
|
||||
key = self._stream_ids.get(msg_id)
|
||||
if not key:
|
||||
return False
|
||||
req_id, stream_id = key.split('|', 1)
|
||||
try:
|
||||
# Skip sending if content hasn't changed (e.g. during tool call argument streaming)
|
||||
if not is_final and content == self._stream_last_content.get(msg_id):
|
||||
return True
|
||||
await self.reply_stream(req_id, stream_id, content, finish=is_final)
|
||||
self._stream_last_content[msg_id] = content
|
||||
if is_final:
|
||||
self._stream_ids.pop(msg_id, None)
|
||||
self._stream_last_content.pop(msg_id, None)
|
||||
return True
|
||||
except Exception:
|
||||
await self.logger.error(f'Failed to push stream chunk: {traceback.format_exc()}')
|
||||
return False
|
||||
|
||||
async def set_message(self, msg_id: str, content: str):
|
||||
"""Fallback: send content as a final stream chunk or direct reply.
|
||||
|
||||
Compatible interface with WecomBotClient.set_message.
|
||||
"""
|
||||
handled = await self.push_stream_chunk(msg_id, content, is_final=True)
|
||||
if not handled:
|
||||
await self.logger.warning(f'No active stream for msg_id={msg_id}, message dropped')
|
||||
|
||||
# ── Connection lifecycle ────────────────────────────────────────
|
||||
|
||||
async def _connect_once(self):
|
||||
"""Establish a single WebSocket connection, authenticate, and listen."""
|
||||
await self.logger.info(f'Connecting to {self.ws_url}...')
|
||||
|
||||
self._session = aiohttp.ClientSession()
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(self.ws_url)
|
||||
self._missed_pong_count = 0
|
||||
self._reconnect_attempts = 0
|
||||
await self.logger.info('WebSocket connected, sending auth...')
|
||||
|
||||
await self._send_auth()
|
||||
|
||||
# Wait for auth response
|
||||
auth_ok = await self._wait_for_auth()
|
||||
if not auth_ok:
|
||||
await self.logger.error('Authentication failed')
|
||||
return
|
||||
|
||||
await self.logger.info('Authenticated successfully')
|
||||
|
||||
# Start heartbeat
|
||||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
|
||||
try:
|
||||
await self._listen_loop()
|
||||
finally:
|
||||
if self._heartbeat_task and not self._heartbeat_task.done():
|
||||
self._heartbeat_task.cancel()
|
||||
self._clear_pending_acks('Connection closed')
|
||||
finally:
|
||||
if self._ws and not self._ws.closed:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def _send_auth(self):
|
||||
"""Send the authentication frame."""
|
||||
frame = {
|
||||
'cmd': CMD_SUBSCRIBE,
|
||||
'headers': {'req_id': _generate_req_id(CMD_SUBSCRIBE)},
|
||||
'body': {
|
||||
'bot_id': self.bot_id,
|
||||
'secret': self.secret,
|
||||
},
|
||||
}
|
||||
await self._send_frame(frame)
|
||||
|
||||
async def _wait_for_auth(self) -> bool:
|
||||
"""Wait for and validate the authentication response."""
|
||||
try:
|
||||
msg = await asyncio.wait_for(self._ws.receive(), timeout=10.0)
|
||||
if msg.type in (aiohttp.WSMsgType.TEXT,):
|
||||
frame = json.loads(msg.data)
|
||||
req_id = frame.get('headers', {}).get('req_id', '')
|
||||
if req_id.startswith(CMD_SUBSCRIBE) and frame.get('errcode') == 0:
|
||||
return True
|
||||
await self.logger.error(f'Auth response: errcode={frame.get("errcode")}, errmsg={frame.get("errmsg")}')
|
||||
return False
|
||||
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||
await self.logger.error(f'WebSocket closed during auth: {msg.type}')
|
||||
return False
|
||||
await self.logger.error(f'Unexpected message type during auth: {msg.type}')
|
||||
return False
|
||||
except asyncio.TimeoutError:
|
||||
await self.logger.error('Auth response timeout')
|
||||
return False
|
||||
|
||||
async def _heartbeat_loop(self):
|
||||
"""Periodically send heartbeat pings."""
|
||||
try:
|
||||
while self._running and self._ws and not self._ws.closed:
|
||||
await asyncio.sleep(self.heartbeat_interval)
|
||||
if not self._running or not self._ws or self._ws.closed:
|
||||
break
|
||||
|
||||
if self._missed_pong_count >= self._max_missed_pong:
|
||||
await self.logger.warning(
|
||||
f'No heartbeat ack for {self._missed_pong_count} consecutive pings, connection considered dead'
|
||||
)
|
||||
await self._ws.close()
|
||||
break
|
||||
|
||||
self._missed_pong_count += 1
|
||||
frame = {
|
||||
'cmd': CMD_HEARTBEAT,
|
||||
'headers': {'req_id': _generate_req_id(CMD_HEARTBEAT)},
|
||||
}
|
||||
try:
|
||||
await self._send_frame(frame)
|
||||
except Exception:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _listen_loop(self):
|
||||
"""Listen for incoming WebSocket frames and dispatch them."""
|
||||
async for msg in self._ws:
|
||||
if not self._running:
|
||||
break
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
frame = json.loads(msg.data)
|
||||
await self._handle_frame(frame)
|
||||
except json.JSONDecodeError:
|
||||
await self.logger.error(f'Failed to parse WebSocket message: {str(msg.data)[:200]}')
|
||||
except Exception:
|
||||
await self.logger.error(f'Error handling frame: {traceback.format_exc()}')
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
try:
|
||||
frame = json.loads(msg.data)
|
||||
await self._handle_frame(frame)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error handling binary frame: {traceback.format_exc()}')
|
||||
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||
await self.logger.warning(f'WebSocket connection closed: {msg.type}')
|
||||
break
|
||||
|
||||
# ── Frame handling ──────────────────────────────────────────────
|
||||
|
||||
async def _handle_frame(self, frame: dict):
|
||||
"""Route an incoming frame to the appropriate handler."""
|
||||
cmd = frame.get('cmd', '')
|
||||
|
||||
# Message push
|
||||
if cmd == CMD_MSG_CALLBACK:
|
||||
asyncio.create_task(self._handle_message_callback(frame))
|
||||
return
|
||||
|
||||
# Event push
|
||||
if cmd == CMD_EVENT_CALLBACK:
|
||||
asyncio.create_task(self._handle_event_callback(frame))
|
||||
return
|
||||
|
||||
# No cmd → response/ACK frame, dispatch by req_id prefix
|
||||
req_id = frame.get('headers', {}).get('req_id', '')
|
||||
|
||||
# Check pending ACKs first
|
||||
if req_id in self._pending_acks:
|
||||
future = self._pending_acks.pop(req_id)
|
||||
if not future.done():
|
||||
future.set_result(frame)
|
||||
return
|
||||
|
||||
# Heartbeat response
|
||||
if req_id.startswith(CMD_HEARTBEAT):
|
||||
if frame.get('errcode') == 0:
|
||||
self._missed_pong_count = 0
|
||||
return
|
||||
|
||||
# Unknown frame
|
||||
await self.logger.warning(f'Unknown frame: {json.dumps(frame, ensure_ascii=False)[:200]}')
|
||||
|
||||
async def _handle_message_callback(self, frame: dict):
|
||||
"""Handle an incoming message callback frame."""
|
||||
try:
|
||||
body = frame.get('body', {})
|
||||
req_id = frame.get('headers', {}).get('req_id', '')
|
||||
|
||||
# Parse message using shared logic
|
||||
message_data = await parse_wecom_bot_message(body, self.encoding_aes_key, self.logger)
|
||||
if not message_data:
|
||||
return
|
||||
|
||||
# Generate stream_id for this message and store the mapping
|
||||
stream_id = _generate_req_id('stream')
|
||||
msg_id = message_data.get('msgid', '')
|
||||
if msg_id:
|
||||
self._stream_ids[msg_id] = f'{req_id}|{stream_id}'
|
||||
message_data['stream_id'] = stream_id
|
||||
message_data['req_id'] = req_id
|
||||
|
||||
event = wecombotevent.WecomBotEvent(message_data)
|
||||
await self._dispatch_event(event)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in message callback: {traceback.format_exc()}')
|
||||
|
||||
async def _handle_event_callback(self, frame: dict):
|
||||
"""Handle an incoming event callback frame (enter_chat, template_card_event, etc.)."""
|
||||
try:
|
||||
body = frame.get('body', {})
|
||||
req_id = frame.get('headers', {}).get('req_id', '')
|
||||
|
||||
event_info = body.get('event', {})
|
||||
event_type = event_info.get('eventtype', '')
|
||||
|
||||
message_data = {
|
||||
'msgtype': 'event',
|
||||
'type': body.get('chattype', 'single'),
|
||||
'event': event_info,
|
||||
'eventtype': event_type,
|
||||
'msgid': body.get('msgid', ''),
|
||||
'aibotid': body.get('aibotid', ''),
|
||||
'req_id': req_id,
|
||||
}
|
||||
|
||||
from_info = body.get('from', {})
|
||||
message_data['userid'] = from_info.get('userid', '')
|
||||
message_data['username'] = from_info.get('alias', '') or from_info.get('userid', '')
|
||||
|
||||
if body.get('chatid'):
|
||||
message_data['chatid'] = body.get('chatid', '')
|
||||
|
||||
event = wecombotevent.WecomBotEvent(message_data)
|
||||
|
||||
# Dispatch to event-specific handlers
|
||||
if event_type in self._message_handlers:
|
||||
for handler in self._message_handlers[event_type]:
|
||||
await handler(event)
|
||||
|
||||
# Also dispatch to generic 'event' handlers
|
||||
if 'event' in self._message_handlers:
|
||||
for handler in self._message_handlers['event']:
|
||||
await handler(event)
|
||||
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in event callback: {traceback.format_exc()}')
|
||||
|
||||
async def _dispatch_event(self, event: wecombotevent.WecomBotEvent):
|
||||
"""Dispatch a message event to registered handlers with deduplication."""
|
||||
try:
|
||||
message_id = event.message_id
|
||||
if message_id in self._msg_id_map:
|
||||
self._msg_id_map[message_id] += 1
|
||||
return
|
||||
self._msg_id_map[message_id] = 1
|
||||
|
||||
msg_type = event.type
|
||||
if msg_type in self._message_handlers:
|
||||
for handler in self._message_handlers[msg_type]:
|
||||
await handler(event)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error dispatching event: {traceback.format_exc()}')
|
||||
|
||||
# ── Reply sending with serial queue ─────────────────────────────
|
||||
|
||||
async def _send_reply(
|
||||
self,
|
||||
req_id: str,
|
||||
body: dict,
|
||||
cmd: str = CMD_RESPOND_MSG,
|
||||
) -> Optional[dict]:
|
||||
"""Send a reply frame and wait for ACK.
|
||||
|
||||
Replies with the same req_id are serialized to maintain ordering.
|
||||
"""
|
||||
if not self._ws or self._ws.closed:
|
||||
return None
|
||||
|
||||
frame = {
|
||||
'cmd': cmd,
|
||||
'headers': {'req_id': req_id},
|
||||
'body': body,
|
||||
}
|
||||
|
||||
# Ensure serial delivery per req_id
|
||||
if req_id not in self._reply_queues:
|
||||
self._reply_queues[req_id] = asyncio.Queue()
|
||||
self._reply_workers[req_id] = asyncio.create_task(self._reply_queue_worker(req_id))
|
||||
|
||||
future: asyncio.Future = asyncio.get_event_loop().create_future()
|
||||
await self._reply_queues[req_id].put((frame, future))
|
||||
return await future
|
||||
|
||||
async def _reply_queue_worker(self, req_id: str):
|
||||
"""Process reply queue items serially for a given req_id."""
|
||||
queue = self._reply_queues[req_id]
|
||||
try:
|
||||
while self._running:
|
||||
try:
|
||||
frame, future = await asyncio.wait_for(queue.get(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
# Queue idle, clean up worker
|
||||
break
|
||||
|
||||
try:
|
||||
ack = await self._send_and_wait_ack(frame)
|
||||
if not future.done():
|
||||
future.set_result(ack)
|
||||
except Exception as e:
|
||||
if not future.done():
|
||||
future.set_exception(e)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._reply_queues.pop(req_id, None)
|
||||
self._reply_workers.pop(req_id, None)
|
||||
|
||||
async def _send_and_wait_ack(self, frame: dict) -> Optional[dict]:
|
||||
"""Send a frame and wait for the corresponding ACK."""
|
||||
req_id = frame['headers']['req_id']
|
||||
ack_future: asyncio.Future = asyncio.get_event_loop().create_future()
|
||||
self._pending_acks[req_id] = ack_future
|
||||
|
||||
try:
|
||||
await self._send_frame(frame)
|
||||
result = await asyncio.wait_for(ack_future, timeout=self._reply_ack_timeout)
|
||||
if result.get('errcode', 0) != 0:
|
||||
await self.logger.warning(
|
||||
f'Reply ACK error: errcode={result.get("errcode")}, errmsg={result.get("errmsg")}'
|
||||
)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_acks.pop(req_id, None)
|
||||
await self.logger.warning(f'Reply ACK timeout ({self._reply_ack_timeout}s) for req_id={req_id}')
|
||||
return None
|
||||
|
||||
async def _send_frame(self, frame: dict):
|
||||
"""Send a JSON frame over the WebSocket connection."""
|
||||
if self._ws and not self._ws.closed:
|
||||
await self._ws.send_str(json.dumps(frame, ensure_ascii=False))
|
||||
|
||||
def _clear_pending_acks(self, reason: str):
|
||||
"""Reject all pending ACK futures on disconnection."""
|
||||
for req_id, future in self._pending_acks.items():
|
||||
if not future.done():
|
||||
future.set_exception(ConnectionError(reason))
|
||||
self._pending_acks.clear()
|
||||
@@ -10,7 +10,6 @@ from typing import Callable
|
||||
from .wecomcsevent import WecomCSEvent
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import aiofiles
|
||||
import time
|
||||
|
||||
|
||||
class WecomCSClient:
|
||||
@@ -35,10 +34,6 @@ class WecomCSClient:
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
|
||||
# Customer info cache: {external_userid: (info_dict, timestamp)}
|
||||
self._customer_cache: dict[str, tuple[dict, float]] = {}
|
||||
self._cache_ttl = 60 # Cache TTL in seconds (1 minute)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
@@ -383,53 +378,3 @@ class WecomCSClient:
|
||||
async def get_media_id(self, image: platform_message.Image):
|
||||
media_id = await self.upload_to_work(image=image)
|
||||
return media_id
|
||||
|
||||
async def get_customer_info(self, external_userid: str) -> dict | None:
|
||||
"""
|
||||
Get customer information by external_userid with caching.
|
||||
|
||||
Uses a 1-minute cache to avoid repeated API calls for the same user.
|
||||
|
||||
Args:
|
||||
external_userid: The external user ID of the customer.
|
||||
|
||||
Returns:
|
||||
Customer info dict with 'nickname', 'avatar', etc., or None if not found.
|
||||
"""
|
||||
# Check cache first
|
||||
current_time = time.time()
|
||||
if external_userid in self._customer_cache:
|
||||
cached_info, cached_time = self._customer_cache[external_userid]
|
||||
if current_time - cached_time < self._cache_ttl:
|
||||
return cached_info
|
||||
|
||||
# Cache miss or expired, fetch from API
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
|
||||
url = f'{self.base_url}/kf/customer/batchget?access_token={self.access_token}'
|
||||
|
||||
payload = {
|
||||
'external_userid_list': [external_userid],
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, json=payload)
|
||||
data = response.json()
|
||||
|
||||
if data.get('errcode') in [40014, 42001]:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.get_customer_info(external_userid)
|
||||
|
||||
if data.get('errcode', 0) != 0:
|
||||
if self.logger:
|
||||
await self.logger.warning(f'Failed to get customer info: {data}')
|
||||
return None
|
||||
|
||||
customer_list = data.get('customer_list', [])
|
||||
if customer_list:
|
||||
customer_info = customer_list[0]
|
||||
# Store in cache
|
||||
self._customer_cache[external_userid] = (customer_info, current_time)
|
||||
return customer_info
|
||||
return None
|
||||
|
||||
@@ -13,10 +13,7 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
try:
|
||||
knowledge_base_uuid = await self.ap.knowledge_service.create_knowledge_base(json_data)
|
||||
except ValueError as e:
|
||||
return self.http_status(400, -1, str(e))
|
||||
knowledge_base_uuid = await self.ap.knowledge_service.create_knowledge_base(json_data)
|
||||
return self.success(data={'uuid': knowledge_base_uuid})
|
||||
|
||||
return self.http_status(405, -1, 'Method not allowed')
|
||||
@@ -42,7 +39,7 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
await self.ap.knowledge_service.update_knowledge_base(knowledge_base_uuid, json_data)
|
||||
return self.success(data={'uuid': knowledge_base_uuid})
|
||||
return self.success({})
|
||||
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.knowledge_service.delete_knowledge_base(knowledge_base_uuid)
|
||||
@@ -68,12 +65,8 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
if not file_id:
|
||||
return self.http_status(400, -1, 'File ID is required')
|
||||
|
||||
parser_plugin_id = json_data.get('parser_plugin_id')
|
||||
|
||||
# 调用服务层方法将文件与知识库关联
|
||||
task_id = await self.ap.knowledge_service.store_file(
|
||||
knowledge_base_uuid, file_id, parser_plugin_id=parser_plugin_id
|
||||
)
|
||||
task_id = await self.ap.knowledge_service.store_file(knowledge_base_uuid, file_id)
|
||||
return self.success(
|
||||
{
|
||||
'task_id': task_id,
|
||||
@@ -97,13 +90,5 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
async def retrieve_knowledge_base(knowledge_base_uuid: str) -> str:
|
||||
json_data = await quart.request.json
|
||||
query = json_data.get('query')
|
||||
|
||||
if not query or not query.strip():
|
||||
return self.http_status(400, -1, 'Query is required and cannot be empty')
|
||||
|
||||
# Extract retrieval_settings to allow dynamic control over Knowledge Engine behavior (e.g. top_k, filters)
|
||||
retrieval_settings = json_data.get('retrieval_settings', {})
|
||||
results = await self.ap.knowledge_service.retrieve_knowledge_base(
|
||||
knowledge_base_uuid, query, retrieval_settings
|
||||
)
|
||||
results = await self.ap.knowledge_service.retrieve_knowledge_base(knowledge_base_uuid, query)
|
||||
return self.success(data={'results': results})
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
import quart
|
||||
from urllib.parse import unquote
|
||||
from ... import group
|
||||
|
||||
|
||||
@group.group_class('knowledge_engines', '/api/v1/knowledge/engines')
|
||||
class KnowledgeEnginesRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def list_knowledge_engines() -> quart.Response:
|
||||
"""List all available Knowledge Engines from plugins.
|
||||
|
||||
Returns a list of Knowledge Engines with their capabilities and configuration schemas.
|
||||
This is used by the frontend to render the knowledge base creation wizard.
|
||||
"""
|
||||
engines = await self.ap.knowledge_service.list_knowledge_engines()
|
||||
return self.success(data={'engines': engines})
|
||||
|
||||
@self.route(
|
||||
'/<path:plugin_id>/creation-schema', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY
|
||||
)
|
||||
async def get_engine_creation_schema(plugin_id: str) -> quart.Response:
|
||||
"""Get creation settings schema for a specific Knowledge Engine.
|
||||
|
||||
plugin_id is in 'author/name' format, captured via <path:> converter.
|
||||
"""
|
||||
plugin_id = unquote(plugin_id)
|
||||
if '/' not in plugin_id:
|
||||
return self.http_status(400, -1, 'Invalid plugin_id format. Expected author/name.')
|
||||
schema = await self.ap.knowledge_service.get_engine_creation_schema(plugin_id)
|
||||
return self.success(data={'schema': schema})
|
||||
|
||||
@self.route(
|
||||
'/<path:plugin_id>/retrieval-schema', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY
|
||||
)
|
||||
async def get_engine_retrieval_schema(plugin_id: str) -> quart.Response:
|
||||
"""Get retrieval settings schema for a specific Knowledge Engine.
|
||||
|
||||
plugin_id is in 'author/name' format, captured via <path:> converter.
|
||||
"""
|
||||
plugin_id = unquote(plugin_id)
|
||||
if '/' not in plugin_id:
|
||||
return self.http_status(400, -1, 'Invalid plugin_id format. Expected author/name.')
|
||||
schema = await self.ap.knowledge_service.get_engine_retrieval_schema(plugin_id)
|
||||
return self.success(data={'schema': schema})
|
||||
@@ -0,0 +1,61 @@
|
||||
import quart
|
||||
from ... import group
|
||||
|
||||
|
||||
@group.group_class('external_knowledge_base', '/api/v1/knowledge/external-bases')
|
||||
class ExternalKnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/retrievers', methods=['GET'])
|
||||
async def list_knowledge_retrievers() -> quart.Response:
|
||||
"""List all available knowledge retrievers from plugins."""
|
||||
retrievers = await self.ap.plugin_connector.list_knowledge_retrievers()
|
||||
return self.success(data={'retrievers': retrievers})
|
||||
|
||||
@self.route('', methods=['POST', 'GET'])
|
||||
async def handle_external_knowledge_bases() -> quart.Response:
|
||||
if quart.request.method == 'GET':
|
||||
external_kbs = await self.ap.external_kb_service.get_external_knowledge_bases()
|
||||
return self.success(data={'bases': external_kbs})
|
||||
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
kb_uuid = await self.ap.external_kb_service.create_external_knowledge_base(json_data)
|
||||
return self.success(data={'uuid': kb_uuid})
|
||||
|
||||
return self.http_status(405, -1, 'Method not allowed')
|
||||
|
||||
@self.route(
|
||||
'/<kb_uuid>',
|
||||
methods=['GET', 'DELETE', 'PUT'],
|
||||
)
|
||||
async def handle_specific_external_knowledge_base(kb_uuid: str) -> quart.Response:
|
||||
if quart.request.method == 'GET':
|
||||
external_kb = await self.ap.external_kb_service.get_external_knowledge_base(kb_uuid)
|
||||
|
||||
if external_kb is None:
|
||||
return self.http_status(404, -1, 'external knowledge base not found')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'base': external_kb,
|
||||
}
|
||||
)
|
||||
|
||||
elif quart.request.method == 'PUT':
|
||||
json_data = await quart.request.json
|
||||
await self.ap.external_kb_service.update_external_knowledge_base(kb_uuid, json_data)
|
||||
return self.success({})
|
||||
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.external_kb_service.delete_external_knowledge_base(kb_uuid)
|
||||
return self.success({})
|
||||
|
||||
@self.route(
|
||||
'/<kb_uuid>/retrieve',
|
||||
methods=['POST'],
|
||||
)
|
||||
async def retrieve_external_knowledge_base(kb_uuid: str) -> str:
|
||||
json_data = await quart.request.json
|
||||
query = json_data.get('query')
|
||||
results = await self.ap.external_kb_service.retrieve_external_knowledge_base(kb_uuid, query)
|
||||
return self.success(data={'results': results})
|
||||
@@ -1,372 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import quart
|
||||
import sqlalchemy
|
||||
|
||||
from ... import group
|
||||
from ......core import taskmgr
|
||||
from ......entity.persistence import metadata as persistence_metadata
|
||||
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
|
||||
|
||||
LANGRAG_PLUGIN_AUTHOR = 'langbot-team'
|
||||
LANGRAG_PLUGIN_NAME = 'LangRAG'
|
||||
LANGRAG_PLUGIN_ID = f'{LANGRAG_PLUGIN_AUTHOR}/{LANGRAG_PLUGIN_NAME}'
|
||||
DEFAULT_SPACE_URL = 'https://space.langbot.app'
|
||||
|
||||
# Old Retriever plugin_name -> New Connector plugin_name
|
||||
EXTERNAL_PLUGIN_NAME_MAPPING = {
|
||||
'DifyDatasetsRetriever': 'DifyDatasetsConnector',
|
||||
'RAGFlowRetriever': 'RAGFlowConnector',
|
||||
'FastGPTRetriever': 'FastGPTConnector',
|
||||
}
|
||||
|
||||
# Per-plugin: which old retriever_config fields belong to creation_settings.
|
||||
# Remaining fields go to retrieval_settings.
|
||||
# None means ALL fields go to creation_settings (no retrieval_schema).
|
||||
EXTERNAL_PLUGIN_CREATION_FIELDS: dict[str, set[str] | None] = {
|
||||
'langbot-team/DifyDatasetsConnector': {'api_base_url', 'dify_apikey', 'dataset_id'},
|
||||
'langbot-team/RAGFlowConnector': {'api_base_url', 'api_key', 'dataset_ids'},
|
||||
'langbot-team/FastGPTConnector': None, # all fields -> creation_settings
|
||||
}
|
||||
|
||||
|
||||
@group.group_class('knowledge/migration', '/api/v1/knowledge/migration')
|
||||
class KnowledgeMigrationRouterGroup(group.RouterGroup):
|
||||
async def _get_migration_flag(self) -> bool:
|
||||
"""Check if rag_plugin_migration_needed flag is set."""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_metadata.Metadata).where(
|
||||
persistence_metadata.Metadata.key == 'rag_plugin_migration_needed'
|
||||
)
|
||||
)
|
||||
row = result.first()
|
||||
return row is not None and row.value == 'true'
|
||||
|
||||
async def _set_migration_flag(self, value: str):
|
||||
"""Set rag_plugin_migration_needed flag."""
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_metadata.Metadata)
|
||||
.where(persistence_metadata.Metadata.key == 'rag_plugin_migration_needed')
|
||||
.values(value=value)
|
||||
)
|
||||
|
||||
async def _table_exists(self, table_name: str) -> bool:
|
||||
"""Check if a table exists."""
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = :table_name);'
|
||||
).bindparams(table_name=table_name)
|
||||
)
|
||||
return result.scalar()
|
||||
else:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name;").bindparams(
|
||||
table_name=table_name
|
||||
)
|
||||
)
|
||||
return result.first() is not None
|
||||
|
||||
async def _install_plugin_from_marketplace(
|
||||
self, plugin_id: str, task_context: taskmgr.TaskContext, space_url: str
|
||||
) -> None:
|
||||
"""Install a single plugin from the marketplace."""
|
||||
p_author, p_name = plugin_id.split('/', 1)
|
||||
self.ap.logger.info(f'RAG migration: installing plugin {plugin_id} from marketplace...')
|
||||
task_context.trace(f'Installing plugin {plugin_id} from marketplace...')
|
||||
|
||||
async with httpx.AsyncClient(trust_env=True, timeout=15) as client:
|
||||
resp = await client.get(f'{space_url}/api/v1/marketplace/plugins/{p_author}/{p_name}')
|
||||
resp.raise_for_status()
|
||||
p_data = resp.json().get('data', {}).get('plugin', {})
|
||||
p_version = p_data.get('latest_version')
|
||||
if not p_version:
|
||||
raise Exception(f'Could not determine latest version for {plugin_id}')
|
||||
|
||||
await self.ap.plugin_connector.install_plugin(
|
||||
PluginInstallSource.MARKETPLACE,
|
||||
{
|
||||
'plugin_author': p_author,
|
||||
'plugin_name': p_name,
|
||||
'plugin_version': p_version,
|
||||
},
|
||||
task_context=task_context,
|
||||
)
|
||||
self.ap.logger.info(f'RAG migration: plugin {plugin_id} install request sent.')
|
||||
|
||||
async def _execute_rag_migration(self, task_context: taskmgr.TaskContext, install_plugin: bool = True):
|
||||
"""Execute RAG migration: install required plugins and restore backup data."""
|
||||
warnings = []
|
||||
|
||||
# Collect all plugins we need: LangRAG (always) + connector plugins (from external KBs)
|
||||
needed_plugins: dict[str, str] = {
|
||||
LANGRAG_PLUGIN_ID: LANGRAG_PLUGIN_NAME,
|
||||
}
|
||||
|
||||
has_external = await self._table_exists('external_knowledge_bases')
|
||||
if has_external:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT DISTINCT plugin_author, plugin_name FROM external_knowledge_bases;')
|
||||
)
|
||||
for row in result.fetchall():
|
||||
plugin_author = row[0] or ''
|
||||
plugin_name = row[1] or ''
|
||||
mapped_name = EXTERNAL_PLUGIN_NAME_MAPPING.get(plugin_name, plugin_name)
|
||||
plugin_id = f'{plugin_author}/{mapped_name}'
|
||||
if plugin_id not in needed_plugins:
|
||||
needed_plugins[plugin_id] = mapped_name
|
||||
|
||||
self.ap.logger.info(f'RAG migration: plugins needed: {list(needed_plugins.keys())}')
|
||||
|
||||
if install_plugin:
|
||||
# Step 1: Install all required plugins from marketplace
|
||||
task_context.trace('Installing required plugins...', action='install-plugin')
|
||||
space_url = self.ap.instance_config.data.get('space', {}).get('url', DEFAULT_SPACE_URL).rstrip('/')
|
||||
|
||||
for plugin_id in needed_plugins:
|
||||
try:
|
||||
await self._install_plugin_from_marketplace(plugin_id, task_context, space_url)
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'RAG migration: plugin {plugin_id} install returned: {e}')
|
||||
task_context.trace(f'Plugin install note ({plugin_id}): {e}')
|
||||
|
||||
# Step 2: Wait for all plugins to become available as knowledge engines
|
||||
task_context.trace(
|
||||
f'Waiting for plugins to become available: {list(needed_plugins.keys())}...',
|
||||
action='wait-plugin',
|
||||
)
|
||||
max_retries = 30
|
||||
engine_id_set: set[str] = set()
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
engines = await self.ap.plugin_connector.list_knowledge_engines()
|
||||
engine_id_set = {e.get('plugin_id') for e in engines}
|
||||
except Exception:
|
||||
pass
|
||||
if all(pid in engine_id_set for pid in needed_plugins):
|
||||
self.ap.logger.info(f'RAG migration: all plugins ready: {engine_id_set}')
|
||||
task_context.trace('All required plugins are ready.')
|
||||
break
|
||||
if i == max_retries - 1:
|
||||
still_missing = [pid for pid in needed_plugins if pid not in engine_id_set]
|
||||
warning = f'Plugin(s) {still_missing} did not become available after {max_retries} retries'
|
||||
self.ap.logger.warning(f'RAG migration: {warning}')
|
||||
warnings.append(warning)
|
||||
task_context.trace(warning)
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
try:
|
||||
engines = await self.ap.plugin_connector.list_knowledge_engines()
|
||||
engine_id_set = {e.get('plugin_id') for e in engines}
|
||||
except Exception:
|
||||
engine_id_set = set()
|
||||
|
||||
# Step 3: Restore internal knowledge bases from backup
|
||||
task_context.trace('Restoring internal knowledge bases...', action='restore-internal')
|
||||
if await self._table_exists('knowledge_bases_backup'):
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT * FROM knowledge_bases_backup;')
|
||||
)
|
||||
rows = result.fetchall()
|
||||
columns = result.keys()
|
||||
|
||||
for row in rows:
|
||||
row_dict = dict(zip(columns, row))
|
||||
kb_uuid = row_dict.get('uuid')
|
||||
name = row_dict.get('name', 'Untitled')
|
||||
description = row_dict.get('description', '')
|
||||
emoji = row_dict.get('emoji', '\U0001f4da')
|
||||
embedding_model_uuid = row_dict.get('embedding_model_uuid', '')
|
||||
top_k = row_dict.get('top_k', 5)
|
||||
created_at = row_dict.get('created_at')
|
||||
updated_at = row_dict.get('updated_at')
|
||||
|
||||
creation_settings = json.dumps({'embedding_model_uuid': embedding_model_uuid})
|
||||
retrieval_settings = json.dumps({'top_k': top_k})
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'INSERT INTO knowledge_bases '
|
||||
'(uuid, name, description, emoji, created_at, updated_at, '
|
||||
'knowledge_engine_plugin_id, collection_id, creation_settings, retrieval_settings) '
|
||||
'VALUES (:uuid, :name, :description, :emoji, :created_at, :updated_at, '
|
||||
':plugin_id, :collection_id, :creation_settings, :retrieval_settings);'
|
||||
).bindparams(
|
||||
uuid=kb_uuid,
|
||||
name=name,
|
||||
description=description,
|
||||
emoji=emoji,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
plugin_id=LANGRAG_PLUGIN_ID,
|
||||
collection_id=kb_uuid,
|
||||
creation_settings=creation_settings,
|
||||
retrieval_settings=retrieval_settings,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
config = {'embedding_model_uuid': embedding_model_uuid}
|
||||
await self.ap.plugin_connector.rag_on_kb_create(LANGRAG_PLUGIN_ID, kb_uuid, config)
|
||||
task_context.trace(f'Restored internal KB: {name} ({kb_uuid})')
|
||||
except Exception as e:
|
||||
warning = f'Failed to notify plugin for KB {name} ({kb_uuid}): {e}'
|
||||
warnings.append(warning)
|
||||
task_context.trace(warning)
|
||||
|
||||
await self.ap.rag_mgr.load_knowledge_bases_from_db()
|
||||
|
||||
# Step 4: Restore external knowledge bases
|
||||
task_context.trace('Restoring external knowledge bases...', action='restore-external')
|
||||
if has_external:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT * FROM external_knowledge_bases;')
|
||||
)
|
||||
rows = result.fetchall()
|
||||
columns = result.keys()
|
||||
|
||||
self.ap.logger.info(
|
||||
f'RAG migration: {len(rows)} external KB(s) to restore. Available engines: {engine_id_set}'
|
||||
)
|
||||
task_context.trace(f'Found {len(rows)} external KB(s). Available engines: {engine_id_set}')
|
||||
|
||||
for row in rows:
|
||||
row_dict = dict(zip(columns, row))
|
||||
kb_uuid = row_dict.get('uuid')
|
||||
name = row_dict.get('name', 'Untitled')
|
||||
description = row_dict.get('description', '')
|
||||
emoji = row_dict.get('emoji', '\U0001f517')
|
||||
plugin_author = row_dict.get('plugin_author', '')
|
||||
plugin_name = row_dict.get('plugin_name', '')
|
||||
retriever_config = row_dict.get('retriever_config', {})
|
||||
created_at = row_dict.get('created_at')
|
||||
|
||||
mapped_plugin_name = EXTERNAL_PLUGIN_NAME_MAPPING.get(plugin_name, plugin_name)
|
||||
external_plugin_id = f'{plugin_author}/{mapped_plugin_name}'
|
||||
|
||||
self.ap.logger.info(
|
||||
f'RAG migration: processing external KB "{name}" ({kb_uuid}), '
|
||||
f'plugin: {plugin_author}/{plugin_name} -> {external_plugin_id}'
|
||||
)
|
||||
|
||||
if isinstance(retriever_config, str):
|
||||
try:
|
||||
retriever_config = json.loads(retriever_config)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
retriever_config = {}
|
||||
|
||||
creation_fields = EXTERNAL_PLUGIN_CREATION_FIELDS.get(external_plugin_id)
|
||||
if creation_fields is None:
|
||||
creation_settings_dict = retriever_config
|
||||
retrieval_settings_dict = {}
|
||||
else:
|
||||
creation_settings_dict = {k: v for k, v in retriever_config.items() if k in creation_fields}
|
||||
retrieval_settings_dict = {k: v for k, v in retriever_config.items() if k not in creation_fields}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'INSERT INTO knowledge_bases '
|
||||
'(uuid, name, description, emoji, created_at, updated_at, '
|
||||
'knowledge_engine_plugin_id, collection_id, creation_settings, retrieval_settings) '
|
||||
'VALUES (:uuid, :name, :description, :emoji, :created_at, :updated_at, '
|
||||
':plugin_id, :collection_id, :creation_settings, :retrieval_settings);'
|
||||
).bindparams(
|
||||
uuid=kb_uuid,
|
||||
name=name,
|
||||
description=description,
|
||||
emoji=emoji,
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
plugin_id=external_plugin_id,
|
||||
collection_id=kb_uuid,
|
||||
creation_settings=json.dumps(creation_settings_dict),
|
||||
retrieval_settings=json.dumps(retrieval_settings_dict),
|
||||
)
|
||||
)
|
||||
|
||||
if external_plugin_id not in engine_id_set:
|
||||
warning = (
|
||||
f'External KB "{name}" ({kb_uuid}) record saved, but plugin {external_plugin_id} '
|
||||
f'is not installed yet. Install the connector plugin to use it.'
|
||||
)
|
||||
warnings.append(warning)
|
||||
task_context.trace(warning)
|
||||
else:
|
||||
try:
|
||||
await self.ap.plugin_connector.rag_on_kb_create(
|
||||
external_plugin_id, kb_uuid, creation_settings_dict
|
||||
)
|
||||
task_context.trace(f'Restored external KB: {name} ({kb_uuid})')
|
||||
except Exception as e:
|
||||
warning = f'Failed to notify plugin for external KB {name} ({kb_uuid}): {e}'
|
||||
warnings.append(warning)
|
||||
task_context.trace(warning)
|
||||
|
||||
await self.ap.rag_mgr.load_knowledge_bases_from_db()
|
||||
|
||||
# Step 5: Clear migration flag
|
||||
await self._set_migration_flag('false')
|
||||
task_context.trace('RAG migration completed.', action='done')
|
||||
|
||||
if warnings:
|
||||
task_context.trace(f'Completed with {len(warnings)} warning(s).')
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/status', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
needed = await self._get_migration_flag()
|
||||
|
||||
internal_kb_count = 0
|
||||
external_kb_count = 0
|
||||
|
||||
if needed:
|
||||
if await self._table_exists('knowledge_bases_backup'):
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT COUNT(*) FROM knowledge_bases_backup;')
|
||||
)
|
||||
internal_kb_count = result.scalar() or 0
|
||||
|
||||
if await self._table_exists('external_knowledge_bases'):
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT COUNT(*) FROM external_knowledge_bases;')
|
||||
)
|
||||
external_kb_count = result.scalar() or 0
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'needed': needed,
|
||||
'internal_kb_count': internal_kb_count,
|
||||
'external_kb_count': external_kb_count,
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/execute', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
needed = await self._get_migration_flag()
|
||||
if not needed:
|
||||
return self.http_status(400, -1, 'RAG migration is not needed')
|
||||
|
||||
data = await quart.request.get_json(silent=True) or {}
|
||||
install_plugin = data.get('install_plugin', True)
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self._execute_rag_migration(task_context=ctx, install_plugin=install_plugin),
|
||||
kind='rag-migration',
|
||||
name='rag-migration-execute',
|
||||
label='Migrating knowledge bases to plugin architecture',
|
||||
context=ctx,
|
||||
)
|
||||
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
|
||||
@self.route('/dismiss', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
needed = await self._get_migration_flag()
|
||||
if not needed:
|
||||
return self.http_status(400, -1, 'RAG migration is not needed')
|
||||
|
||||
await self._set_migration_flag('false')
|
||||
return self.success()
|
||||
@@ -1,16 +0,0 @@
|
||||
import quart
|
||||
from ... import group
|
||||
|
||||
|
||||
@group.group_class('parsers', '/api/v1/knowledge/parsers')
|
||||
class ParsersRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||
async def list_parsers() -> quart.Response:
|
||||
"""List all available parsers from plugins.
|
||||
|
||||
Optional query parameter `mime_type` to filter parsers by supported MIME type.
|
||||
"""
|
||||
mime_type = quart.request.args.get('mime_type')
|
||||
parsers = await self.ap.knowledge_service.list_parsers(mime_type)
|
||||
return self.success(data={'parsers': parsers})
|
||||
@@ -52,7 +52,6 @@ class MonitoringRouterGroup(group.RouterGroup):
|
||||
# Parse query parameters
|
||||
bot_ids = quart.request.args.getlist('botId')
|
||||
pipeline_ids = quart.request.args.getlist('pipelineId')
|
||||
session_ids = quart.request.args.getlist('sessionId')
|
||||
start_time_str = quart.request.args.get('startTime')
|
||||
end_time_str = quart.request.args.get('endTime')
|
||||
limit = int(quart.request.args.get('limit', 100))
|
||||
@@ -65,7 +64,6 @@ class MonitoringRouterGroup(group.RouterGroup):
|
||||
messages, total = await self.ap.monitoring_service.get_messages(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
session_ids=session_ids if session_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
|
||||
@@ -68,7 +68,7 @@ class PipelinesRouterGroup(group.RouterGroup):
|
||||
return self.http_status(404, -1, 'pipeline not found')
|
||||
|
||||
# Only include plugins with pipeline-related components (Command, EventListener, Tool)
|
||||
# Plugins that only have KnowledgeEngine components are not suitable for pipeline extensions
|
||||
# Plugins that only have KnowledgeRetriever components are not suitable for pipeline extensions
|
||||
pipeline_component_kinds = ['Command', 'EventListener', 'Tool']
|
||||
plugins = await self.ap.plugin_connector.list_plugins(component_kinds=pipeline_component_kinds)
|
||||
mcp_servers = await self.ap.mcp_service.get_mcp_servers(contain_runtime_info=True)
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
import quart
|
||||
|
||||
from .. import group
|
||||
|
||||
|
||||
@group.group_class('survey', '/api/v1/survey')
|
||||
class SurveyRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/pending', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _get_pending() -> str:
|
||||
"""Get pending survey for the frontend to display."""
|
||||
survey = self.ap.survey.get_pending_survey() if self.ap.survey else None
|
||||
return self.success(data={'survey': survey})
|
||||
|
||||
@self.route('/respond', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _respond() -> str:
|
||||
"""Submit survey response."""
|
||||
json_data = await quart.request.json
|
||||
survey_id = json_data.get('survey_id')
|
||||
answers = json_data.get('answers', {})
|
||||
completed = json_data.get('completed', True)
|
||||
|
||||
if not survey_id:
|
||||
return self.fail(1, 'survey_id required')
|
||||
|
||||
if self.ap.survey:
|
||||
ok = await self.ap.survey.submit_response(survey_id, answers, completed)
|
||||
if ok:
|
||||
return self.success()
|
||||
return self.fail(2, 'Failed to submit response')
|
||||
return self.fail(3, 'Survey not available')
|
||||
|
||||
@self.route('/dismiss', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _dismiss() -> str:
|
||||
"""Dismiss survey."""
|
||||
json_data = await quart.request.json
|
||||
survey_id = json_data.get('survey_id')
|
||||
|
||||
if not survey_id:
|
||||
return self.fail(1, 'survey_id required')
|
||||
|
||||
if self.ap.survey:
|
||||
ok = await self.ap.survey.dismiss_survey(survey_id)
|
||||
if ok:
|
||||
return self.success()
|
||||
return self.fail(2, 'Failed to dismiss')
|
||||
return self.fail(3, 'Survey not available')
|
||||
@@ -70,17 +70,12 @@ class BotService:
|
||||
'lark',
|
||||
]:
|
||||
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
||||
extra_webhook_prefix = self.ap.instance_config.data['api'].get('extra_webhook_prefix', '')
|
||||
webhook_url = f'/bots/{bot_uuid}'
|
||||
adapter_runtime_values['webhook_url'] = webhook_url
|
||||
adapter_runtime_values['webhook_full_url'] = f'{webhook_prefix}{webhook_url}'
|
||||
adapter_runtime_values['extra_webhook_full_url'] = (
|
||||
f'{extra_webhook_prefix}{webhook_url}' if extra_webhook_prefix else ''
|
||||
)
|
||||
else:
|
||||
adapter_runtime_values['webhook_url'] = None
|
||||
adapter_runtime_values['webhook_full_url'] = None
|
||||
adapter_runtime_values['extra_webhook_full_url'] = None
|
||||
|
||||
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
|
||||
|
||||
|
||||
80
src/langbot/pkg/api/http/service/external_kb.py
Normal file
80
src/langbot/pkg/api/http/service/external_kb.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ....core import app
|
||||
import sqlalchemy
|
||||
from langbot.pkg.entity.persistence import rag as persistence_rag
|
||||
import uuid
|
||||
|
||||
|
||||
class ExternalKBService:
|
||||
"""External KB service"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
# External Knowledge Base methods
|
||||
async def get_external_knowledge_bases(self) -> list[dict]:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.ExternalKnowledgeBase))
|
||||
external_kbs = result.all()
|
||||
return [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_rag.ExternalKnowledgeBase, external_kb)
|
||||
for external_kb in external_kbs
|
||||
]
|
||||
|
||||
async def get_external_knowledge_base(self, kb_uuid: str) -> dict | None:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_rag.ExternalKnowledgeBase).where(
|
||||
persistence_rag.ExternalKnowledgeBase.uuid == kb_uuid
|
||||
)
|
||||
)
|
||||
external_kb = result.first()
|
||||
if external_kb is None:
|
||||
return None
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_rag.ExternalKnowledgeBase, external_kb)
|
||||
|
||||
async def create_external_knowledge_base(self, kb_data: dict) -> str:
|
||||
kb_data['uuid'] = str(uuid.uuid4())
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_rag.ExternalKnowledgeBase).values(kb_data)
|
||||
)
|
||||
|
||||
kb = await self.get_external_knowledge_base(kb_data['uuid'])
|
||||
|
||||
await self.ap.rag_mgr.load_external_knowledge_base(kb)
|
||||
|
||||
return kb_data['uuid']
|
||||
|
||||
async def retrieve_external_knowledge_base(self, kb_uuid: str, query: str) -> list[dict]:
|
||||
"""Retrieve external knowledge base"""
|
||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
if runtime_kb is None:
|
||||
raise Exception('Knowledge base not found')
|
||||
return [
|
||||
result.model_dump() for result in await runtime_kb.retrieve(query, 5)
|
||||
] # top_k is just a placeholder for external knowledge base
|
||||
|
||||
async def update_external_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
|
||||
if 'uuid' in kb_data:
|
||||
del kb_data['uuid']
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.ExternalKnowledgeBase)
|
||||
.values(kb_data)
|
||||
.where(persistence_rag.ExternalKnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
await self.ap.rag_mgr.remove_knowledge_base_from_runtime(kb_uuid)
|
||||
|
||||
kb = await self.get_external_knowledge_base(kb_uuid)
|
||||
|
||||
await self.ap.rag_mgr.load_external_knowledge_base(kb)
|
||||
|
||||
async def delete_external_knowledge_base(self, kb_uuid: str) -> None:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.ExternalKnowledgeBase).where(
|
||||
persistence_rag.ExternalKnowledgeBase.uuid == kb_uuid
|
||||
)
|
||||
)
|
||||
|
||||
await self.ap.rag_mgr.delete_knowledge_base(kb_uuid)
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import sqlalchemy
|
||||
|
||||
from ....core import app
|
||||
@@ -16,77 +17,64 @@ class KnowledgeService:
|
||||
|
||||
async def get_knowledge_bases(self) -> list[dict]:
|
||||
"""获取所有知识库"""
|
||||
return await self.ap.rag_mgr.get_all_knowledge_base_details()
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
|
||||
knowledge_bases = result.all()
|
||||
return [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base)
|
||||
for knowledge_base in knowledge_bases
|
||||
]
|
||||
|
||||
async def get_knowledge_base(self, kb_uuid: str) -> dict | None:
|
||||
"""获取知识库"""
|
||||
return await self.ap.rag_mgr.get_knowledge_base_details(kb_uuid)
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
knowledge_base = result.first()
|
||||
if knowledge_base is None:
|
||||
return None
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base)
|
||||
|
||||
async def create_knowledge_base(self, kb_data: dict) -> str:
|
||||
"""创建知识库"""
|
||||
# In new architecture, we delegate entirely to RAGManager which uses plugins.
|
||||
# Legacy internal KB creation is removed.
|
||||
kb_data['uuid'] = str(uuid.uuid4())
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.KnowledgeBase).values(kb_data))
|
||||
|
||||
knowledge_engine_plugin_id = kb_data.get('knowledge_engine_plugin_id')
|
||||
if not knowledge_engine_plugin_id:
|
||||
raise ValueError('knowledge_engine_plugin_id is required')
|
||||
kb = await self.get_knowledge_base(kb_data['uuid'])
|
||||
|
||||
kb = await self.ap.rag_mgr.create_knowledge_base(
|
||||
name=kb_data.get('name', 'Untitled'),
|
||||
knowledge_engine_plugin_id=knowledge_engine_plugin_id,
|
||||
creation_settings=kb_data.get('creation_settings', {}),
|
||||
retrieval_settings=kb_data.get('retrieval_settings', {}),
|
||||
description=kb_data.get('description', ''),
|
||||
)
|
||||
return kb.uuid
|
||||
await self.ap.rag_mgr.load_knowledge_base(kb)
|
||||
|
||||
return kb_data['uuid']
|
||||
|
||||
async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
|
||||
"""更新知识库"""
|
||||
# Filter to only mutable fields
|
||||
filtered_data = {k: v for k, v in kb_data.items() if k in persistence_rag.KnowledgeBase.MUTABLE_FIELDS}
|
||||
if 'uuid' in kb_data:
|
||||
del kb_data['uuid']
|
||||
|
||||
if not filtered_data:
|
||||
return
|
||||
if 'embedding_model_uuid' in kb_data:
|
||||
del kb_data['embedding_model_uuid']
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.KnowledgeBase)
|
||||
.values(filtered_data)
|
||||
.values(kb_data)
|
||||
.where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
await self.ap.rag_mgr.remove_knowledge_base_from_runtime(kb_uuid)
|
||||
|
||||
kb = await self.get_knowledge_base(kb_uuid)
|
||||
if kb is None:
|
||||
raise Exception('Knowledge base not found after update')
|
||||
|
||||
await self.ap.rag_mgr.load_knowledge_base(kb)
|
||||
|
||||
async def _check_doc_capability(self, kb_uuid: str, operation: str) -> None:
|
||||
"""Check if the KB's Knowledge Engine supports document operations.
|
||||
|
||||
Args:
|
||||
kb_uuid: Knowledge base UUID.
|
||||
operation: Human-readable operation name for error messages.
|
||||
|
||||
Raises:
|
||||
Exception: If the KB does not support doc_ingestion.
|
||||
"""
|
||||
kb_info = await self.ap.rag_mgr.get_knowledge_base_details(kb_uuid)
|
||||
if not kb_info:
|
||||
raise Exception('Knowledge base not found')
|
||||
capabilities = kb_info.get('knowledge_engine', {}).get('capabilities', [])
|
||||
if 'doc_ingestion' not in capabilities:
|
||||
raise Exception(f'This knowledge base does not support {operation}')
|
||||
|
||||
async def store_file(self, kb_uuid: str, file_id: str, parser_plugin_id: str | None = None) -> str:
|
||||
async def store_file(self, kb_uuid: str, file_id: str) -> int:
|
||||
"""存储文件"""
|
||||
# await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(kb_id=kb_uuid, file_id=file_id))
|
||||
# await self.ap.rag_mgr.store_file(file_id)
|
||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
if runtime_kb is None:
|
||||
raise Exception('Knowledge base not found')
|
||||
|
||||
await self._check_doc_capability(kb_uuid, 'document upload')
|
||||
|
||||
result = await runtime_kb.store_file(file_id, parser_plugin_id=parser_plugin_id)
|
||||
# Only internal KBs support file storage
|
||||
if runtime_kb.get_type() != 'internal':
|
||||
raise Exception('Only internal knowledge bases support file storage')
|
||||
result = await runtime_kb.store_file(file_id)
|
||||
|
||||
# Update the KB's updated_at timestamp
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
@@ -97,18 +85,14 @@ class KnowledgeService:
|
||||
|
||||
return result
|
||||
|
||||
async def retrieve_knowledge_base(
|
||||
self, kb_uuid: str, query: str, retrieval_settings: dict | None = None
|
||||
) -> list[dict]:
|
||||
async def retrieve_knowledge_base(self, kb_uuid: str, query: str) -> list[dict]:
|
||||
"""检索知识库"""
|
||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
if runtime_kb is None:
|
||||
raise Exception('Knowledge base not found')
|
||||
|
||||
# Pass retrieval_settings
|
||||
results = await runtime_kb.retrieve(query, settings=retrieval_settings)
|
||||
|
||||
return [result.model_dump() for result in results]
|
||||
return [
|
||||
result.model_dump() for result in await runtime_kb.retrieve(query, runtime_kb.knowledge_base_entity.top_k)
|
||||
]
|
||||
|
||||
async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]:
|
||||
"""获取知识库文件"""
|
||||
@@ -123,9 +107,9 @@ class KnowledgeService:
|
||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
if runtime_kb is None:
|
||||
raise Exception('Knowledge base not found')
|
||||
|
||||
await self._check_doc_capability(kb_uuid, 'document deletion')
|
||||
|
||||
# Only internal KBs support file deletion
|
||||
if runtime_kb.get_type() != 'internal':
|
||||
raise Exception('Only internal knowledge bases support file deletion')
|
||||
await runtime_kb.delete_file(file_id)
|
||||
|
||||
# Update the KB's updated_at timestamp
|
||||
@@ -137,14 +121,13 @@ class KnowledgeService:
|
||||
|
||||
async def delete_knowledge_base(self, kb_uuid: str) -> None:
|
||||
"""删除知识库"""
|
||||
# Delete from DB first to commit the deletion, then clean up runtime/plugin (best-effort)
|
||||
await self.ap.rag_mgr.delete_knowledge_base(kb_uuid)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
|
||||
# delete files
|
||||
# NOTE: Chunk cleanup is for legacy (pre-plugin) KBs that stored chunks locally.
|
||||
# For plugin-based Knowledge Engines, the Chunk table is not populated, so this is a no-op.
|
||||
files = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid)
|
||||
)
|
||||
@@ -157,53 +140,3 @@ class KnowledgeService:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file.uuid)
|
||||
)
|
||||
|
||||
# Remove from runtime and notify plugin (best-effort, DB is already cleaned up)
|
||||
await self.ap.rag_mgr.delete_knowledge_base(kb_uuid)
|
||||
|
||||
# ================= Knowledge Engine Discovery =================
|
||||
|
||||
async def list_knowledge_engines(self) -> list[dict]:
|
||||
"""List all available Knowledge Engines from plugins."""
|
||||
engines = []
|
||||
|
||||
if not self.ap.plugin_connector.is_enable_plugin:
|
||||
return engines
|
||||
|
||||
# Get KnowledgeEngine plugins
|
||||
try:
|
||||
knowledge_engines = await self.ap.plugin_connector.list_knowledge_engines()
|
||||
engines.extend(knowledge_engines)
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to list Knowledge Engines from plugins: {e}')
|
||||
|
||||
return engines
|
||||
|
||||
async def list_parsers(self, mime_type: str | None = None) -> list[dict]:
|
||||
"""List available parsers, optionally filtered by MIME type."""
|
||||
if not self.ap.plugin_connector.is_enable_plugin:
|
||||
return []
|
||||
try:
|
||||
parsers = await self.ap.plugin_connector.list_parsers()
|
||||
if mime_type:
|
||||
parsers = [p for p in parsers if mime_type in p.get('supported_mime_types', [])]
|
||||
return parsers
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to list parsers: {e}')
|
||||
return []
|
||||
|
||||
async def get_engine_creation_schema(self, plugin_id: str) -> dict:
|
||||
"""Get creation settings schema for a specific Knowledge Engine."""
|
||||
try:
|
||||
return await self.ap.plugin_connector.get_rag_creation_schema(plugin_id)
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to get creation schema for {plugin_id}: {e}')
|
||||
return {}
|
||||
|
||||
async def get_engine_retrieval_schema(self, plugin_id: str) -> dict:
|
||||
"""Get retrieval settings schema for a specific Knowledge Engine."""
|
||||
try:
|
||||
return await self.ap.plugin_connector.get_rag_retrieval_schema(plugin_id)
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to get retrieval schema for {plugin_id}: {e}')
|
||||
return {}
|
||||
|
||||
@@ -105,16 +105,11 @@ class LLMModelsService:
|
||||
)
|
||||
)
|
||||
pipeline = result.first()
|
||||
if pipeline is not None:
|
||||
model_config = pipeline.config.get('ai', {}).get('local-agent', {}).get('model', {})
|
||||
if not model_config.get('primary', ''):
|
||||
pipeline_config = pipeline.config
|
||||
pipeline_config['ai']['local-agent']['model'] = {
|
||||
'primary': model_data['uuid'],
|
||||
'fallbacks': [],
|
||||
}
|
||||
pipeline_data = {'config': pipeline_config}
|
||||
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
|
||||
if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '':
|
||||
pipeline_config = pipeline.config
|
||||
pipeline_config['ai']['local-agent']['model'] = model_data['uuid']
|
||||
pipeline_data = {'config': pipeline_config}
|
||||
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
|
||||
|
||||
return model_data['uuid']
|
||||
|
||||
|
||||
@@ -30,10 +30,8 @@ class MonitoringService:
|
||||
level: str = 'info',
|
||||
platform: str | None = None,
|
||||
user_id: str | None = None,
|
||||
user_name: str | None = None,
|
||||
runner_name: str | None = None,
|
||||
variables: str | None = None,
|
||||
role: str = 'user',
|
||||
) -> str:
|
||||
"""Record a message"""
|
||||
message_id = str(uuid.uuid4())
|
||||
@@ -50,10 +48,8 @@ class MonitoringService:
|
||||
'level': level,
|
||||
'platform': platform,
|
||||
'user_id': user_id,
|
||||
'user_name': user_name,
|
||||
'runner_name': runner_name,
|
||||
'variables': variables,
|
||||
'role': role,
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
@@ -154,7 +150,6 @@ class MonitoringService:
|
||||
pipeline_name: str,
|
||||
platform: str | None = None,
|
||||
user_id: str | None = None,
|
||||
user_name: str | None = None,
|
||||
) -> None:
|
||||
"""Record a new session"""
|
||||
session_data = {
|
||||
@@ -169,7 +164,6 @@ class MonitoringService:
|
||||
'is_active': True,
|
||||
'platform': platform,
|
||||
'user_id': user_id,
|
||||
'user_name': user_name,
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
@@ -361,7 +355,6 @@ class MonitoringService:
|
||||
self,
|
||||
bot_ids: list[str] | None = None,
|
||||
pipeline_ids: list[str] | None = None,
|
||||
session_ids: list[str] | None = None,
|
||||
start_time: datetime.datetime | None = None,
|
||||
end_time: datetime.datetime | None = None,
|
||||
limit: int = 100,
|
||||
@@ -374,8 +367,6 @@ class MonitoringService:
|
||||
conditions.append(persistence_monitoring.MonitoringMessage.bot_id.in_(bot_ids))
|
||||
if pipeline_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringMessage.pipeline_id.in_(pipeline_ids))
|
||||
if session_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringMessage.session_id.in_(session_ids))
|
||||
if start_time:
|
||||
conditions.append(persistence_monitoring.MonitoringMessage.timestamp >= start_time)
|
||||
if end_time:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langbot.pkg.utils import httpclient
|
||||
import aiohttp
|
||||
import typing
|
||||
import datetime
|
||||
import time
|
||||
@@ -99,49 +99,49 @@ class SpaceService:
|
||||
space_config = self._get_space_config()
|
||||
space_url = space_config['url']
|
||||
|
||||
session = httpclient.get_session()
|
||||
async with session.post(
|
||||
f'{space_url}/api/v1/accounts/oauth/token',
|
||||
json={'code': code, 'instance_id': constants.instance_id},
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to exchange OAuth code: {await response.text()}')
|
||||
data = await response.json()
|
||||
if data.get('code') != 0:
|
||||
raise ValueError(f'Failed to exchange OAuth code: {data.get("msg")}')
|
||||
return data.get('data', {})
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f'{space_url}/api/v1/accounts/oauth/token',
|
||||
json={'code': code, 'instance_id': constants.instance_id},
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to exchange OAuth code: {await response.text()}')
|
||||
data = await response.json()
|
||||
if data.get('code') != 0:
|
||||
raise ValueError(f'Failed to exchange OAuth code: {data.get("msg")}')
|
||||
return data.get('data', {})
|
||||
|
||||
async def refresh_token(self, refresh_token: str) -> typing.Dict:
|
||||
"""Refresh Space access token"""
|
||||
space_config = self._get_space_config()
|
||||
space_url = space_config['url']
|
||||
|
||||
session = httpclient.get_session()
|
||||
async with session.post(
|
||||
f'{space_url}/api/v1/accounts/token/refresh', json={'refresh_token': refresh_token}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to refresh token: {await response.text()}')
|
||||
data = await response.json()
|
||||
if data.get('code') != 0:
|
||||
raise ValueError(f'Failed to refresh token: {data.get("msg")}')
|
||||
return data.get('data', {})
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f'{space_url}/api/v1/accounts/token/refresh', json={'refresh_token': refresh_token}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to refresh token: {await response.text()}')
|
||||
data = await response.json()
|
||||
if data.get('code') != 0:
|
||||
raise ValueError(f'Failed to refresh token: {data.get("msg")}')
|
||||
return data.get('data', {})
|
||||
|
||||
async def get_user_info_raw(self, access_token: str) -> typing.Dict:
|
||||
"""Get user info from Space using access token (no validation)"""
|
||||
space_config = self._get_space_config()
|
||||
space_url = space_config['url']
|
||||
|
||||
session = httpclient.get_session()
|
||||
async with session.get(
|
||||
f'{space_url}/api/v1/accounts/me', headers={'Authorization': f'Bearer {access_token}'}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to get user info: {await response.text()}')
|
||||
data = await response.json()
|
||||
if data.get('code') != 0:
|
||||
raise ValueError(f'Failed to get user info: {data.get("msg")}')
|
||||
return data.get('data', {})
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f'{space_url}/api/v1/accounts/me', headers={'Authorization': f'Bearer {access_token}'}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to get user info: {await response.text()}')
|
||||
data = await response.json()
|
||||
if data.get('code') != 0:
|
||||
raise ValueError(f'Failed to get user info: {data.get("msg")}')
|
||||
return data.get('data', {})
|
||||
|
||||
# === API calls with token validation ===
|
||||
|
||||
@@ -178,12 +178,12 @@ class SpaceService:
|
||||
space_config = self._get_space_config()
|
||||
space_url = space_config['url']
|
||||
|
||||
session = httpclient.get_session()
|
||||
async with session.get(f'{space_url}/api/v1/models') as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to get models: {await response.text()}')
|
||||
data = await response.json()
|
||||
if data.get('code') != 0:
|
||||
raise ValueError(f'Failed to get models: {data.get("msg")}')
|
||||
models_data = data.get('data', {}).get('models', [])
|
||||
return [SpaceModel.model_validate(model_dict) for model_dict in models_data]
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f'{space_url}/api/v1/models') as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f'Failed to get models: {await response.text()}')
|
||||
data = await response.json()
|
||||
if data.get('code') != 0:
|
||||
raise ValueError(f'Failed to get models: {data.get("msg")}')
|
||||
models_data = data.get('data', {}).get('models', [])
|
||||
return [SpaceModel.model_validate(model_dict) for model_dict in models_data]
|
||||
|
||||
@@ -9,14 +9,12 @@ from ..platform import botmgr as im_mgr
|
||||
from ..platform.webhook_pusher import WebhookPusher
|
||||
from ..provider.session import sessionmgr as llm_session_mgr
|
||||
from ..provider.modelmgr import modelmgr as llm_model_mgr
|
||||
|
||||
from langbot.pkg.provider.tools import toolmgr as llm_tool_mgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..command import cmdmgr
|
||||
from ..plugin import connector as plugin_connector
|
||||
from ..pipeline import pool
|
||||
from ..pipeline import controller, pipelinemgr
|
||||
from ..pipeline import aggregator as message_aggregator
|
||||
from ..utils import version as version_mgr, proxy as proxy_mgr
|
||||
from ..persistence import mgr as persistencemgr
|
||||
from ..api.http.controller import main as http_controller
|
||||
@@ -30,18 +28,16 @@ from ..api.http.service import knowledge as knowledge_service
|
||||
from ..api.http.service import mcp as mcp_service
|
||||
from ..api.http.service import apikey as apikey_service
|
||||
from ..api.http.service import webhook as webhook_service
|
||||
from ..api.http.service import external_kb as external_kb_service
|
||||
from ..api.http.service import monitoring as monitoring_service
|
||||
|
||||
from ..discover import engine as discover_engine
|
||||
from ..storage import mgr as storagemgr
|
||||
from ..utils import logcache
|
||||
from . import taskmgr
|
||||
from . import entities as core_entities
|
||||
from ..rag.knowledge import kbmgr as rag_mgr
|
||||
from ..rag.service import RAGRuntimeService
|
||||
from ..vector import mgr as vectordb_mgr
|
||||
from ..telemetry import telemetry as telemetry_module
|
||||
from ..survey import manager as survey_module
|
||||
|
||||
|
||||
class Application:
|
||||
@@ -65,7 +61,6 @@ class Application:
|
||||
model_mgr: llm_model_mgr.ModelManager = None
|
||||
|
||||
rag_mgr: rag_mgr.RAGManager = None
|
||||
rag_runtime_service: RAGRuntimeService = None
|
||||
|
||||
# TODO move to pipeline
|
||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||
@@ -101,8 +96,6 @@ class Application:
|
||||
|
||||
query_pool: pool.QueryPool = None
|
||||
|
||||
msg_aggregator: message_aggregator.MessageAggregator = None
|
||||
|
||||
ctrl: controller.Controller = None
|
||||
|
||||
pipeline_mgr: pipelinemgr.PipelineManager = None
|
||||
@@ -141,6 +134,8 @@ class Application:
|
||||
|
||||
knowledge_service: knowledge_service.KnowledgeService = None
|
||||
|
||||
external_kb_service: external_kb_service.ExternalKBService = None
|
||||
|
||||
mcp_service: mcp_service.MCPService = None
|
||||
|
||||
apikey_service: apikey_service.ApiKeyService = None
|
||||
@@ -149,8 +144,6 @@ class Application:
|
||||
|
||||
telemetry: telemetry_module.TelemetryManager = None
|
||||
|
||||
survey: survey_module.SurveyManager = None
|
||||
|
||||
monitoring_service: monitoring_service.MonitoringService = None
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import importlib.util
|
||||
import pip
|
||||
import os
|
||||
from ...utils import pkgmgr
|
||||
@@ -50,10 +49,9 @@ async def check_deps() -> list[str]:
|
||||
|
||||
missing_deps = []
|
||||
for dep in required_deps:
|
||||
# Use find_spec instead of __import__ to avoid actually loading
|
||||
# all modules into memory. find_spec only checks if the module
|
||||
# can be found, without executing module-level code.
|
||||
if importlib.util.find_spec(dep) is None:
|
||||
try:
|
||||
__import__(dep)
|
||||
except ImportError:
|
||||
missing_deps.append(dep)
|
||||
return missing_deps
|
||||
|
||||
|
||||
@@ -5,14 +5,12 @@ import asyncio
|
||||
from .. import stage, app
|
||||
from ...utils import version, proxy
|
||||
from ...pipeline import pool, controller, pipelinemgr
|
||||
from ...pipeline import aggregator as message_aggregator
|
||||
from ...plugin import connector as plugin_connector
|
||||
from ...command import cmdmgr
|
||||
from ...provider.session import sessionmgr as llm_session_mgr
|
||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||
from ...rag.knowledge import kbmgr as rag_mgr
|
||||
from ...rag.service import RAGRuntimeService
|
||||
from ...platform import botmgr as im_mgr
|
||||
from ...platform.webhook_pusher import WebhookPusher
|
||||
from ...persistence import mgr as persistencemgr
|
||||
@@ -27,6 +25,7 @@ from ...api.http.service import knowledge as knowledge_service
|
||||
from ...api.http.service import mcp as mcp_service
|
||||
from ...api.http.service import apikey as apikey_service
|
||||
from ...api.http.service import webhook as webhook_service
|
||||
from ...api.http.service import external_kb as external_kb_service
|
||||
from ...api.http.service import monitoring as monitoring_service
|
||||
from ...discover import engine as discover_engine
|
||||
from ...storage import mgr as storagemgr
|
||||
@@ -34,7 +33,6 @@ from ...utils import logcache
|
||||
from ...vector import mgr as vectordb_mgr
|
||||
from .. import taskmgr
|
||||
from ...telemetry import telemetry as telemetry_module
|
||||
from ...survey import manager as survey_module
|
||||
|
||||
|
||||
@stage.stage_class('BuildAppStage')
|
||||
@@ -73,6 +71,9 @@ class BuildAppStage(stage.BootingStage):
|
||||
knowledge_service_inst = knowledge_service.KnowledgeService(ap)
|
||||
ap.knowledge_service = knowledge_service_inst
|
||||
|
||||
external_kb_service_inst = external_kb_service.ExternalKBService(ap)
|
||||
ap.external_kb_service = external_kb_service_inst
|
||||
|
||||
mcp_service_inst = mcp_service.MCPService(ap)
|
||||
ap.mcp_service = mcp_service_inst
|
||||
|
||||
@@ -108,11 +109,6 @@ class BuildAppStage(stage.BootingStage):
|
||||
await telemetry_inst.initialize()
|
||||
ap.telemetry = telemetry_inst
|
||||
|
||||
# Survey manager
|
||||
survey_inst = survey_module.SurveyManager(ap)
|
||||
await survey_inst.initialize()
|
||||
ap.survey = survey_inst
|
||||
|
||||
cmd_mgr_inst = cmdmgr.CommandManager(ap)
|
||||
await cmd_mgr_inst.initialize()
|
||||
ap.cmd_mgr = cmd_mgr_inst
|
||||
@@ -141,17 +137,10 @@ class BuildAppStage(stage.BootingStage):
|
||||
await pipeline_mgr.initialize()
|
||||
ap.pipeline_mgr = pipeline_mgr
|
||||
|
||||
# Initialize message aggregator (after pipeline_mgr, as it needs pipeline config)
|
||||
msg_aggregator_inst = message_aggregator.MessageAggregator(ap)
|
||||
ap.msg_aggregator = msg_aggregator_inst
|
||||
|
||||
rag_mgr_inst = rag_mgr.RAGManager(ap)
|
||||
await rag_mgr_inst.initialize()
|
||||
ap.rag_mgr = rag_mgr_inst
|
||||
|
||||
# Initialize RAG Runtime Service for plugins
|
||||
ap.rag_runtime_service = RAGRuntimeService(ap)
|
||||
|
||||
# 初始化向量数据库管理器
|
||||
vectordb_mgr_inst = vectordb_mgr.VectorDBManager(ap)
|
||||
await vectordb_mgr_inst.initialize()
|
||||
|
||||
@@ -20,10 +20,8 @@ class MonitoringMessage(Base):
|
||||
level = sqlalchemy.Column(sqlalchemy.String(50), nullable=False) # info, warning, error, debug
|
||||
platform = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
user_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
user_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) # User display name
|
||||
runner_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) # Runner name for this query
|
||||
variables = sqlalchemy.Column(sqlalchemy.Text, nullable=True) # Query variables as JSON string
|
||||
role = sqlalchemy.Column(sqlalchemy.String(50), nullable=True, default='user') # user, assistant
|
||||
|
||||
|
||||
class MonitoringLLMCall(Base):
|
||||
@@ -65,7 +63,6 @@ class MonitoringSession(Base):
|
||||
is_active = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True, index=True)
|
||||
platform = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
user_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
user_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) # User display name
|
||||
|
||||
|
||||
class MonitoringError(Base):
|
||||
|
||||
@@ -10,21 +10,8 @@ class KnowledgeBase(Base):
|
||||
emoji = sqlalchemy.Column(sqlalchemy.String(10), nullable=True, default='📚')
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now())
|
||||
# New fields for plugin-based RAG
|
||||
knowledge_engine_plugin_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
collection_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
creation_settings = sqlalchemy.Column(sqlalchemy.JSON, nullable=True, default=None)
|
||||
retrieval_settings = sqlalchemy.Column(sqlalchemy.JSON, nullable=True, default=None)
|
||||
|
||||
# Field sets for different operations
|
||||
MUTABLE_FIELDS = {'name', 'description', 'retrieval_settings'}
|
||||
"""Fields that can be updated after creation."""
|
||||
|
||||
CREATE_FIELDS = MUTABLE_FIELDS | {'uuid', 'knowledge_engine_plugin_id', 'collection_id', 'creation_settings'}
|
||||
"""Fields used when creating a new knowledge base."""
|
||||
|
||||
ALL_DB_FIELDS = CREATE_FIELDS | {'emoji', 'created_at', 'updated_at'}
|
||||
"""All fields stored in database (for loading from DB row)."""
|
||||
embedding_model_uuid = sqlalchemy.Column(sqlalchemy.String, default='')
|
||||
top_k = sqlalchemy.Column(sqlalchemy.Integer, default=5)
|
||||
|
||||
|
||||
class File(Base):
|
||||
@@ -42,3 +29,16 @@ class Chunk(Base):
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
file_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
text = sqlalchemy.Column(sqlalchemy.Text)
|
||||
|
||||
|
||||
class ExternalKnowledgeBase(Base):
|
||||
__tablename__ = 'external_knowledge_bases'
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String, index=True)
|
||||
description = sqlalchemy.Column(sqlalchemy.Text)
|
||||
emoji = sqlalchemy.Column(sqlalchemy.String(10), nullable=True, default='🔗')
|
||||
plugin_author = sqlalchemy.Column(sqlalchemy.String, nullable=False)
|
||||
plugin_name = sqlalchemy.Column(sqlalchemy.String, nullable=False)
|
||||
retriever_name = sqlalchemy.Column(sqlalchemy.String, nullable=False)
|
||||
retriever_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now())
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(19)
|
||||
class DBMigrateMonitoringMessageRole(migration.DBMigration):
|
||||
"""Add role column to monitoring_messages table"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
try:
|
||||
sql_text = sqlalchemy.text("ALTER TABLE monitoring_messages ADD COLUMN role VARCHAR(50) DEFAULT 'user'")
|
||||
await self.ap.persistence_mgr.execute_async(sql_text)
|
||||
except Exception:
|
||||
# Column may already exist
|
||||
pass
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
try:
|
||||
sql_text = sqlalchemy.text('ALTER TABLE monitoring_messages DROP COLUMN role')
|
||||
await self.ap.persistence_mgr.execute_async(sql_text)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -1,161 +0,0 @@
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(20)
|
||||
class DBMigrateKnowledgeEnginePluginArchitecture(migration.DBMigration):
|
||||
"""Migrate to unified Knowledge Engine plugin architecture.
|
||||
|
||||
Changes:
|
||||
- Backup existing knowledge_bases data to knowledge_bases_backup
|
||||
- Clear knowledge_bases table and add new plugin architecture columns
|
||||
- Drop old columns (PostgreSQL only; SQLite leaves them unmapped)
|
||||
- Preserve external_knowledge_bases table as-is for future migration
|
||||
- Set rag_plugin_migration_needed flag in metadata if old data exists
|
||||
"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
has_internal_data = await self._backup_knowledge_bases()
|
||||
has_external_data = await self._check_external_knowledge_bases()
|
||||
await self._clear_knowledge_bases()
|
||||
await self._add_columns_to_knowledge_bases()
|
||||
await self._drop_old_columns()
|
||||
if has_internal_data or has_external_data:
|
||||
await self._set_migration_flag()
|
||||
|
||||
async def _get_table_columns(self, table_name: str) -> list[str]:
|
||||
"""Get column names from a table (works for both SQLite and PostgreSQL)."""
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'SELECT column_name FROM information_schema.columns WHERE table_name = :table_name;'
|
||||
).bindparams(table_name=table_name)
|
||||
)
|
||||
return [row[0] for row in result.fetchall()]
|
||||
else:
|
||||
# SQLite PRAGMA does not support bind parameters; validate identifier.
|
||||
if not table_name.isidentifier():
|
||||
raise ValueError(f'Invalid table name: {table_name}')
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name});'))
|
||||
return [row[1] for row in result.fetchall()]
|
||||
|
||||
async def _table_exists(self, table_name: str) -> bool:
|
||||
"""Check if a table exists."""
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = :table_name);'
|
||||
).bindparams(table_name=table_name)
|
||||
)
|
||||
return result.scalar()
|
||||
else:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name;").bindparams(
|
||||
table_name=table_name
|
||||
)
|
||||
)
|
||||
return result.first() is not None
|
||||
|
||||
async def _backup_knowledge_bases(self) -> bool:
|
||||
"""Backup knowledge_bases data. Returns True if data was backed up."""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text('SELECT COUNT(*) FROM knowledge_bases;'))
|
||||
count = result.scalar()
|
||||
if count == 0:
|
||||
return False
|
||||
|
||||
# Drop backup table if it already exists (from a previous failed migration)
|
||||
if await self._table_exists('knowledge_bases_backup'):
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.text('DROP TABLE knowledge_bases_backup;'))
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('CREATE TABLE knowledge_bases_backup AS SELECT * FROM knowledge_bases;')
|
||||
)
|
||||
self.ap.logger.info(
|
||||
'Backed up %d knowledge base(s) to knowledge_bases_backup table.',
|
||||
count,
|
||||
)
|
||||
return True
|
||||
|
||||
async def _check_external_knowledge_bases(self) -> bool:
|
||||
"""Check if external_knowledge_bases table exists and has data.
|
||||
|
||||
The table is preserved as-is (not dropped) for future migration.
|
||||
"""
|
||||
if not await self._table_exists('external_knowledge_bases'):
|
||||
return False
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT COUNT(*) FROM external_knowledge_bases;')
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
self.ap.logger.info(
|
||||
'Found %d external knowledge base(s) in external_knowledge_bases table. '
|
||||
'Table preserved for future migration.',
|
||||
count,
|
||||
)
|
||||
return count > 0
|
||||
|
||||
async def _clear_knowledge_bases(self):
|
||||
"""Clear all rows from knowledge_bases table (preserve table structure)."""
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.text('DELETE FROM knowledge_bases;'))
|
||||
|
||||
async def _add_columns_to_knowledge_bases(self):
|
||||
"""Add new RAG plugin architecture columns to knowledge_bases table."""
|
||||
columns = await self._get_table_columns('knowledge_bases')
|
||||
|
||||
new_columns = {
|
||||
'knowledge_engine_plugin_id': 'VARCHAR',
|
||||
'collection_id': 'VARCHAR',
|
||||
'creation_settings': 'TEXT', # JSON stored as TEXT for SQLite compatibility
|
||||
'retrieval_settings': 'TEXT',
|
||||
}
|
||||
|
||||
for col_name, col_type in new_columns.items():
|
||||
if col_name not in columns:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(f'ALTER TABLE knowledge_bases ADD COLUMN {col_name} {col_type};')
|
||||
)
|
||||
|
||||
async def _drop_old_columns(self):
|
||||
"""Drop embedding_model_uuid and top_k columns (PostgreSQL only).
|
||||
|
||||
SQLite does not support DROP COLUMN in older versions, so we leave the
|
||||
columns in place — the SQLAlchemy entity simply won't map them.
|
||||
"""
|
||||
if self.ap.persistence_mgr.db.name != 'postgresql':
|
||||
return
|
||||
|
||||
columns = await self._get_table_columns('knowledge_bases')
|
||||
|
||||
if 'embedding_model_uuid' in columns:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE knowledge_bases DROP COLUMN embedding_model_uuid;')
|
||||
)
|
||||
|
||||
if 'top_k' in columns:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE knowledge_bases DROP COLUMN top_k;')
|
||||
)
|
||||
|
||||
async def _set_migration_flag(self):
|
||||
"""Set rag_plugin_migration_needed flag in metadata table."""
|
||||
# Check if the key already exists
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("SELECT value FROM metadata WHERE key = 'rag_plugin_migration_needed';")
|
||||
)
|
||||
row = result.first()
|
||||
if row is not None:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("UPDATE metadata SET value = 'true' WHERE key = 'rag_plugin_migration_needed';")
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("INSERT INTO metadata (key, value) VALUES ('rag_plugin_migration_needed', 'true');")
|
||||
)
|
||||
self.ap.logger.info('Set rag_plugin_migration_needed=true in metadata.')
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
pass
|
||||
@@ -1,74 +0,0 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(21)
|
||||
class DBMigrateMergeExceptionHandling(migration.DBMigration):
|
||||
"""Merge hide-exception and block-failed-request-output into a single exception-handling select option,
|
||||
and add failure-hint field.
|
||||
|
||||
Conversion logic:
|
||||
- block-failed-request-output=true -> exception-handling: hide
|
||||
- hide-exception=true -> exception-handling: show-hint
|
||||
- hide-exception=false -> exception-handling: show-error
|
||||
"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
if 'output' not in config:
|
||||
config['output'] = {}
|
||||
if 'misc' not in config['output']:
|
||||
config['output']['misc'] = {}
|
||||
|
||||
misc = config['output']['misc']
|
||||
|
||||
# Determine new exception-handling value from legacy fields
|
||||
hide_exception = misc.get('hide-exception', True)
|
||||
block_failed = misc.get('block-failed-request-output', False)
|
||||
|
||||
if block_failed:
|
||||
exception_handling = 'hide'
|
||||
elif hide_exception:
|
||||
exception_handling = 'show-hint'
|
||||
else:
|
||||
exception_handling = 'show-error'
|
||||
|
||||
misc['exception-handling'] = exception_handling
|
||||
|
||||
# Add failure-hint with default value
|
||||
misc['failure-hint'] = 'Request failed.'
|
||||
|
||||
# Remove legacy fields
|
||||
misc.pop('hide-exception', None)
|
||||
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
pass
|
||||
@@ -1,73 +0,0 @@
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(22)
|
||||
class DBMigrateMonitoringUserId(migration.DBMigration):
|
||||
"""Add user_id and user_name columns to monitoring_sessions table
|
||||
|
||||
This migration adds the missing user_id column and also ensures user_name
|
||||
column exists (in case migration 21 failed or was skipped).
|
||||
"""
|
||||
|
||||
async def _table_exists(self, table_name: str) -> bool:
|
||||
"""Check if a table exists (works for both SQLite and PostgreSQL)."""
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = :table_name);'
|
||||
).bindparams(table_name=table_name)
|
||||
)
|
||||
return bool(result.scalar())
|
||||
else:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name;").bindparams(
|
||||
table_name=table_name
|
||||
)
|
||||
)
|
||||
return result.first() is not None
|
||||
|
||||
async def _get_table_columns(self, table_name: str) -> list[str]:
|
||||
"""Get column names from a table (works for both SQLite and PostgreSQL)."""
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'SELECT column_name FROM information_schema.columns WHERE table_name = :table_name;'
|
||||
).bindparams(table_name=table_name)
|
||||
)
|
||||
return [row[0] for row in result.fetchall()]
|
||||
else:
|
||||
if not table_name.isidentifier():
|
||||
raise ValueError(f'Invalid table name: {table_name}')
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name});'))
|
||||
return [row[1] for row in result.fetchall()]
|
||||
|
||||
async def _add_column_if_not_exists(self, table_name: str, column_name: str, column_type: str):
|
||||
"""Add a column to a table if it does not already exist."""
|
||||
columns = await self._get_table_columns(table_name)
|
||||
if column_name in columns:
|
||||
self.ap.logger.debug('%s column already exists in %s.', column_name, table_name)
|
||||
return
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(f'ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type};')
|
||||
)
|
||||
self.ap.logger.info('Added %s column to %s table.', column_name, table_name)
|
||||
|
||||
async def upgrade(self):
|
||||
# Check if monitoring_sessions table exists
|
||||
if not await self._table_exists('monitoring_sessions'):
|
||||
self.ap.logger.warning('monitoring_sessions table does not exist, skipping migration.')
|
||||
return
|
||||
|
||||
# Add user_id column to monitoring_sessions table
|
||||
await self._add_column_if_not_exists('monitoring_sessions', 'user_id', 'VARCHAR(255)')
|
||||
|
||||
# Add user_name column to monitoring_sessions table (in case migration 21 failed)
|
||||
await self._add_column_if_not_exists('monitoring_sessions', 'user_name', 'VARCHAR(255)')
|
||||
|
||||
# Add user_name column to monitoring_messages table (in case migration 21 failed)
|
||||
if await self._table_exists('monitoring_messages'):
|
||||
await self._add_column_if_not_exists('monitoring_messages', 'user_name', 'VARCHAR(255)')
|
||||
|
||||
async def downgrade(self):
|
||||
pass
|
||||
@@ -1,102 +0,0 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(23)
|
||||
class DBMigrateModelFallbackConfig(migration.DBMigration):
|
||||
"""Convert model field from plain UUID string to object with primary/fallbacks"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
if 'ai' not in config or 'local-agent' not in config['ai']:
|
||||
continue
|
||||
|
||||
local_agent = config['ai']['local-agent']
|
||||
changed = False
|
||||
|
||||
# Convert model from string to object
|
||||
model_value = local_agent.get('model', '')
|
||||
if isinstance(model_value, str):
|
||||
local_agent['model'] = {
|
||||
'primary': model_value,
|
||||
'fallbacks': [],
|
||||
}
|
||||
changed = True
|
||||
|
||||
# Remove leftover fallback-models field if present
|
||||
if 'fallback-models' in local_agent:
|
||||
del local_agent['fallback-models']
|
||||
changed = True
|
||||
|
||||
if not changed:
|
||||
continue
|
||||
|
||||
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||
)
|
||||
pipelines = result.fetchall()
|
||||
|
||||
current_version = self.ap.ver_mgr.get_current_version()
|
||||
|
||||
for pipeline_row in pipelines:
|
||||
uuid = pipeline_row[0]
|
||||
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||
|
||||
if 'ai' not in config or 'local-agent' not in config['ai']:
|
||||
continue
|
||||
|
||||
local_agent = config['ai']['local-agent']
|
||||
|
||||
# Convert model from object back to string
|
||||
model_value = local_agent.get('model', '')
|
||||
if isinstance(model_value, dict):
|
||||
local_agent['model'] = model_value.get('primary', '')
|
||||
else:
|
||||
continue
|
||||
|
||||
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||
)
|
||||
@@ -1,49 +0,0 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
import json
|
||||
|
||||
|
||||
@migration.migration_class(24)
|
||||
class DBMigrateWecomBotWebSocketMode(migration.DBMigration):
|
||||
"""Add enable-webhook field to existing wecombot adapter configs.
|
||||
|
||||
Existing wecombot bots were all using webhook mode, so we set
|
||||
enable-webhook=true to preserve their behavior after the new
|
||||
WebSocket long connection mode is introduced as default.
|
||||
"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""Upgrade"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("SELECT uuid, adapter_config FROM bots WHERE adapter = 'wecombot'")
|
||||
)
|
||||
bots = result.fetchall()
|
||||
|
||||
for bot_row in bots:
|
||||
bot_uuid = bot_row[0]
|
||||
adapter_config = json.loads(bot_row[1]) if isinstance(bot_row[1], str) else bot_row[1]
|
||||
|
||||
if 'enable-webhook' in adapter_config:
|
||||
continue
|
||||
|
||||
# Determine mode based on existing config: if webhook fields are present, keep webhook mode
|
||||
has_webhook_config = bool(
|
||||
adapter_config.get('Token') and adapter_config.get('EncodingAESKey') and adapter_config.get('Corpid')
|
||||
)
|
||||
adapter_config['enable-webhook'] = has_webhook_config
|
||||
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('UPDATE bots SET adapter_config = :config::jsonb WHERE uuid = :uuid'),
|
||||
{'config': json.dumps(adapter_config), 'uuid': bot_uuid},
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('UPDATE bots SET adapter_config = :config WHERE uuid = :uuid'),
|
||||
{'config': json.dumps(adapter_config), 'uuid': bot_uuid},
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""Downgrade"""
|
||||
pass
|
||||
@@ -1,289 +0,0 @@
|
||||
"""Message Aggregator Module
|
||||
|
||||
This module provides message aggregation/debounce functionality.
|
||||
When users send multiple messages consecutively, the aggregator will wait
|
||||
for a configurable delay period and merge them into a single message
|
||||
before processing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import typing
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..core import app
|
||||
|
||||
# Maximum number of messages to buffer before forcing a flush
|
||||
MAX_BUFFER_MESSAGES = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingMessage:
|
||||
"""A pending message waiting to be aggregated"""
|
||||
|
||||
bot_uuid: str
|
||||
launcher_type: provider_session.LauncherTypes
|
||||
launcher_id: typing.Union[int, str]
|
||||
sender_id: typing.Union[int, str]
|
||||
message_event: platform_events.MessageEvent
|
||||
message_chain: platform_message.MessageChain
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
pipeline_uuid: typing.Optional[str]
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionBuffer:
|
||||
"""Buffer for a single session's pending messages"""
|
||||
|
||||
session_id: str
|
||||
messages: list[PendingMessage] = field(default_factory=list)
|
||||
timer_task: typing.Optional[asyncio.Task] = None
|
||||
last_message_time: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class MessageAggregator:
|
||||
"""Message aggregator that buffers and merges consecutive messages
|
||||
|
||||
This class implements a debounce mechanism for incoming messages.
|
||||
When a message arrives, it starts a timer. If more messages arrive
|
||||
before the timer expires, they are buffered. When the timer expires,
|
||||
all buffered messages are merged and sent to the query pool.
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
buffers: dict[str, SessionBuffer]
|
||||
"""Session ID -> SessionBuffer mapping"""
|
||||
|
||||
lock: asyncio.Lock
|
||||
"""Lock for thread-safe buffer operations"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.buffers = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
def _get_session_id(
|
||||
self,
|
||||
bot_uuid: str,
|
||||
launcher_type: provider_session.LauncherTypes,
|
||||
launcher_id: typing.Union[int, str],
|
||||
) -> str:
|
||||
"""Generate a unique session ID"""
|
||||
return f'{bot_uuid}:{launcher_type.value}:{launcher_id}'
|
||||
|
||||
async def _get_aggregation_config(self, pipeline_uuid: typing.Optional[str]) -> tuple[bool, float]:
|
||||
"""Get aggregation configuration for a pipeline
|
||||
|
||||
Returns:
|
||||
tuple: (enabled, delay_seconds)
|
||||
"""
|
||||
default_enabled = False
|
||||
default_delay = 1.5
|
||||
|
||||
if pipeline_uuid is None:
|
||||
return default_enabled, default_delay
|
||||
|
||||
# Get pipeline from pipeline manager
|
||||
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(pipeline_uuid)
|
||||
if pipeline is None:
|
||||
return default_enabled, default_delay
|
||||
|
||||
config = pipeline.pipeline_entity.config or {}
|
||||
trigger_config = config.get('trigger', {})
|
||||
aggregation_config = trigger_config.get('message-aggregation', {})
|
||||
|
||||
enabled = aggregation_config.get('enabled', default_enabled)
|
||||
|
||||
delay_raw = aggregation_config.get('delay', default_delay)
|
||||
try:
|
||||
delay = float(delay_raw)
|
||||
except (TypeError, ValueError):
|
||||
delay = default_delay
|
||||
|
||||
# Clamp delay to valid range
|
||||
delay = max(1.0, min(10.0, delay))
|
||||
|
||||
return enabled, delay
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
bot_uuid: str,
|
||||
launcher_type: provider_session.LauncherTypes,
|
||||
launcher_id: typing.Union[int, str],
|
||||
sender_id: typing.Union[int, str],
|
||||
message_event: platform_events.MessageEvent,
|
||||
message_chain: platform_message.MessageChain,
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||
pipeline_uuid: typing.Optional[str] = None,
|
||||
) -> None:
|
||||
"""Add a message to the aggregation buffer
|
||||
|
||||
If aggregation is disabled for the pipeline, the message is sent
|
||||
directly to the query pool. Otherwise, it's buffered and will be
|
||||
merged with other messages from the same session.
|
||||
"""
|
||||
enabled, delay = await self._get_aggregation_config(pipeline_uuid)
|
||||
|
||||
if not enabled:
|
||||
# Aggregation disabled, send directly to query pool
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=bot_uuid,
|
||||
launcher_type=launcher_type,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=sender_id,
|
||||
message_event=message_event,
|
||||
message_chain=message_chain,
|
||||
adapter=adapter,
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
)
|
||||
return
|
||||
|
||||
session_id = self._get_session_id(bot_uuid, launcher_type, launcher_id)
|
||||
|
||||
pending_msg = PendingMessage(
|
||||
bot_uuid=bot_uuid,
|
||||
launcher_type=launcher_type,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=sender_id,
|
||||
message_event=message_event,
|
||||
message_chain=message_chain,
|
||||
adapter=adapter,
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
)
|
||||
|
||||
force_flush = False
|
||||
async with self.lock:
|
||||
if session_id in self.buffers:
|
||||
buffer = self.buffers[session_id]
|
||||
# Cancel existing timer (just cancel, don't await inside lock)
|
||||
if buffer.timer_task and not buffer.timer_task.done():
|
||||
buffer.timer_task.cancel()
|
||||
buffer.messages.append(pending_msg)
|
||||
else:
|
||||
buffer = SessionBuffer(
|
||||
session_id=session_id,
|
||||
messages=[pending_msg],
|
||||
)
|
||||
self.buffers[session_id] = buffer
|
||||
|
||||
buffer.last_message_time = time.time()
|
||||
|
||||
# Check if buffer reached max capacity
|
||||
if len(buffer.messages) >= MAX_BUFFER_MESSAGES:
|
||||
force_flush = True
|
||||
else:
|
||||
# Start new timer
|
||||
buffer.timer_task = asyncio.create_task(self._delayed_flush(session_id, delay))
|
||||
|
||||
if force_flush:
|
||||
await self._flush_buffer(session_id)
|
||||
|
||||
async def _delayed_flush(self, session_id: str, delay: float) -> None:
|
||||
"""Wait for delay then flush the buffer"""
|
||||
try:
|
||||
await asyncio.sleep(delay)
|
||||
await self._flush_buffer(session_id)
|
||||
except asyncio.CancelledError:
|
||||
# Timer was cancelled, new message arrived
|
||||
pass
|
||||
|
||||
async def _flush_buffer(self, session_id: str) -> None:
|
||||
"""Flush the buffer for a session, merging all messages"""
|
||||
async with self.lock:
|
||||
buffer = self.buffers.pop(session_id, None)
|
||||
|
||||
if buffer is None or not buffer.messages:
|
||||
return
|
||||
|
||||
if len(buffer.messages) == 1:
|
||||
# Only one message, no need to merge
|
||||
msg = buffer.messages[0]
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=msg.bot_uuid,
|
||||
launcher_type=msg.launcher_type,
|
||||
launcher_id=msg.launcher_id,
|
||||
sender_id=msg.sender_id,
|
||||
message_event=msg.message_event,
|
||||
message_chain=msg.message_chain,
|
||||
adapter=msg.adapter,
|
||||
pipeline_uuid=msg.pipeline_uuid,
|
||||
)
|
||||
return
|
||||
|
||||
# Merge multiple messages
|
||||
merged_msg = self._merge_messages(buffer.messages)
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=merged_msg.bot_uuid,
|
||||
launcher_type=merged_msg.launcher_type,
|
||||
launcher_id=merged_msg.launcher_id,
|
||||
sender_id=merged_msg.sender_id,
|
||||
message_event=merged_msg.message_event,
|
||||
message_chain=merged_msg.message_chain,
|
||||
adapter=merged_msg.adapter,
|
||||
pipeline_uuid=merged_msg.pipeline_uuid,
|
||||
)
|
||||
|
||||
def _merge_messages(self, messages: list[PendingMessage]) -> PendingMessage:
|
||||
"""Merge multiple messages into one
|
||||
|
||||
The merged message uses the first message as base and combines
|
||||
all message chains with newline separators.
|
||||
The original message_event is kept unmodified to preserve
|
||||
message metadata (message_id, etc.) for reply/quote.
|
||||
"""
|
||||
if len(messages) == 1:
|
||||
return messages[0]
|
||||
|
||||
base_msg = messages[0]
|
||||
|
||||
# Build merged message chain
|
||||
merged_chain = platform_message.MessageChain([])
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
if i > 0:
|
||||
# Add newline separator between messages
|
||||
merged_chain.append(platform_message.Plain(text='\n'))
|
||||
|
||||
# Copy all components from this message
|
||||
for component in msg.message_chain:
|
||||
merged_chain.append(component)
|
||||
|
||||
# Keep message_event unmodified (preserves original message_id and
|
||||
# metadata for reply/quote), only pass merged chain separately
|
||||
return PendingMessage(
|
||||
bot_uuid=base_msg.bot_uuid,
|
||||
launcher_type=base_msg.launcher_type,
|
||||
launcher_id=base_msg.launcher_id,
|
||||
sender_id=base_msg.sender_id,
|
||||
message_event=base_msg.message_event,
|
||||
message_chain=merged_chain,
|
||||
adapter=base_msg.adapter,
|
||||
pipeline_uuid=base_msg.pipeline_uuid,
|
||||
)
|
||||
|
||||
async def flush_all(self) -> None:
|
||||
"""Flush all pending buffers immediately
|
||||
|
||||
This is useful during shutdown to ensure no messages are lost.
|
||||
"""
|
||||
# Snapshot session IDs and cancel all timers under lock
|
||||
async with self.lock:
|
||||
session_ids = list(self.buffers.keys())
|
||||
for sid in session_ids:
|
||||
buffer = self.buffers.get(sid)
|
||||
if buffer and buffer.timer_task and not buffer.timer_task.done():
|
||||
buffer.timer_task.cancel()
|
||||
|
||||
# Flush each buffer outside the lock
|
||||
for session_id in session_ids:
|
||||
await self._flush_buffer(session_id)
|
||||
@@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .. import entities
|
||||
from .. import filter as filter_model
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from langbot.pkg.utils import httpclient
|
||||
|
||||
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
|
||||
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
|
||||
@@ -14,50 +15,50 @@ class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
"""百度云内容审核"""
|
||||
|
||||
async def _get_token(self) -> str:
|
||||
session = httpclient.get_session()
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_TOKEN_URL,
|
||||
params={
|
||||
'grant_type': 'client_credentials',
|
||||
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
|
||||
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'],
|
||||
},
|
||||
) as resp:
|
||||
return (await resp.json())['access_token']
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_TOKEN_URL,
|
||||
params={
|
||||
'grant_type': 'client_credentials',
|
||||
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
|
||||
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'],
|
||||
},
|
||||
) as resp:
|
||||
return (await resp.json())['access_token']
|
||||
|
||||
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||
session = httpclient.get_session()
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
||||
headers={
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
data=f'text={message}'.encode('utf-8'),
|
||||
) as resp:
|
||||
result = await resp.json()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
||||
headers={
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
data=f'text={message}'.encode('utf-8'),
|
||||
) as resp:
|
||||
result = await resp.json()
|
||||
|
||||
if 'error_code' in result:
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement=message,
|
||||
user_notice='',
|
||||
console_notice=f'百度云判定出错,错误信息:{result["error_msg"]}',
|
||||
)
|
||||
else:
|
||||
conclusion = result['conclusion']
|
||||
|
||||
if conclusion in ('合规'):
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.PASS,
|
||||
replacement=message,
|
||||
user_notice='',
|
||||
console_notice=f'百度云判定结果:{conclusion}',
|
||||
)
|
||||
else:
|
||||
if 'error_code' in result:
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement=message,
|
||||
user_notice='消息中存在不合适的内容, 请修改',
|
||||
console_notice=f'百度云判定结果:{conclusion}',
|
||||
user_notice='',
|
||||
console_notice=f'百度云判定出错,错误信息:{result["error_msg"]}',
|
||||
)
|
||||
else:
|
||||
conclusion = result['conclusion']
|
||||
|
||||
if conclusion in ('合规'):
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.PASS,
|
||||
replacement=message,
|
||||
user_notice='',
|
||||
console_notice=f'百度云判定结果:{conclusion}',
|
||||
)
|
||||
else:
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement=message,
|
||||
user_notice='消息中存在不合适的内容, 请修改',
|
||||
console_notice=f'百度云判定结果:{conclusion}',
|
||||
)
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# metadata type -> coercion function
|
||||
_COERCE_MAP = {
|
||||
'integer': lambda v: int(v),
|
||||
'number': lambda v: float(v),
|
||||
'float': lambda v: float(v),
|
||||
}
|
||||
|
||||
|
||||
def _coerce_bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
if v.lower() == 'true':
|
||||
return True
|
||||
if v.lower() == 'false':
|
||||
return False
|
||||
raise ValueError(f'Cannot convert string {v!r} to bool')
|
||||
return bool(v)
|
||||
|
||||
|
||||
def _coerce_value(value, expected_type: str):
|
||||
"""Convert a single value to the expected type.
|
||||
|
||||
Returns the converted value, or the original value if no conversion needed.
|
||||
"""
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
if expected_type == 'boolean':
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return _coerce_bool(value)
|
||||
|
||||
coerce_fn = _COERCE_MAP.get(expected_type)
|
||||
if coerce_fn is None:
|
||||
return value
|
||||
|
||||
# Already the correct type
|
||||
if expected_type == 'integer' and isinstance(value, int) and not isinstance(value, bool):
|
||||
return value
|
||||
if expected_type in ('number', 'float') and isinstance(value, (int, float)) and not isinstance(value, bool):
|
||||
return float(value)
|
||||
|
||||
return coerce_fn(value)
|
||||
|
||||
|
||||
def coerce_pipeline_config(
|
||||
config: dict,
|
||||
*metadata_list: dict,
|
||||
) -> None:
|
||||
"""Coerce pipeline config values according to metadata type definitions.
|
||||
|
||||
Walks each metadata dict (trigger, safety, ai, output) and converts
|
||||
config values in-place so that strings coming from the JSON column are
|
||||
cast to their declared types (integer, number/float, boolean).
|
||||
|
||||
Args:
|
||||
config: The pipeline config dict to modify in-place.
|
||||
*metadata_list: Metadata dicts loaded from the YAML templates.
|
||||
"""
|
||||
for meta in metadata_list:
|
||||
section_name = meta.get('name')
|
||||
if not section_name or section_name not in config:
|
||||
continue
|
||||
|
||||
section = config[section_name]
|
||||
if not isinstance(section, dict):
|
||||
continue
|
||||
|
||||
for stage_def in meta.get('stages', []):
|
||||
stage_name = stage_def.get('name')
|
||||
if not stage_name or stage_name not in section:
|
||||
continue
|
||||
|
||||
stage_config = section[stage_name]
|
||||
if not isinstance(stage_config, dict):
|
||||
continue
|
||||
|
||||
for field_def in stage_def.get('config', []):
|
||||
field_name = field_def.get('name')
|
||||
field_type = field_def.get('type')
|
||||
if not field_name or not field_type or field_name not in stage_config:
|
||||
continue
|
||||
|
||||
old_value = stage_config[field_name]
|
||||
try:
|
||||
new_value = _coerce_value(old_value, field_type)
|
||||
if new_value is not old_value:
|
||||
stage_config[field_name] = new_value
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(
|
||||
'Failed to coerce config %s.%s.%s (%r) to %s: %s',
|
||||
section_name,
|
||||
stage_name,
|
||||
field_name,
|
||||
old_value,
|
||||
field_type,
|
||||
e,
|
||||
)
|
||||
@@ -34,15 +34,6 @@ class MonitoringHelper:
|
||||
# Check if session exists, if not, record session start
|
||||
session_id = f'{query.launcher_type}_{query.launcher_id}'
|
||||
|
||||
# Get sender name from message event
|
||||
sender_name = None
|
||||
if hasattr(query, 'message_event'):
|
||||
if hasattr(query.message_event, 'sender'):
|
||||
if hasattr(query.message_event.sender, 'nickname'):
|
||||
sender_name = query.message_event.sender.nickname
|
||||
elif hasattr(query.message_event.sender, 'member_name'):
|
||||
sender_name = query.message_event.sender.member_name
|
||||
|
||||
# Try to record message
|
||||
# Use JSON serialization to preserve message chain structure (including image URLs, etc.)
|
||||
if hasattr(query, 'message_chain') and hasattr(query.message_chain, 'model_dump'):
|
||||
@@ -66,7 +57,6 @@ class MonitoringHelper:
|
||||
if hasattr(query.launcher_type, 'value')
|
||||
else str(query.launcher_type),
|
||||
user_id=query.sender_id,
|
||||
user_name=sender_name,
|
||||
runner_name=runner_name,
|
||||
variables=None, # Will be updated in record_query_success
|
||||
)
|
||||
@@ -90,7 +80,6 @@ class MonitoringHelper:
|
||||
if hasattr(query.launcher_type, 'value')
|
||||
else str(query.launcher_type),
|
||||
user_id=query.sender_id,
|
||||
user_name=sender_name,
|
||||
)
|
||||
|
||||
return message_id
|
||||
@@ -125,70 +114,6 @@ class MonitoringHelper:
|
||||
except Exception as e:
|
||||
ap.logger.error(f'Failed to record query success: {e}')
|
||||
|
||||
@staticmethod
|
||||
async def record_query_response(
|
||||
ap: app.Application,
|
||||
query: pipeline_query.Query,
|
||||
bot_id: str,
|
||||
bot_name: str,
|
||||
pipeline_id: str,
|
||||
pipeline_name: str,
|
||||
runner_name: str | None = None,
|
||||
):
|
||||
"""Record bot response message to monitoring"""
|
||||
try:
|
||||
session_id = f'{query.launcher_type}_{query.launcher_id}'
|
||||
|
||||
# Get sender name from message event
|
||||
sender_name = None
|
||||
if hasattr(query, 'message_event'):
|
||||
if hasattr(query.message_event, 'sender'):
|
||||
if hasattr(query.message_event.sender, 'nickname'):
|
||||
sender_name = query.message_event.sender.nickname
|
||||
elif hasattr(query.message_event.sender, 'member_name'):
|
||||
sender_name = query.message_event.sender.member_name
|
||||
|
||||
# Extract response content from resp_message_chain
|
||||
if hasattr(query, 'resp_message_chain') and query.resp_message_chain:
|
||||
# Serialize the last response message chain
|
||||
last_resp = query.resp_message_chain[-1]
|
||||
if hasattr(last_resp, 'model_dump'):
|
||||
message_content = json.dumps(last_resp.model_dump(), ensure_ascii=False)
|
||||
else:
|
||||
message_content = str(last_resp)
|
||||
elif hasattr(query, 'resp_messages') and query.resp_messages:
|
||||
last_resp = query.resp_messages[-1]
|
||||
if hasattr(last_resp, 'get_content_platform_message_chain'):
|
||||
chain = last_resp.get_content_platform_message_chain()
|
||||
if hasattr(chain, 'model_dump'):
|
||||
message_content = json.dumps(chain.model_dump(), ensure_ascii=False)
|
||||
else:
|
||||
message_content = str(chain)
|
||||
else:
|
||||
message_content = str(last_resp)
|
||||
else:
|
||||
return # No response to record
|
||||
|
||||
await ap.monitoring_service.record_message(
|
||||
bot_id=bot_id,
|
||||
bot_name=bot_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_name=pipeline_name,
|
||||
message_content=message_content,
|
||||
session_id=session_id,
|
||||
status='success',
|
||||
level='info',
|
||||
platform=query.launcher_type.value
|
||||
if hasattr(query.launcher_type, 'value')
|
||||
else str(query.launcher_type),
|
||||
user_id=query.sender_id,
|
||||
user_name=sender_name,
|
||||
runner_name=runner_name,
|
||||
role='assistant',
|
||||
)
|
||||
except Exception as e:
|
||||
ap.logger.error(f'Failed to record query response: {e}')
|
||||
|
||||
@staticmethod
|
||||
async def record_query_error(
|
||||
ap: app.Application,
|
||||
@@ -204,15 +129,6 @@ class MonitoringHelper:
|
||||
try:
|
||||
session_id = f'{query.launcher_type}_{query.launcher_id}'
|
||||
|
||||
# Get sender name from message event
|
||||
sender_name = None
|
||||
if hasattr(query, 'message_event'):
|
||||
if hasattr(query.message_event, 'sender'):
|
||||
if hasattr(query.message_event.sender, 'nickname'):
|
||||
sender_name = query.message_event.sender.nickname
|
||||
elif hasattr(query.message_event.sender, 'member_name'):
|
||||
sender_name = query.message_event.sender.member_name
|
||||
|
||||
# Record error message
|
||||
message_id = await ap.monitoring_service.record_message(
|
||||
bot_id=bot_id,
|
||||
@@ -227,7 +143,6 @@ class MonitoringHelper:
|
||||
if hasattr(query.launcher_type, 'value')
|
||||
else str(query.launcher_type),
|
||||
user_id=query.sender_id,
|
||||
user_name=sender_name,
|
||||
runner_name=runner_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.events as events
|
||||
from ..utils import importutil
|
||||
from .config_coercion import coerce_pipeline_config
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
@@ -340,20 +339,6 @@ class RuntimePipeline:
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to record query success: {e}')
|
||||
|
||||
# Record bot response message
|
||||
try:
|
||||
await monitoring_helper.MonitoringHelper.record_query_response(
|
||||
ap=self.ap,
|
||||
query=query,
|
||||
bot_id=query.bot_uuid or 'unknown',
|
||||
bot_name=bot_name,
|
||||
pipeline_id=self.pipeline_entity.uuid,
|
||||
pipeline_name=pipeline_name,
|
||||
runner_name=runner_name,
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to record query response: {e}')
|
||||
|
||||
except Exception as e:
|
||||
inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
|
||||
self.ap.logger.error(f'Error processing query {query.query_id} stage={inst_name} : {e}')
|
||||
@@ -384,6 +369,8 @@ class RuntimePipeline:
|
||||
class PipelineManager:
|
||||
"""流水线管理器"""
|
||||
|
||||
# ====== 4.0 ======
|
||||
|
||||
ap: app.Application
|
||||
|
||||
pipelines: list[RuntimePipeline]
|
||||
@@ -421,14 +408,6 @@ class PipelineManager:
|
||||
elif isinstance(pipeline_entity, dict):
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
|
||||
|
||||
coerce_pipeline_config(
|
||||
pipeline_entity.config,
|
||||
getattr(self.ap, 'pipeline_config_meta_trigger', {'name': 'trigger', 'stages': []}),
|
||||
getattr(self.ap, 'pipeline_config_meta_safety', {'name': 'safety', 'stages': []}),
|
||||
getattr(self.ap, 'pipeline_config_meta_ai', {'name': 'ai', 'stages': []}),
|
||||
getattr(self.ap, 'pipeline_config_meta_output', {'name': 'output', 'stages': []}),
|
||||
)
|
||||
|
||||
# initialize stage containers according to pipeline_entity.stages
|
||||
stage_containers: list[StageInstContainer] = []
|
||||
for stage_name in pipeline_entity.stages:
|
||||
|
||||
@@ -36,36 +36,17 @@ class PreProcessor(stage.PipelineStage):
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
# When not local-agent, llm_model is None
|
||||
llm_model = None
|
||||
if selected_runner == 'local-agent':
|
||||
# Read model config — new format is { primary: str, fallbacks: [str] },
|
||||
# but handle legacy plain string for backward compatibility
|
||||
model_config = query.pipeline_config['ai']['local-agent'].get('model', {})
|
||||
if isinstance(model_config, str):
|
||||
# Legacy format: plain UUID string
|
||||
primary_uuid = model_config
|
||||
fallback_uuids = []
|
||||
else:
|
||||
primary_uuid = model_config.get('primary', '')
|
||||
fallback_uuids = model_config.get('fallbacks', [])
|
||||
|
||||
if primary_uuid:
|
||||
try:
|
||||
llm_model = await self.ap.model_mgr.get_model_by_uuid(primary_uuid)
|
||||
except ValueError:
|
||||
self.ap.logger.warning(f'LLM model {primary_uuid} not found or not configured')
|
||||
|
||||
# Resolve fallback model UUIDs
|
||||
if fallback_uuids:
|
||||
valid_fallbacks = []
|
||||
for fb_uuid in fallback_uuids:
|
||||
try:
|
||||
await self.ap.model_mgr.get_model_by_uuid(fb_uuid)
|
||||
valid_fallbacks.append(fb_uuid)
|
||||
except ValueError:
|
||||
self.ap.logger.warning(f'Fallback model {fb_uuid} not found, skipping')
|
||||
if valid_fallbacks:
|
||||
query.variables['_fallback_model_uuids'] = valid_fallbacks
|
||||
try:
|
||||
llm_model = (
|
||||
await self.ap.model_mgr.get_model_by_uuid(query.pipeline_config['ai']['local-agent']['model'])
|
||||
if selected_runner == 'local-agent'
|
||||
else None
|
||||
)
|
||||
except ValueError:
|
||||
self.ap.logger.warning(
|
||||
f'LLM model {query.pipeline_config["ai"]["local-agent"]["model"] + " "}not found or not configured'
|
||||
)
|
||||
llm_model = None
|
||||
|
||||
conversation = await self.ap.sess_mgr.get_conversation(
|
||||
query,
|
||||
@@ -80,28 +61,20 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.prompt = conversation.prompt.copy()
|
||||
query.messages = conversation.messages.copy()
|
||||
|
||||
if selected_runner == 'local-agent':
|
||||
if selected_runner == 'local-agent' and llm_model:
|
||||
query.use_funcs = []
|
||||
if llm_model:
|
||||
query.use_llm_model_uuid = llm_model.model_entity.uuid
|
||||
query.use_llm_model_uuid = llm_model.model_entity.uuid
|
||||
|
||||
if llm_model.model_entity.abilities.__contains__('func_call'):
|
||||
# Get bound plugins and MCP servers for filtering tools
|
||||
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
|
||||
bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None)
|
||||
query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers)
|
||||
|
||||
self.ap.logger.debug(f'Bound plugins: {bound_plugins}')
|
||||
self.ap.logger.debug(f'Bound MCP servers: {bound_mcp_servers}')
|
||||
self.ap.logger.debug(f'Use funcs: {query.use_funcs}')
|
||||
|
||||
# If primary model doesn't support func_call but fallback models exist,
|
||||
# load tools anyway since fallback models may support them
|
||||
if not query.use_funcs and query.variables.get('_fallback_model_uuids'):
|
||||
if llm_model.model_entity.abilities.__contains__('func_call'):
|
||||
# Get bound plugins and MCP servers for filtering tools
|
||||
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
|
||||
bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None)
|
||||
query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers)
|
||||
|
||||
self.ap.logger.debug(f'Bound plugins: {bound_plugins}')
|
||||
self.ap.logger.debug(f'Bound MCP servers: {bound_mcp_servers}')
|
||||
self.ap.logger.debug(f'Use funcs: {query.use_funcs}')
|
||||
|
||||
sender_name = ''
|
||||
|
||||
if isinstance(query.message_event, platform_events.GroupMessage):
|
||||
|
||||
@@ -12,7 +12,7 @@ from ... import entities
|
||||
from ....provider import runner as runner_module
|
||||
|
||||
import langbot_plugin.api.entities.events as events
|
||||
from ....utils import importutil, constants, runner as runner_utils
|
||||
from ....utils import importutil, constants
|
||||
from ....provider import runners
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
@@ -149,19 +149,12 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
self.ap.logger.error(f'Conversation({query.query_id}) Request Failed: {error_info}')
|
||||
traceback.print_exc()
|
||||
|
||||
exception_handling = query.pipeline_config['output']['misc'].get('exception-handling', 'show-hint')
|
||||
|
||||
if exception_handling == 'show-error':
|
||||
user_notice = f'{e}'
|
||||
elif exception_handling == 'show-hint':
|
||||
user_notice = query.pipeline_config['output']['misc'].get('failure-hint', 'Request failed.')
|
||||
else: # hide
|
||||
user_notice = None
|
||||
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice=user_notice,
|
||||
user_notice='请求失败' if hide_exception_info else f'{e}',
|
||||
error_notice=f'{e}',
|
||||
debug_notice=traceback.format_exc(),
|
||||
)
|
||||
@@ -192,15 +185,10 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
pipeline_plugins = query.variables.get('_pipeline_bound_plugins', None)
|
||||
|
||||
runner_category = runner_utils.get_runner_category_from_runner(
|
||||
runner_name, runner, query.pipeline_config
|
||||
)
|
||||
|
||||
payload = {
|
||||
'query_id': query.query_id,
|
||||
'adapter': adapter_name,
|
||||
'runner': runner_name,
|
||||
'runner_category': runner_category,
|
||||
'duration_ms': duration_ms,
|
||||
'model_name': model_name,
|
||||
'version': constants.semantic_version,
|
||||
@@ -212,11 +200,6 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
# Send telemetry asynchronously and do not block pipeline via app's telemetry manager
|
||||
await self.ap.telemetry.start_send_task(payload)
|
||||
|
||||
# Trigger survey event on first successful non-WebSocket response
|
||||
if not locals().get('error_info') and adapter_name and 'WebSocket' not in adapter_name:
|
||||
if self.ap.survey:
|
||||
await self.ap.survey.trigger_event('first_bot_response_success')
|
||||
except Exception as ex:
|
||||
# Ensure telemetry issues do not affect normal flow
|
||||
self.ap.logger.warning(f'Failed to send telemetry: {ex}')
|
||||
|
||||
@@ -18,6 +18,7 @@ import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.events as events
|
||||
|
||||
|
||||
class RuntimeBot:
|
||||
@@ -82,7 +83,7 @@ class RuntimeBot:
|
||||
if custom_launcher_id:
|
||||
launcher_id = custom_launcher_id
|
||||
|
||||
await self.ap.msg_aggregator.add_message(
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=launcher_id,
|
||||
@@ -125,7 +126,7 @@ class RuntimeBot:
|
||||
if custom_launcher_id:
|
||||
launcher_id = custom_launcher_id
|
||||
|
||||
await self.ap.msg_aggregator.add_message(
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=provider_session.LauncherTypes.GROUP,
|
||||
launcher_id=launcher_id,
|
||||
@@ -141,6 +142,56 @@ class RuntimeBot:
|
||||
self.adapter.register_listener(platform_events.FriendMessage, on_friend_message)
|
||||
self.adapter.register_listener(platform_events.GroupMessage, on_group_message)
|
||||
|
||||
async def on_notice(
|
||||
event: platform_events.NoticeEvent,
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||
):
|
||||
await self.logger.info(f'Notice event: {event.notice_type} {event.sub_type}')
|
||||
|
||||
try:
|
||||
event_obj = events.NoticeReceived(
|
||||
notice_type=event.notice_type,
|
||||
sub_type=event.sub_type,
|
||||
group_id=event.group_id,
|
||||
user_id=event.user_id,
|
||||
operator_id=event.operator_id,
|
||||
target_id=event.target_id,
|
||||
message_id=event.message_id,
|
||||
duration=event.duration,
|
||||
file=event.file,
|
||||
honor_type=event.honor_type,
|
||||
)
|
||||
|
||||
if hasattr(self.ap, 'plugin_connector') and self.ap.plugin_connector:
|
||||
await self.ap.plugin_connector.emit_event(event_obj)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error emitting notice event: {traceback.format_exc()}')
|
||||
|
||||
self.adapter.register_listener(platform_events.NoticeEvent, on_notice)
|
||||
|
||||
async def on_request(
|
||||
event: platform_events.RequestEvent,
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||
):
|
||||
await self.logger.info(f'Request event: {event.request_type} {event.sub_type}')
|
||||
|
||||
try:
|
||||
event_obj = events.RequestReceived(
|
||||
request_type=event.request_type,
|
||||
sub_type=event.sub_type,
|
||||
user_id=event.user_id,
|
||||
group_id=event.group_id,
|
||||
comment=event.comment,
|
||||
flag=event.flag,
|
||||
)
|
||||
|
||||
if hasattr(self.ap, 'plugin_connector') and self.ap.plugin_connector:
|
||||
await self.ap.plugin_connector.emit_event(event_obj)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error emitting request event: {traceback.format_exc()}')
|
||||
|
||||
self.adapter.register_listener(platform_events.RequestEvent, on_request)
|
||||
|
||||
async def run(self):
|
||||
async def exception_wrapper():
|
||||
try:
|
||||
@@ -282,8 +333,6 @@ class PlatformManager:
|
||||
return runtime_bot
|
||||
|
||||
async def get_bot_by_uuid(self, bot_uuid: str) -> RuntimeBot | None:
|
||||
if self.websocket_proxy_bot and self.websocket_proxy_bot.bot_entity.uuid == bot_uuid:
|
||||
return self.websocket_proxy_bot
|
||||
for bot in self.bots:
|
||||
if bot.bot_entity.uuid == bot_uuid:
|
||||
return bot
|
||||
|
||||
@@ -306,9 +306,8 @@ class AiocqhttpEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(event: aiocqhttp.Event, bot=None):
|
||||
yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id, bot)
|
||||
|
||||
if event.message_type == 'group':
|
||||
yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id, bot)
|
||||
permission = 'MEMBER'
|
||||
|
||||
if 'role' in event.sender:
|
||||
@@ -334,6 +333,7 @@ class AiocqhttpEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
)
|
||||
return converted_event
|
||||
elif event.message_type == 'private':
|
||||
yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id, bot)
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=event.sender['user_id'],
|
||||
@@ -344,6 +344,57 @@ class AiocqhttpEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
time=event.time,
|
||||
source_platform_object=event,
|
||||
)
|
||||
elif event.post_type == 'notice':
|
||||
yiri_chain = platform_message.MessageChain(
|
||||
[
|
||||
platform_message.Source(id=-1, time=datetime.datetime.now()),
|
||||
platform_message.Notice(
|
||||
notice_type=event.get('notice_type', ''),
|
||||
sub_type=event.get('sub_type', ''),
|
||||
user_id=event.get('user_id', None),
|
||||
target_id=event.get('target_id', None),
|
||||
group_id=event.get('group_id', None),
|
||||
operator_id=event.get('operator_id', None),
|
||||
message_id=event.get('message_id', None),
|
||||
duration=event.get('duration', None),
|
||||
file=event.get('file', None),
|
||||
honor_type=event.get('honor_type', None),
|
||||
),
|
||||
]
|
||||
)
|
||||
return platform_events.NoticeEvent(
|
||||
notice_type=event.get('notice_type', ''),
|
||||
sub_type=event.get('sub_type', ''),
|
||||
user_id=event.get('user_id', None),
|
||||
target_id=event.get('target_id', None),
|
||||
group_id=event.get('group_id', None),
|
||||
time=event.time,
|
||||
source_platform_object=event,
|
||||
)
|
||||
elif event.post_type == 'request':
|
||||
yiri_chain = platform_message.MessageChain(
|
||||
[
|
||||
platform_message.Source(id=-1, time=datetime.datetime.now()),
|
||||
platform_message.Request(
|
||||
request_type=event.get('request_type', ''),
|
||||
sub_type=event.get('sub_type', ''),
|
||||
user_id=event.get('user_id', None),
|
||||
group_id=event.get('group_id', None),
|
||||
comment=event.get('comment', ''),
|
||||
flag=event.get('flag', ''),
|
||||
),
|
||||
]
|
||||
)
|
||||
return platform_events.RequestEvent(
|
||||
request_type=event.get('request_type', ''),
|
||||
sub_type=event.get('sub_type', ''),
|
||||
user_id=event.get('user_id', None),
|
||||
group_id=event.get('group_id', None),
|
||||
comment=event.get('comment', ''),
|
||||
flag=event.get('flag', ''),
|
||||
time=event.time,
|
||||
source_platform_object=event,
|
||||
)
|
||||
|
||||
|
||||
class AiocqhttpAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
@@ -375,18 +426,6 @@ class AiocqhttpAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
self.bot = aiocqhttp.CQHttp()
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
# Check if message contains a Forward component
|
||||
forward_msg = message.get_first(platform_message.Forward)
|
||||
if forward_msg:
|
||||
if target_type == 'group':
|
||||
# Send as merged forward message via OneBot API
|
||||
await self._send_forward_message(int(target_id), forward_msg)
|
||||
return
|
||||
else:
|
||||
await self.logger.warning(
|
||||
f'Forward message is only supported for group targets, got target_type={target_type}. Falling through to normal send.'
|
||||
)
|
||||
|
||||
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
|
||||
|
||||
if target_type == 'group':
|
||||
@@ -394,90 +433,6 @@ class AiocqhttpAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
elif target_type == 'person':
|
||||
await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg)
|
||||
|
||||
async def _send_forward_message(self, group_id: int, forward: platform_message.Forward):
|
||||
"""Send a merged forward message to a group using NapCat extended API."""
|
||||
messages = []
|
||||
|
||||
for node in forward.node_list:
|
||||
# Build content for each node
|
||||
content = []
|
||||
if node.message_chain:
|
||||
for component in node.message_chain:
|
||||
if isinstance(component, platform_message.Plain):
|
||||
if component.text:
|
||||
content.append({'type': 'text', 'data': {'text': component.text}})
|
||||
elif isinstance(component, platform_message.Image):
|
||||
img_data = {}
|
||||
if component.base64:
|
||||
b64 = component.base64
|
||||
if b64.startswith('data:'):
|
||||
b64 = b64.split(',', 1)[-1] if ',' in b64 else b64
|
||||
img_data['file'] = f'base64://{b64}'
|
||||
elif component.url:
|
||||
img_data['file'] = component.url
|
||||
elif component.path:
|
||||
img_data['file'] = str(component.path)
|
||||
|
||||
if img_data:
|
||||
content.append({'type': 'image', 'data': img_data})
|
||||
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# Build node data - use user_id and nickname format for NapCat
|
||||
user_id = str(node.sender_id) if node.sender_id else str(self.bot_account_id or '10000')
|
||||
node_data = {
|
||||
'type': 'node',
|
||||
'data': {
|
||||
'user_id': user_id,
|
||||
'nickname': node.sender_name or '未知',
|
||||
'content': content,
|
||||
},
|
||||
}
|
||||
|
||||
messages.append(node_data)
|
||||
|
||||
if not messages:
|
||||
return
|
||||
|
||||
# Build the full message payload for NapCat's send_forward_msg API
|
||||
# This matches the format used by GiveMeSetuPlugin
|
||||
bot_id = str(self.bot_account_id) if self.bot_account_id else '10000'
|
||||
payload = {
|
||||
'group_id': group_id,
|
||||
'user_id': bot_id, # Required by NapCat for display
|
||||
'messages': messages,
|
||||
}
|
||||
|
||||
# Add display settings if available
|
||||
if forward.display:
|
||||
if forward.display.title:
|
||||
payload['news'] = [{'text': forward.display.title}]
|
||||
if forward.display.brief:
|
||||
payload['prompt'] = forward.display.brief
|
||||
if forward.display.summary:
|
||||
payload['summary'] = forward.display.summary
|
||||
if forward.display.source:
|
||||
payload['source'] = forward.display.source
|
||||
|
||||
try:
|
||||
# Use send_forward_msg (NapCat extended API) instead of send_group_forward_msg
|
||||
await self.logger.info(
|
||||
f'Sending forward message to group {group_id} with {len(messages)} nodes, payload keys: {list(payload.keys())}'
|
||||
)
|
||||
result = await self.bot.call_action('send_forward_msg', **payload)
|
||||
await self.logger.info(f'Forward message sent to group {group_id}, result: {result}')
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Failed to send forward message to group {group_id}: {e}')
|
||||
# Fallback: try standard OneBot API with integer group_id
|
||||
try:
|
||||
await self.logger.info('Trying fallback API send_group_forward_msg')
|
||||
await self.bot.call_action('send_group_forward_msg', group_id=group_id, messages=messages)
|
||||
await self.logger.info(f'Forward message sent via fallback API to group {group_id}')
|
||||
except Exception as e2:
|
||||
await self.logger.error(f'Fallback also failed: {e2}')
|
||||
raise
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
@@ -509,12 +464,31 @@ class AiocqhttpAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
await self.logger.error(f'Error in on_message: {traceback.format_exc()}')
|
||||
traceback.print_exc()
|
||||
|
||||
async def on_notice(event: aiocqhttp.Event):
|
||||
self.bot_account_id = event.self_id
|
||||
try:
|
||||
return await callback(await self.event_converter.target2yiri(event, self.bot), self)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in on_notice: {traceback.format_exc()}')
|
||||
traceback.print_exc()
|
||||
|
||||
async def on_request(event: aiocqhttp.Event):
|
||||
self.bot_account_id = event.self_id
|
||||
try:
|
||||
return await callback(await self.event_converter.target2yiri(event, self.bot), self)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in on_request: {traceback.format_exc()}')
|
||||
traceback.print_exc()
|
||||
|
||||
if event_type == platform_events.GroupMessage:
|
||||
self.bot.on_message('group')(on_message)
|
||||
# self.bot.on_notice()(on_message)
|
||||
elif event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message('private')(on_message)
|
||||
# self.bot.on_notice()(on_message)
|
||||
elif event_type == platform_events.NoticeEvent:
|
||||
self.bot.on_notice()(on_notice)
|
||||
elif event_type == platform_events.RequestEvent:
|
||||
self.bot.on_request()(on_request)
|
||||
# print(event_type)
|
||||
|
||||
async def on_websocket_connection(event: aiocqhttp.Event):
|
||||
|
||||
@@ -14,7 +14,7 @@ import io
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
|
||||
from langbot.pkg.utils import httpclient
|
||||
import aiohttp
|
||||
import pydantic
|
||||
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
@@ -622,23 +622,23 @@ class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
elif ele.url:
|
||||
# 从URL下载图片
|
||||
session = httpclient.get_session()
|
||||
async with session.get(ele.url) as response:
|
||||
image_bytes = await response.read()
|
||||
# 从URL或Content-Type推断文件类型
|
||||
content_type = response.headers.get('Content-Type', '')
|
||||
if 'jpeg' in content_type or 'jpg' in content_type:
|
||||
filename = f'{uuid.uuid4()}.jpg'
|
||||
elif 'gif' in content_type:
|
||||
filename = f'{uuid.uuid4()}.gif'
|
||||
elif 'webp' in content_type:
|
||||
filename = f'{uuid.uuid4()}.webp'
|
||||
elif ele.url.lower().endswith(('.jpg', '.jpeg')):
|
||||
filename = f'{uuid.uuid4()}.jpg'
|
||||
elif ele.url.lower().endswith('.gif'):
|
||||
filename = f'{uuid.uuid4()}.gif'
|
||||
elif ele.url.lower().endswith('.webp'):
|
||||
filename = f'{uuid.uuid4()}.webp'
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(ele.url) as response:
|
||||
image_bytes = await response.read()
|
||||
# 从URL或Content-Type推断文件类型
|
||||
content_type = response.headers.get('Content-Type', '')
|
||||
if 'jpeg' in content_type or 'jpg' in content_type:
|
||||
filename = f'{uuid.uuid4()}.jpg'
|
||||
elif 'gif' in content_type:
|
||||
filename = f'{uuid.uuid4()}.gif'
|
||||
elif 'webp' in content_type:
|
||||
filename = f'{uuid.uuid4()}.webp'
|
||||
elif ele.url.lower().endswith(('.jpg', '.jpeg')):
|
||||
filename = f'{uuid.uuid4()}.jpg'
|
||||
elif ele.url.lower().endswith('.gif'):
|
||||
filename = f'{uuid.uuid4()}.gif'
|
||||
elif ele.url.lower().endswith('.webp'):
|
||||
filename = f'{uuid.uuid4()}.webp'
|
||||
elif ele.path:
|
||||
# 从文件路径读取图片
|
||||
# 确保路径没有空字节
|
||||
@@ -702,9 +702,9 @@ class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter
|
||||
file_base64 = ele.base64.split(',')[-1]
|
||||
file_bytes = base64.b64decode(file_base64)
|
||||
elif ele.url:
|
||||
session = httpclient.get_session()
|
||||
async with session.get(ele.url) as response:
|
||||
file_bytes = await response.read()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(ele.url) as response:
|
||||
file_bytes = await response.read()
|
||||
if file_bytes:
|
||||
files.append(discord.File(fp=io.BytesIO(file_bytes), filename=filename))
|
||||
elif isinstance(ele, platform_message.File):
|
||||
@@ -717,9 +717,9 @@ class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter
|
||||
else:
|
||||
file_bytes = base64.b64decode(ele.base64)
|
||||
elif ele.url:
|
||||
session = httpclient.get_session()
|
||||
async with session.get(ele.url) as response:
|
||||
file_bytes = await response.read()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(ele.url) as response:
|
||||
file_bytes = await response.read()
|
||||
if file_bytes:
|
||||
files.append(discord.File(fp=io.BytesIO(file_bytes), filename=filename))
|
||||
elif isinstance(ele, platform_message.Forward):
|
||||
@@ -775,12 +775,12 @@ class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter
|
||||
|
||||
# attachments
|
||||
for attachment in message.attachments:
|
||||
session = httpclient.get_session(trust_env=True)
|
||||
async with session.get(attachment.url) as response:
|
||||
image_data = await response.read()
|
||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||
image_format = response.headers['Content-Type']
|
||||
element_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}'))
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(attachment.url) as response:
|
||||
image_data = await response.read()
|
||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||
image_format = response.headers['Content-Type']
|
||||
element_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}'))
|
||||
|
||||
return platform_message.MessageChain(element_list)
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ import traceback
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
|
||||
from langbot.pkg.utils import httpclient
|
||||
import websockets
|
||||
import pydantic
|
||||
|
||||
@@ -122,16 +120,16 @@ class KookMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
if content:
|
||||
# Download image and convert to base64
|
||||
try:
|
||||
session = httpclient.get_session()
|
||||
async with session.get(content) as response:
|
||||
if response.status == 200:
|
||||
image_bytes = await response.read()
|
||||
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||
# Detect image format
|
||||
content_type = response.headers.get('Content-Type', 'image/png')
|
||||
components.append(
|
||||
platform_message.Image(base64=f'data:{content_type};base64,{image_base64}')
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(content) as response:
|
||||
if response.status == 200:
|
||||
image_bytes = await response.read()
|
||||
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||
# Detect image format
|
||||
content_type = response.headers.get('Content-Type', 'image/png')
|
||||
components.append(
|
||||
platform_message.Image(base64=f'data:{content_type};base64,{image_base64}')
|
||||
)
|
||||
except Exception:
|
||||
# If download fails, just add as plain text
|
||||
components.append(platform_message.Plain(text=f'[Image: {content}]'))
|
||||
@@ -297,17 +295,17 @@ class KookAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
'Authorization': f'Bot {self.config["token"]}',
|
||||
}
|
||||
|
||||
session = httpclient.get_session()
|
||||
async with session.get(base_url, params=params, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
if data.get('code') == 0:
|
||||
gateway_url = data['data']['url']
|
||||
return gateway_url
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(base_url, params=params, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
if data.get('code') == 0:
|
||||
gateway_url = data['data']['url']
|
||||
return gateway_url
|
||||
else:
|
||||
raise Exception(f'Failed to get gateway URL: {data.get("message")}')
|
||||
else:
|
||||
raise Exception(f'Failed to get gateway URL: {data.get("message")}')
|
||||
else:
|
||||
raise Exception(f'Failed to get gateway URL: HTTP {response.status}')
|
||||
raise Exception(f'Failed to get gateway URL: HTTP {response.status}')
|
||||
|
||||
async def _get_bot_user_info(self) -> dict:
|
||||
"""Get bot's own user information from KOOK API"""
|
||||
@@ -317,17 +315,17 @@ class KookAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
'Authorization': f'Bot {self.config["token"]}',
|
||||
}
|
||||
|
||||
session = httpclient.get_session()
|
||||
async with session.get(base_url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
if data.get('code') == 0:
|
||||
user_info = data['data']
|
||||
return user_info
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(base_url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
if data.get('code') == 0:
|
||||
user_info = data['data']
|
||||
return user_info
|
||||
else:
|
||||
raise Exception(f'Failed to get bot user info: {data.get("message")}')
|
||||
else:
|
||||
raise Exception(f'Failed to get bot user info: {data.get("message")}')
|
||||
else:
|
||||
raise Exception(f'Failed to get bot user info: HTTP {response.status}')
|
||||
raise Exception(f'Failed to get bot user info: HTTP {response.status}')
|
||||
|
||||
async def _handle_hello(self, data: dict):
|
||||
"""Handle HELLO signal (signal 1)"""
|
||||
@@ -512,7 +510,7 @@ class KookAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
|
||||
try:
|
||||
if not self.http_session:
|
||||
self.http_session = httpclient.get_session()
|
||||
self.http_session = aiohttp.ClientSession()
|
||||
|
||||
async with self.http_session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
@@ -578,7 +576,7 @@ class KookAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
|
||||
try:
|
||||
if not self.http_session:
|
||||
self.http_session = httpclient.get_session()
|
||||
self.http_session = aiohttp.ClientSession()
|
||||
|
||||
async with self.http_session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
@@ -626,7 +624,7 @@ class KookAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
|
||||
try:
|
||||
# Create HTTP session
|
||||
self.http_session = httpclient.get_session()
|
||||
self.http_session = aiohttp.ClientSession()
|
||||
|
||||
await self.logger.info('Starting KOOK adapter')
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import lark_oapi
|
||||
from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody, CreateFileRequest, CreateFileRequestBody
|
||||
from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody
|
||||
import traceback
|
||||
import typing
|
||||
import asyncio
|
||||
@@ -17,7 +17,7 @@ import tempfile
|
||||
import os
|
||||
import mimetypes
|
||||
|
||||
from langbot.pkg.utils import httpclient
|
||||
import aiohttp
|
||||
import lark_oapi.ws.exception
|
||||
import quart
|
||||
from lark_oapi.api.im.v1 import *
|
||||
@@ -78,13 +78,13 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
return None
|
||||
elif msg.url:
|
||||
try:
|
||||
session = httpclient.get_session()
|
||||
async with session.get(msg.url) as response:
|
||||
if response.status == 200:
|
||||
image_bytes = await response.read()
|
||||
else:
|
||||
print(f'Failed to download image from {msg.url}: HTTP {response.status}')
|
||||
return None
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(msg.url) as response:
|
||||
if response.status == 200:
|
||||
image_bytes = await response.read()
|
||||
else:
|
||||
print(f'Failed to download image from {msg.url}: HTTP {response.status}')
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f'Failed to download image from {msg.url}: {e}')
|
||||
traceback.print_exc()
|
||||
@@ -141,88 +141,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def upload_file_to_lark(
|
||||
file_bytes: bytes,
|
||||
api_client: lark_oapi.Client,
|
||||
file_type: str,
|
||||
file_name: str = 'file',
|
||||
duration: typing.Optional[int] = None,
|
||||
) -> typing.Optional[str]:
|
||||
"""Upload a file to Lark and return the file_key, or None if upload fails.
|
||||
|
||||
Args:
|
||||
file_bytes: Raw file bytes.
|
||||
api_client: Lark API client.
|
||||
file_type: Lark file type, e.g. 'opus', 'mp4', 'pdf', 'doc', etc.
|
||||
file_name: Display name for the file.
|
||||
duration: Duration in milliseconds (for audio files).
|
||||
"""
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
temp_file.write(file_bytes)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
body_builder = (
|
||||
CreateFileRequestBody.builder()
|
||||
.file_type(file_type)
|
||||
.file_name(file_name)
|
||||
.file(open(temp_file_path, 'rb'))
|
||||
)
|
||||
if duration is not None:
|
||||
body_builder = body_builder.duration(duration)
|
||||
|
||||
request = CreateFileRequest.builder().request_body(body_builder.build()).build()
|
||||
|
||||
response = await api_client.im.v1.file.acreate(request)
|
||||
|
||||
if not response.success():
|
||||
print(
|
||||
f'client.im.v1.file.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}'
|
||||
)
|
||||
return None
|
||||
|
||||
return response.data.file_key
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
except Exception as e:
|
||||
print(f'Failed to upload file to Lark: {e}')
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _get_media_bytes(
|
||||
msg: typing.Union[platform_message.Voice, platform_message.File],
|
||||
) -> typing.Optional[bytes]:
|
||||
"""Get bytes from a Voice or File message (base64, url, or path)."""
|
||||
data = None
|
||||
|
||||
if msg.base64:
|
||||
try:
|
||||
base64_str = msg.base64
|
||||
if ',' in base64_str:
|
||||
base64_str = base64_str.split(',', 1)[1]
|
||||
data = base64.b64decode(base64_str)
|
||||
except Exception:
|
||||
pass
|
||||
elif msg.url:
|
||||
try:
|
||||
session = httpclient.get_session()
|
||||
async with session.get(msg.url) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.read()
|
||||
except Exception:
|
||||
pass
|
||||
elif msg.path:
|
||||
try:
|
||||
with open(msg.path, 'rb') as f:
|
||||
data = f.read()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
message_chain: platform_message.MessageChain, api_client: lark_oapi.Client
|
||||
@@ -232,10 +150,10 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
Returns:
|
||||
Tuple of (text_elements, image_keys):
|
||||
- text_elements: List of paragraphs for post message format
|
||||
- media_items: List of dicts with 'msg_type' and 'content' for separate media messages
|
||||
- image_keys: List of image_key strings for separate image messages
|
||||
"""
|
||||
message_elements = []
|
||||
media_items = []
|
||||
image_keys = []
|
||||
pending_paragraph = []
|
||||
|
||||
# Regex pattern to match Markdown image syntax: 
|
||||
@@ -278,77 +196,40 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
# Check for and extract Markdown images from text
|
||||
cleaned_text, extracted_urls = await process_text_with_images(text)
|
||||
|
||||
# Split by blank lines to create separate paragraphs for Lark post format.
|
||||
# Lark truncates md elements at the first \n\n, so we must use the
|
||||
# post format's native paragraph structure instead.
|
||||
# Add cleaned text if not empty
|
||||
if cleaned_text:
|
||||
segments = re.split(r'\n\s*\n', cleaned_text)
|
||||
for i, segment in enumerate(segments):
|
||||
segment = segment.strip()
|
||||
if not segment:
|
||||
continue
|
||||
if i > 0 and pending_paragraph:
|
||||
message_elements.append(pending_paragraph)
|
||||
pending_paragraph = []
|
||||
pending_paragraph.append({'tag': 'md', 'text': segment})
|
||||
pending_paragraph.append({'tag': 'md', 'text': cleaned_text})
|
||||
|
||||
# Process extracted image URLs
|
||||
for url in extracted_urls:
|
||||
# Create a temporary Image message to upload
|
||||
temp_image = platform_message.Image(url=url)
|
||||
image_key = await LarkMessageConverter.upload_image_to_lark(temp_image, api_client)
|
||||
if image_key:
|
||||
media_items.append({'msg_type': 'image', 'content': {'image_key': image_key}})
|
||||
image_keys.append(image_key)
|
||||
|
||||
elif isinstance(msg, platform_message.At):
|
||||
pending_paragraph.append({'tag': 'at', 'user_id': msg.target, 'style': []})
|
||||
elif isinstance(msg, platform_message.AtAll):
|
||||
pending_paragraph.append({'tag': 'at', 'user_id': 'all', 'style': []})
|
||||
elif isinstance(msg, platform_message.Image):
|
||||
# Upload image and get image_key
|
||||
image_key = await LarkMessageConverter.upload_image_to_lark(msg, api_client)
|
||||
if image_key:
|
||||
media_items.append({'msg_type': 'image', 'content': {'image_key': image_key}})
|
||||
elif isinstance(msg, platform_message.Voice):
|
||||
data = await LarkMessageConverter._get_media_bytes(msg)
|
||||
if data:
|
||||
duration = int(msg.length * 1000) if msg.length else None
|
||||
file_key = await LarkMessageConverter.upload_file_to_lark(
|
||||
data, api_client, file_type='opus', file_name='voice.opus', duration=duration
|
||||
)
|
||||
if file_key:
|
||||
media_items.append({'msg_type': 'audio', 'content': {'file_key': file_key}})
|
||||
elif isinstance(msg, platform_message.File):
|
||||
data = await LarkMessageConverter._get_media_bytes(msg)
|
||||
if data:
|
||||
file_name = msg.name or 'file'
|
||||
# Guess file_type from extension
|
||||
ext = os.path.splitext(file_name)[1].lstrip('.').lower() if file_name else ''
|
||||
file_type_map = {
|
||||
'opus': 'opus',
|
||||
'mp4': 'mp4',
|
||||
'pdf': 'pdf',
|
||||
'doc': 'doc',
|
||||
'docx': 'doc',
|
||||
'xls': 'xls',
|
||||
'xlsx': 'xls',
|
||||
'ppt': 'ppt',
|
||||
'pptx': 'ppt',
|
||||
}
|
||||
file_type = file_type_map.get(ext, 'stream')
|
||||
file_key = await LarkMessageConverter.upload_file_to_lark(
|
||||
data, api_client, file_type=file_type, file_name=file_name
|
||||
)
|
||||
if file_key:
|
||||
media_items.append({'msg_type': 'file', 'content': {'file_key': file_key}})
|
||||
# Store image_key for separate image message
|
||||
image_keys.append(image_key)
|
||||
elif isinstance(msg, platform_message.Forward):
|
||||
for node in msg.node_list:
|
||||
sub_elements, sub_media = await LarkMessageConverter.yiri2target(node.message_chain, api_client)
|
||||
sub_elements, sub_image_keys = await LarkMessageConverter.yiri2target(
|
||||
node.message_chain, api_client
|
||||
)
|
||||
message_elements.extend(sub_elements)
|
||||
media_items.extend(sub_media)
|
||||
image_keys.extend(sub_image_keys)
|
||||
|
||||
if pending_paragraph:
|
||||
message_elements.append(pending_paragraph)
|
||||
|
||||
return message_elements, media_items
|
||||
return message_elements, image_keys
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(
|
||||
@@ -575,127 +456,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
|
||||
|
||||
class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
_processed_thread_quote_cache: typing.ClassVar[dict[str, float]] = {}
|
||||
_processed_thread_quote_cache_max_size: typing.ClassVar[int] = 4096
|
||||
_processed_thread_quote_cache_ttl_seconds: typing.ClassVar[int] = 86400
|
||||
|
||||
@classmethod
|
||||
def _prune_processed_thread_quote_cache(cls, now: typing.Optional[float] = None) -> None:
|
||||
if now is None:
|
||||
now = time.time()
|
||||
|
||||
expire_before = now - cls._processed_thread_quote_cache_ttl_seconds
|
||||
while cls._processed_thread_quote_cache:
|
||||
oldest_key, oldest_ts = next(iter(cls._processed_thread_quote_cache.items()))
|
||||
if oldest_ts >= expire_before:
|
||||
break
|
||||
cls._processed_thread_quote_cache.pop(oldest_key, None)
|
||||
|
||||
while len(cls._processed_thread_quote_cache) > cls._processed_thread_quote_cache_max_size:
|
||||
oldest_key = next(iter(cls._processed_thread_quote_cache))
|
||||
cls._processed_thread_quote_cache.pop(oldest_key, None)
|
||||
|
||||
@classmethod
|
||||
def _mark_thread_quote_processed(cls, thread_id: str) -> None:
|
||||
now = time.time()
|
||||
cls._prune_processed_thread_quote_cache(now)
|
||||
cls._processed_thread_quote_cache[thread_id] = now
|
||||
|
||||
@classmethod
|
||||
def _extract_quote_message_id(cls, message: EventMessage) -> typing.Optional[str]:
|
||||
"""
|
||||
Extract the message ID to quote from the given message.
|
||||
|
||||
Rules:
|
||||
- First thread reply in a topic: return parent_id and mark topic as processed
|
||||
- Follow-up thread replies in the same topic: return None
|
||||
- Non-thread message: return parent_id if valid (non-empty, different from message_id)
|
||||
|
||||
Thread reply state is kept in a bounded TTL cache to avoid unbounded memory growth.
|
||||
"""
|
||||
parent_id = getattr(message, 'parent_id', None)
|
||||
if not parent_id:
|
||||
return None
|
||||
|
||||
message_id = getattr(message, 'message_id', None)
|
||||
if parent_id == message_id:
|
||||
return None
|
||||
|
||||
thread_id = getattr(message, 'thread_id', None)
|
||||
if thread_id:
|
||||
cls._prune_processed_thread_quote_cache()
|
||||
if thread_id in cls._processed_thread_quote_cache:
|
||||
return None
|
||||
cls._mark_thread_quote_processed(thread_id)
|
||||
|
||||
return parent_id
|
||||
|
||||
@staticmethod
|
||||
def _build_event_message_from_message_item(message_item: Message) -> typing.Optional[EventMessage]:
|
||||
"""
|
||||
Build EventMessage from SDK typed Message item.
|
||||
|
||||
Returns None if body or content is missing.
|
||||
"""
|
||||
body = getattr(message_item, 'body', None)
|
||||
if not body:
|
||||
return None
|
||||
|
||||
content = getattr(body, 'content', None)
|
||||
if not content:
|
||||
return None
|
||||
|
||||
event_data = {
|
||||
'message_id': message_item.message_id,
|
||||
'message_type': message_item.msg_type,
|
||||
'content': content,
|
||||
'create_time': message_item.create_time,
|
||||
'mentions': getattr(message_item, 'mentions', []) or [],
|
||||
}
|
||||
|
||||
# Preserve thread-related fields
|
||||
if hasattr(message_item, 'parent_id') and message_item.parent_id:
|
||||
event_data['parent_id'] = message_item.parent_id
|
||||
if hasattr(message_item, 'root_id') and message_item.root_id:
|
||||
event_data['root_id'] = message_item.root_id
|
||||
if hasattr(message_item, 'thread_id') and message_item.thread_id:
|
||||
event_data['thread_id'] = message_item.thread_id
|
||||
if hasattr(message_item, 'chat_id') and message_item.chat_id:
|
||||
event_data['chat_id'] = message_item.chat_id
|
||||
|
||||
return EventMessage(event_data)
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_quoted_message(
|
||||
quote_message_id: str,
|
||||
api_client: lark_oapi.Client,
|
||||
) -> typing.Optional[platform_message.MessageChain]:
|
||||
"""
|
||||
Fetch the quoted message and convert to MessageChain.
|
||||
|
||||
Returns None if:
|
||||
- API call fails
|
||||
- Response items is empty
|
||||
- Message item normalization fails
|
||||
"""
|
||||
request = GetMessageRequest.builder().message_id(quote_message_id).build()
|
||||
response = await api_client.im.v1.message.aget(request)
|
||||
|
||||
if not response.success():
|
||||
return None
|
||||
|
||||
items = getattr(response.data, 'items', None)
|
||||
if not items:
|
||||
return None
|
||||
|
||||
message_item = items[0]
|
||||
event_message = LarkEventConverter._build_event_message_from_message_item(message_item)
|
||||
if event_message is None:
|
||||
return None
|
||||
|
||||
quote_chain = await LarkMessageConverter.target2yiri(event_message, api_client)
|
||||
return quote_chain
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
event: platform_events.MessageEvent,
|
||||
@@ -708,23 +468,6 @@ class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
) -> platform_events.Event:
|
||||
message_chain = await LarkMessageConverter.target2yiri(event.event.message, api_client)
|
||||
|
||||
# Check for quote/reply message
|
||||
quote_message_id = LarkEventConverter._extract_quote_message_id(event.event.message)
|
||||
if quote_message_id:
|
||||
quote_chain = await LarkEventConverter._fetch_quoted_message(quote_message_id, api_client)
|
||||
if quote_chain:
|
||||
# Filter out Source component from quoted chain, keep only content
|
||||
quote_origin = platform_message.MessageChain(
|
||||
[comp for comp in quote_chain if not isinstance(comp, platform_message.Source)]
|
||||
)
|
||||
if quote_origin:
|
||||
message_chain.append(
|
||||
platform_message.Quote(
|
||||
message_id=quote_message_id,
|
||||
origin=quote_origin,
|
||||
)
|
||||
)
|
||||
|
||||
if event.event.message.chat_type == 'p2p':
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
@@ -908,32 +651,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
self.request_tenant_access_token(tenant_key)
|
||||
return self.tenant_access_tokens.get(tenant_key)['token'] if self.tenant_access_tokens.get(tenant_key) else None
|
||||
|
||||
def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None:
|
||||
"""
|
||||
Get topic-scoped launcher_id for thread-aware session isolation.
|
||||
|
||||
For group thread messages, returns "{group_id}_{thread_id}"
|
||||
to ensure conversation context stays stable per topic.
|
||||
|
||||
Returns None for non-thread messages or P2P messages.
|
||||
"""
|
||||
source_event = getattr(event.source_platform_object, 'event', None)
|
||||
if not source_event:
|
||||
return None
|
||||
|
||||
message = getattr(source_event, 'message', None)
|
||||
if not message:
|
||||
return None
|
||||
|
||||
thread_id = getattr(message, 'thread_id', None)
|
||||
if not thread_id:
|
||||
return None
|
||||
|
||||
if isinstance(event, platform_events.GroupMessage):
|
||||
return f'{event.group.id}_{thread_id}'
|
||||
|
||||
return None
|
||||
|
||||
def build_api_client(self, config):
|
||||
app_id = config['app_id']
|
||||
app_secret = config['app_secret']
|
||||
@@ -1200,40 +917,23 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
):
|
||||
# 不再需要了,因为message_id已经被包含到message_chain中
|
||||
# lark_event = await self.event_converter.yiri2target(message_source)
|
||||
text_elements, media_items = await self.message_converter.yiri2target(message, self.api_client)
|
||||
text_elements, image_keys = await self.message_converter.yiri2target(message, self.api_client)
|
||||
|
||||
# Send text message if there are text elements
|
||||
if text_elements:
|
||||
# Determine msg_type based on content: use 'post' if at mentions
|
||||
# are present (requires post paragraph structure), otherwise 'text'
|
||||
needs_post = any(ele['tag'] == 'at' for paragraph in text_elements for ele in paragraph)
|
||||
|
||||
if needs_post:
|
||||
msg_type = 'post'
|
||||
final_content = json.dumps(
|
||||
{
|
||||
'zh_Hans': {
|
||||
'title': '',
|
||||
'content': text_elements,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
msg_type = 'text'
|
||||
parts = []
|
||||
for paragraph in text_elements:
|
||||
para_text = ''.join(ele.get('text', '') for ele in paragraph)
|
||||
if para_text:
|
||||
parts.append(para_text)
|
||||
final_content = json.dumps({'text': '\n\n'.join(parts)})
|
||||
|
||||
final_content = {
|
||||
'zh_Hans': {
|
||||
'title': '',
|
||||
'content': text_elements,
|
||||
},
|
||||
}
|
||||
request: ReplyMessageRequest = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(message_source.message_chain.message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.content(final_content)
|
||||
.msg_type(msg_type)
|
||||
.content(json.dumps(final_content))
|
||||
.msg_type('post')
|
||||
.reply_in_thread(False)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
@@ -1263,15 +963,17 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
f'client.im.v1.message.reply failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
|
||||
# Send media messages separately (image, audio, file, etc.)
|
||||
for media in media_items:
|
||||
# Send image messages separately using msg_type='image'
|
||||
for image_key in image_keys:
|
||||
image_content = json.dumps({'image_key': image_key})
|
||||
|
||||
request: ReplyMessageRequest = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(message_source.message_chain.message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.content(json.dumps(media['content']))
|
||||
.msg_type(media['msg_type'])
|
||||
.content(image_content)
|
||||
.msg_type('image')
|
||||
.reply_in_thread(False)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
@@ -1298,7 +1000,7 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
|
||||
if not response.success():
|
||||
raise Exception(
|
||||
f'client.im.v1.message.reply ({media["msg_type"]}) failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
f'client.im.v1.message.reply (image) failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
|
||||
async def reply_message_chunk(
|
||||
@@ -1316,16 +1018,15 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
message_id = bot_message.resp_message_id
|
||||
msg_seq = bot_message.msg_sequence
|
||||
if msg_seq % 8 == 0 or is_final:
|
||||
text_elements, media_items = await self.message_converter.yiri2target(message, self.api_client)
|
||||
text_elements, image_keys = await self.message_converter.yiri2target(message, self.api_client)
|
||||
|
||||
text_message = ''
|
||||
if text_elements:
|
||||
parts = []
|
||||
for paragraph in text_elements:
|
||||
para_text = ''.join(ele['text'] for ele in paragraph if ele['tag'] in ('text', 'md'))
|
||||
if para_text:
|
||||
parts.append(para_text)
|
||||
text_message = '\n\n'.join(parts)
|
||||
for ele in text_elements[0]:
|
||||
if ele['tag'] == 'text':
|
||||
text_message += ele['text']
|
||||
elif ele['tag'] == 'md':
|
||||
text_message += ele['text']
|
||||
|
||||
# content = {
|
||||
# 'type': 'card_json',
|
||||
@@ -1375,30 +1076,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
)
|
||||
return
|
||||
|
||||
# Send media messages when streaming is done
|
||||
if is_final and media_items:
|
||||
for media in media_items:
|
||||
media_request: ReplyMessageRequest = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(message_source.message_chain.message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.content(json.dumps(media['content']))
|
||||
.msg_type(media['msg_type'])
|
||||
.reply_in_thread(False)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
media_response: ReplyMessageResponse = await self.api_client.im.v1.message.areply(
|
||||
media_request, req_opt
|
||||
)
|
||||
if not media_response.success():
|
||||
raise Exception(
|
||||
f'client.im.v1.message.reply ({media["msg_type"]}) failed, code: {media_response.code}, msg: {media_response.msg}, log_id: {media_response.get_log_id()}'
|
||||
)
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import copy
|
||||
import threading
|
||||
|
||||
import quart
|
||||
from langbot.pkg.utils import httpclient
|
||||
import aiohttp
|
||||
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
from ....core import app
|
||||
@@ -639,14 +639,14 @@ class GeWeChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
|
||||
async def run_async(self):
|
||||
if not self.config['token']:
|
||||
session = httpclient.get_session()
|
||||
async with session.post(
|
||||
f'{self.config["gewechat_url"]}/v2/api/tools/getTokenId',
|
||||
json={'app_id': self.config['app_id']},
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f'获取gewechat token失败: {await response.text()}')
|
||||
self.config['token'] = (await response.json())['data']
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f'{self.config["gewechat_url"]}/v2/api/tools/getTokenId',
|
||||
json={'app_id': self.config['app_id']},
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f'获取gewechat token失败: {await response.text()}')
|
||||
self.config['token'] = (await response.json())['data']
|
||||
|
||||
self.bot = gewechat_client.GewechatClient(f'{self.config["gewechat_url"]}/v2/api', self.config['token'])
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
|
||||
|
||||
import telegram
|
||||
@@ -10,9 +9,9 @@ import telegramify_markdown
|
||||
import typing
|
||||
import traceback
|
||||
import base64
|
||||
import aiohttp
|
||||
import pydantic
|
||||
|
||||
from langbot.pkg.utils import httpclient
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
@@ -34,9 +33,9 @@ class TelegramMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
||||
if component.base64:
|
||||
photo_bytes = base64.b64decode(component.base64)
|
||||
elif component.url:
|
||||
session = httpclient.get_session()
|
||||
async with session.get(component.url) as response:
|
||||
photo_bytes = await response.read()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(component.url) as response:
|
||||
photo_bytes = await response.read()
|
||||
elif component.path:
|
||||
with open(component.path, 'rb') as f:
|
||||
photo_bytes = f.read()
|
||||
@@ -75,9 +74,10 @@ class TelegramMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
||||
file_bytes = None
|
||||
file_format = ''
|
||||
|
||||
async with httpclient.get_session(trust_env=True).get(file.file_path) as response:
|
||||
file_bytes = await response.read()
|
||||
file_format = 'image/jpeg'
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(file.file_path) as response:
|
||||
file_bytes = await response.read()
|
||||
file_format = 'image/jpeg'
|
||||
|
||||
message_components.append(
|
||||
platform_message.Image(
|
||||
@@ -94,8 +94,9 @@ class TelegramMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
||||
file_bytes = None
|
||||
file_format = message.voice.mime_type or 'audio/ogg'
|
||||
|
||||
async with httpclient.get_session(trust_env=True).get(file.file_path) as response:
|
||||
file_bytes = await response.read()
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(file.file_path) as response:
|
||||
file_bytes = await response.read()
|
||||
|
||||
message_components.append(
|
||||
platform_message.Voice(
|
||||
@@ -193,31 +194,7 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
)
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
components = await TelegramMessageConverter.yiri2target(message, self.bot)
|
||||
|
||||
chat_id_str, _, thread_id_str = str(target_id).partition('#')
|
||||
chat_id: int | str = int(chat_id_str) if chat_id_str.lstrip('-').isdigit() else chat_id_str
|
||||
message_thread_id = int(thread_id_str) if thread_id_str and thread_id_str.isdigit() else None
|
||||
|
||||
for component in components:
|
||||
component_type = component.get('type')
|
||||
args = {'chat_id': chat_id}
|
||||
if message_thread_id is not None:
|
||||
args['message_thread_id'] = message_thread_id
|
||||
|
||||
if component_type == 'text':
|
||||
text = component.get('text', '')
|
||||
if self.config['markdown_card'] is True:
|
||||
text = telegramify_markdown.markdownify(content=text)
|
||||
args['parse_mode'] = 'MarkdownV2'
|
||||
args['text'] = text
|
||||
await self.bot.send_message(**args)
|
||||
elif component_type == 'photo':
|
||||
photo = component.get('photo')
|
||||
if photo is None:
|
||||
continue
|
||||
args['photo'] = telegram.InputFile(photo)
|
||||
await self.bot.send_photo(**args)
|
||||
pass
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
@@ -251,39 +228,6 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
|
||||
await self.bot.send_message(**args)
|
||||
|
||||
def _process_markdown(self, text: str) -> str:
|
||||
if self.config.get('markdown_card', False):
|
||||
return telegramify_markdown.markdownify(content=text)
|
||||
return text
|
||||
|
||||
def _build_message_args(self, chat_id: int, text: str, message_thread_id: int = None, **extra_args) -> dict:
|
||||
args = {'chat_id': chat_id, 'text': self._process_markdown(text), **extra_args}
|
||||
if message_thread_id:
|
||||
args['message_thread_id'] = message_thread_id
|
||||
if self.config.get('markdown_card', False):
|
||||
args['parse_mode'] = 'MarkdownV2'
|
||||
return args
|
||||
|
||||
async def create_message_card(self, message_id, event):
|
||||
assert isinstance(event.source_platform_object, Update)
|
||||
update = event.source_platform_object
|
||||
chat_id = update.effective_chat.id
|
||||
chat_type = update.effective_chat.type
|
||||
message_thread_id = update.message.message_thread_id
|
||||
|
||||
if chat_type == 'private':
|
||||
draft_id = int(time.time() * 1000)
|
||||
self.msg_stream_id[message_id] = ('private', draft_id)
|
||||
|
||||
args = self._build_message_args(chat_id, 'Thinking...', message_thread_id, draft_id=draft_id)
|
||||
await self.bot.send_message_draft(**args)
|
||||
else:
|
||||
args = self._build_message_args(chat_id, 'Thinking...', message_thread_id)
|
||||
send_msg = await self.bot.send_message(**args)
|
||||
self.msg_stream_id[message_id] = ('group', send_msg.message_id)
|
||||
|
||||
return True
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
@@ -292,47 +236,59 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
):
|
||||
message_id = bot_message.resp_message_id
|
||||
msg_seq = bot_message.msg_sequence
|
||||
assert isinstance(message_source.source_platform_object, Update)
|
||||
update = message_source.source_platform_object
|
||||
chat_id = update.effective_chat.id
|
||||
message_thread_id = update.message.message_thread_id
|
||||
if (msg_seq - 1) % 8 == 0 or is_final:
|
||||
assert isinstance(message_source.source_platform_object, Update)
|
||||
components = await TelegramMessageConverter.yiri2target(message, self.bot)
|
||||
args = {}
|
||||
message_id = message_source.source_platform_object.message.id
|
||||
|
||||
if message_id not in self.msg_stream_id:
|
||||
return
|
||||
component = components[0]
|
||||
if message_id not in self.msg_stream_id: # 当消息回复第一次时,发送新消息
|
||||
# time.sleep(0.6)
|
||||
if component['type'] == 'text':
|
||||
if self.config['markdown_card'] is True:
|
||||
content = telegramify_markdown.markdownify(
|
||||
content=component['text'],
|
||||
)
|
||||
else:
|
||||
content = component['text']
|
||||
args = {
|
||||
'chat_id': message_source.source_platform_object.effective_chat.id,
|
||||
'text': content,
|
||||
}
|
||||
if message_source.source_platform_object.message.message_thread_id:
|
||||
args['message_thread_id'] = message_source.source_platform_object.message.message_thread_id
|
||||
|
||||
chat_mode, draft_id = self.msg_stream_id[message_id]
|
||||
components = await TelegramMessageConverter.yiri2target(message, self.bot)
|
||||
if quote_origin:
|
||||
args['reply_to_message_id'] = message_source.source_platform_object.message.id
|
||||
|
||||
if not components or components[0]['type'] != 'text':
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
self.msg_stream_id.pop(message_id)
|
||||
return
|
||||
if self.config['markdown_card'] is True:
|
||||
args['parse_mode'] = 'MarkdownV2'
|
||||
|
||||
content = components[0]['text']
|
||||
send_msg = await self.bot.send_message(**args)
|
||||
send_msg_id = send_msg.message_id
|
||||
self.msg_stream_id[message_id] = send_msg_id
|
||||
else: # 存在消息的时候直接编辑消息1
|
||||
if component['type'] == 'text':
|
||||
if self.config['markdown_card'] is True:
|
||||
content = telegramify_markdown.markdownify(
|
||||
content=component['text'],
|
||||
)
|
||||
else:
|
||||
content = component['text']
|
||||
args = {
|
||||
'message_id': self.msg_stream_id[message_id],
|
||||
'chat_id': message_source.source_platform_object.effective_chat.id,
|
||||
'text': content,
|
||||
}
|
||||
if self.config['markdown_card'] is True:
|
||||
args['parse_mode'] = 'MarkdownV2'
|
||||
|
||||
if chat_mode == 'private':
|
||||
args = self._build_message_args(chat_id, content, message_thread_id, draft_id=draft_id)
|
||||
await self.bot.send_message_draft(**args)
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
del args['draft_id']
|
||||
await self.bot.send_message(**args)
|
||||
self.msg_stream_id.pop(message_id)
|
||||
else:
|
||||
stream_id = draft_id
|
||||
if (msg_seq - 1) % 8 == 0 or is_final:
|
||||
args = {
|
||||
'message_id': stream_id,
|
||||
'chat_id': chat_id,
|
||||
'text': self._process_markdown(content),
|
||||
}
|
||||
if self.config.get('markdown_card', False):
|
||||
args['parse_mode'] = 'MarkdownV2'
|
||||
await self.bot.edit_message_text(**args)
|
||||
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
self.msg_stream_id.pop(message_id)
|
||||
# self.seq = 1 # 消息回复结束之后重置seq
|
||||
self.msg_stream_id.pop(message_id) # 消息回复结束之后删除流式消息id
|
||||
|
||||
def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None:
|
||||
if not isinstance(event.source_platform_object, Update):
|
||||
|
||||
@@ -37,24 +37,16 @@ class WebSocketSession:
|
||||
id: str
|
||||
message_lists: dict[str, list[WebSocketMessage]] = {}
|
||||
"""消息列表 {pipeline_uuid: [messages]}"""
|
||||
stream_message_indexes: dict[str, dict[str, int]] = {}
|
||||
"""流式消息索引 {pipeline_uuid: {resp_message_id: message_index}}"""
|
||||
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
self.message_lists = {}
|
||||
self.stream_message_indexes = {}
|
||||
|
||||
def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]:
|
||||
if pipeline_uuid not in self.message_lists:
|
||||
self.message_lists[pipeline_uuid] = []
|
||||
return self.message_lists[pipeline_uuid]
|
||||
|
||||
def get_stream_message_indexes(self, pipeline_uuid: str) -> dict[str, int]:
|
||||
if pipeline_uuid not in self.stream_message_indexes:
|
||||
self.stream_message_indexes[pipeline_uuid] = {}
|
||||
return self.stream_message_indexes[pipeline_uuid]
|
||||
|
||||
|
||||
class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
"""WebSocket适配器 - 支持双向实时通信"""
|
||||
@@ -97,46 +89,20 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
target_id: str,
|
||||
message: platform_message.MessageChain,
|
||||
) -> dict:
|
||||
"""发送消息 - 这里用于主动推送消息到前端
|
||||
"""发送消息 - 这里用于主动推送消息到前端"""
|
||||
message_data = {
|
||||
'type': 'bot_message',
|
||||
'target_type': target_type,
|
||||
'target_id': target_id,
|
||||
'content': str(message),
|
||||
'message_chain': [component.__dict__ for component in message],
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
对于 WebSocket 适配器,我们需要将消息广播到正确的 pipeline 连接。
|
||||
target_id 可能是 launcher_id(如 websocket_xxx)或 pipeline_uuid。
|
||||
我们需要尝试两种方式来确保消息能够送达。
|
||||
"""
|
||||
# 获取当前的 pipeline_uuid
|
||||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||||
session_type = 'group' if target_type == 'group' else 'person'
|
||||
# 推送到所有相关连接
|
||||
await self.outbound_message_queue.put(message_data)
|
||||
|
||||
# 选择会话
|
||||
session = self.websocket_group_session if session_type == 'group' else self.websocket_person_session
|
||||
|
||||
# 生成唯一消息ID
|
||||
msg_id = len(session.get_message_list(pipeline_uuid)) + 1
|
||||
|
||||
message_data = WebSocketMessage(
|
||||
id=msg_id,
|
||||
role='assistant',
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=datetime.now().isoformat(),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# 保存到历史记录
|
||||
session.get_message_list(pipeline_uuid).append(message_data)
|
||||
|
||||
# 直接广播到当前pipeline的连接
|
||||
await ws_connection_manager.broadcast_to_pipeline(
|
||||
pipeline_uuid,
|
||||
{
|
||||
'type': 'response',
|
||||
'session_type': session_type,
|
||||
'data': message_data.model_dump(),
|
||||
},
|
||||
session_type=session_type,
|
||||
)
|
||||
|
||||
return message_data.model_dump()
|
||||
return message_data
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
@@ -203,16 +169,10 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||||
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
|
||||
message_list = session.get_message_list(pipeline_uuid)
|
||||
stream_message_indexes = session.get_stream_message_indexes(pipeline_uuid)
|
||||
|
||||
# Streaming messages in LangBot have a stable resp_message_id during the same assistant reply.
|
||||
# Use it as the primary key to avoid overwriting an old card from a previous reply.
|
||||
resp_message_id = str(getattr(bot_message, 'resp_message_id', '') or '')
|
||||
existing_index = stream_message_indexes.get(resp_message_id) if resp_message_id else None
|
||||
|
||||
message_is_final = is_final and bot_message.tool_calls is None
|
||||
|
||||
if existing_index is None or existing_index >= len(message_list):
|
||||
# 检查是否是新的流式消息(通过bot_message对象判断)
|
||||
# 如果列表为空,或者最后一条消息已经is_final=True,则创建新消息
|
||||
if not message_list or message_list[-1].is_final:
|
||||
# 创建新消息
|
||||
msg_id = len(message_list) + 1
|
||||
message_data = WebSocketMessage(
|
||||
@@ -221,31 +181,27 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=datetime.now().isoformat(),
|
||||
is_final=message_is_final,
|
||||
is_final=is_final and bot_message.tool_calls is None,
|
||||
)
|
||||
|
||||
# 立即添加到历史记录(即使is_final=False),以便后续块可以更新它
|
||||
message_list.append(message_data)
|
||||
if resp_message_id:
|
||||
stream_message_indexes[resp_message_id] = len(message_list) - 1
|
||||
# 只有在is_final时才保存到历史记录
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_list.append(message_data)
|
||||
else:
|
||||
# 更新同一条流式消息
|
||||
old_message = message_list[existing_index]
|
||||
msg_id = old_message.id
|
||||
# 更新最后一条消息
|
||||
msg_id = message_list[-1].id
|
||||
message_data = WebSocketMessage(
|
||||
id=msg_id,
|
||||
role='assistant',
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=old_message.timestamp, # 保持原始时间戳
|
||||
is_final=message_is_final,
|
||||
timestamp=message_list[-1].timestamp, # 保持原始时间戳
|
||||
is_final=is_final and bot_message.tool_calls is None,
|
||||
)
|
||||
|
||||
# 更新历史记录中的对应消息
|
||||
message_list[existing_index] = message_data
|
||||
|
||||
if message_is_final and resp_message_id:
|
||||
stream_message_indexes.pop(resp_message_id, None)
|
||||
# 如果是final,更新历史记录中的最后一条
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_list[-1] = message_data
|
||||
|
||||
# 直接广播到所有该pipeline的连接,包含session_type信息
|
||||
await ws_connection_manager.broadcast_to_pipeline(
|
||||
@@ -454,10 +410,6 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
if session_type == 'person':
|
||||
if pipeline_uuid in self.websocket_person_session.message_lists:
|
||||
self.websocket_person_session.message_lists[pipeline_uuid] = []
|
||||
if pipeline_uuid in self.websocket_person_session.stream_message_indexes:
|
||||
self.websocket_person_session.stream_message_indexes[pipeline_uuid] = {}
|
||||
else:
|
||||
if pipeline_uuid in self.websocket_group_session.message_lists:
|
||||
self.websocket_group_session.message_lists[pipeline_uuid] = []
|
||||
if pipeline_uuid in self.websocket_group_session.stream_message_indexes:
|
||||
self.websocket_group_session.stream_message_indexes[pipeline_uuid] = {}
|
||||
|
||||
@@ -11,7 +11,6 @@ import langbot_plugin.api.entities.builtin.platform.entities as platform_entitie
|
||||
from ..logger import EventLogger
|
||||
from langbot.libs.wecom_ai_bot_api.wecombotevent import WecomBotEvent
|
||||
from langbot.libs.wecom_ai_bot_api.api import WecomBotClient
|
||||
from langbot.libs.wecom_ai_bot_api.ws_client import WecomBotWsClient
|
||||
|
||||
|
||||
class WecomBotMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
@@ -177,42 +176,27 @@ class WecomBotEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
|
||||
|
||||
class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: typing.Union[WecomBotClient, WecomBotWsClient]
|
||||
bot: WecomBotClient
|
||||
bot_account_id: str
|
||||
message_converter: WecomBotMessageConverter = WecomBotMessageConverter()
|
||||
event_converter: WecomBotEventConverter = WecomBotEventConverter()
|
||||
config: dict
|
||||
bot_uuid: str = None
|
||||
_ws_mode: bool = False
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
enable_webhook = config.get('enable-webhook', False)
|
||||
required_keys = ['Token', 'EncodingAESKey', 'Corpid', 'BotId']
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise Exception(f'WecomBot 缺少配置项: {missing_keys}')
|
||||
|
||||
if not enable_webhook:
|
||||
bot = WecomBotWsClient(
|
||||
bot_id=config['BotId'],
|
||||
secret=config['Secret'],
|
||||
logger=logger,
|
||||
encoding_aes_key=config.get('EncodingAESKey', ''),
|
||||
)
|
||||
ws_mode = True
|
||||
else:
|
||||
# Webhook callback mode
|
||||
required_keys = ['Token', 'EncodingAESKey', 'Corpid']
|
||||
missing_keys = [key for key in required_keys if key not in config or not config[key]]
|
||||
if missing_keys:
|
||||
raise Exception(f'WecomBot webhook mode missing config: {missing_keys}')
|
||||
|
||||
bot = WecomBotClient(
|
||||
Token=config['Token'],
|
||||
EnCodingAESKey=config['EncodingAESKey'],
|
||||
Corpid=config['Corpid'],
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
)
|
||||
ws_mode = False
|
||||
|
||||
bot_account_id = config.get('BotId', '')
|
||||
bot = WecomBotClient(
|
||||
Token=config['Token'],
|
||||
EnCodingAESKey=config['EncodingAESKey'],
|
||||
Corpid=config['Corpid'],
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
)
|
||||
bot_account_id = config['BotId']
|
||||
|
||||
super().__init__(
|
||||
config=config,
|
||||
@@ -220,7 +204,6 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot=bot,
|
||||
bot_account_id=bot_account_id,
|
||||
)
|
||||
self._ws_mode = ws_mode
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
@@ -229,15 +212,7 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
content = await self.message_converter.yiri2target(message)
|
||||
if self._ws_mode:
|
||||
event = message_source.source_platform_object
|
||||
req_id = event.get('req_id', '')
|
||||
if req_id:
|
||||
await self.bot.reply_text(req_id, content)
|
||||
else:
|
||||
await self.bot.set_message(event.message_id, content)
|
||||
else:
|
||||
await self.bot.set_message(message_source.source_platform_object.message_id, content)
|
||||
await self.bot.set_message(message_source.source_platform_object.message_id, content)
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
@@ -247,22 +222,31 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
):
|
||||
"""将流水线增量输出写入企业微信 stream 会话。
|
||||
|
||||
Args:
|
||||
message_source: 流水线提供的原始消息事件。
|
||||
bot_message: 当前片段对应的模型元信息(未使用)。
|
||||
message: 需要回复的消息链。
|
||||
quote_origin: 是否引用原消息(企业微信暂不支持)。
|
||||
is_final: 标记当前片段是否为最终回复。
|
||||
|
||||
Returns:
|
||||
dict: 包含 `stream` 键,标识写入是否成功。
|
||||
|
||||
Example:
|
||||
在流水线 `reply_message_chunk` 调用中自动触发,无需手动调用。
|
||||
"""
|
||||
# 转换为纯文本(智能机器人当前协议仅支持文本流)
|
||||
content = await self.message_converter.yiri2target(message)
|
||||
msg_id = message_source.source_platform_object.message_id
|
||||
|
||||
if self._ws_mode:
|
||||
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
|
||||
if not success and is_final:
|
||||
event = message_source.source_platform_object
|
||||
req_id = event.get('req_id', '')
|
||||
if req_id:
|
||||
await self.bot.reply_text(req_id, content)
|
||||
return {'stream': success}
|
||||
else:
|
||||
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
|
||||
if not success and is_final:
|
||||
await self.bot.set_message(msg_id, content)
|
||||
return {'stream': success}
|
||||
# 将片段推送到 WecomBotClient 中的队列,返回值用于判断是否走降级逻辑
|
||||
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
|
||||
if not success and is_final:
|
||||
# 未命中流式队列时使用旧有 set_message 兜底
|
||||
await self.bot.set_message(msg_id, content)
|
||||
return {'stream': success}
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
"""智能机器人侧默认开启流式能力。
|
||||
@@ -275,11 +259,7 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
return True
|
||||
|
||||
async def send_message(self, target_type, target_id, message):
|
||||
if self._ws_mode:
|
||||
content = await self.message_converter.yiri2target(message)
|
||||
await self.bot.send_message(target_id, content)
|
||||
else:
|
||||
pass
|
||||
pass
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
@@ -308,25 +288,29 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
if self._ws_mode:
|
||||
return None
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
async def run_async(self):
|
||||
if self._ws_mode:
|
||||
await self.bot.connect()
|
||||
else:
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
# 保持运行但不启动独立端口
|
||||
|
||||
async def keep_alive():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
async def keep_alive():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await keep_alive()
|
||||
await keep_alive()
|
||||
|
||||
async def kill(self) -> bool:
|
||||
if self._ws_mode:
|
||||
await self.bot.disconnect()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def unregister_listener(
|
||||
|
||||
@@ -11,64 +11,35 @@ metadata:
|
||||
icon: wecombot.png
|
||||
spec:
|
||||
config:
|
||||
- name: BotId
|
||||
label:
|
||||
en_US: BotId
|
||||
zh_Hans: 机器人ID (BotId)
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: enable-webhook
|
||||
label:
|
||||
en_US: Enable Webhook Mode
|
||||
zh_Hans: 启用Webhook模式
|
||||
description:
|
||||
en_US: If enabled, the bot will use webhook mode to receive messages. Otherwise, it will use WS long connection mode
|
||||
zh_Hans: 如果启用,机器人将使用 Webhook 模式接收消息。否则,将使用 WS 长连接模式
|
||||
type: boolean
|
||||
required: true
|
||||
default: false
|
||||
- name: Secret
|
||||
label:
|
||||
en_US: Secret
|
||||
zh_Hans: 机器人密钥 (Secret)
|
||||
description:
|
||||
en_US: Required for WebSocket long connection mode
|
||||
zh_Hans: 使用 WS 长连接模式时必填
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
- name: Corpid
|
||||
label:
|
||||
en_US: Corpid
|
||||
zh_Hans: 企业ID
|
||||
description:
|
||||
en_US: Required for Webhook mode
|
||||
zh_Hans: 使用 Webhook 模式时必填
|
||||
type: string
|
||||
required: false
|
||||
required: true
|
||||
default: ""
|
||||
- name: Token
|
||||
label:
|
||||
en_US: Token
|
||||
zh_Hans: 令牌 (Token)
|
||||
description:
|
||||
en_US: Required for Webhook mode
|
||||
zh_Hans: 使用 Webhook 模式时必填
|
||||
type: string
|
||||
required: false
|
||||
required: true
|
||||
default: ""
|
||||
- name: EncodingAESKey
|
||||
label:
|
||||
en_US: EncodingAESKey
|
||||
zh_Hans: 消息加解密密钥 (EncodingAESKey)
|
||||
description:
|
||||
en_US: Required for Webhook mode. Optional for WebSocket mode (used for file decryption)
|
||||
zh_Hans: 使用 Webhook 模式时必填。WebSocket 模式下可选(用于文件解密)
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: BotId
|
||||
label:
|
||||
en_US: BotId
|
||||
zh_Hans: 机器人ID
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
execution:
|
||||
python:
|
||||
path: ./wecombot.py
|
||||
attr: WecomBotAdapter
|
||||
attr: WecomBotAdapter
|
||||
@@ -81,33 +81,22 @@ class WecomEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
return event.source_platform_object
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(event: WecomCSEvent, bot: WecomCSClient = None):
|
||||
async def target2yiri(event: WecomCSEvent):
|
||||
"""
|
||||
将 WecomEvent 转换为平台的 FriendMessage 对象。
|
||||
|
||||
Args:
|
||||
event (WecomEvent): 企业微信客服事件。
|
||||
bot (WecomCSClient): 企业微信客服客户端,用于获取用户信息。
|
||||
|
||||
Returns:
|
||||
platform_events.FriendMessage: 转换后的 FriendMessage 对象。
|
||||
"""
|
||||
# Try to get customer nickname from WeChat API
|
||||
nickname = str(event.user_id)
|
||||
if bot and event.user_id:
|
||||
try:
|
||||
customer_info = await bot.get_customer_info(event.user_id)
|
||||
if customer_info and customer_info.get('nickname'):
|
||||
nickname = customer_info.get('nickname')
|
||||
except Exception:
|
||||
pass # Fall back to user_id as nickname
|
||||
|
||||
# 转换消息链
|
||||
if event.type == 'text':
|
||||
yiri_chain = await WecomMessageConverter.target2yiri(event.message, event.message_id)
|
||||
friend = platform_entities.Friend(
|
||||
id=f'u{event.user_id}',
|
||||
nickname=nickname,
|
||||
nickname=str(event.user_id),
|
||||
remark='',
|
||||
)
|
||||
|
||||
@@ -117,7 +106,7 @@ class WecomEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
elif event.type == 'image':
|
||||
friend = platform_entities.Friend(
|
||||
id=f'u{event.user_id}',
|
||||
nickname=nickname,
|
||||
nickname=str(event.user_id),
|
||||
remark='',
|
||||
)
|
||||
|
||||
@@ -198,7 +187,7 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
async def on_message(event: WecomCSEvent):
|
||||
self.bot_account_id = event.receiver_id
|
||||
try:
|
||||
return await callback(await self.event_converter.target2yiri(event, self.bot), self)
|
||||
return await callback(await self.event_converter.target2yiri(event), self)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in wecomcs callback: {traceback.format_exc()}')
|
||||
|
||||
|
||||
@@ -3,8 +3,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import aiohttp
|
||||
|
||||
from langbot.pkg.utils import httpclient
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -121,23 +119,23 @@ class WebhookPusher:
|
||||
dict | None: The response JSON if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
session = httpclient.get_session()
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={'Content-Type': 'application/json'},
|
||||
timeout=aiohttp.ClientTimeout(total=15),
|
||||
) as response:
|
||||
if response.status >= 400:
|
||||
self.logger.warning(f'Webhook {url} returned status {response.status}')
|
||||
return None
|
||||
else:
|
||||
self.logger.debug(f'Successfully pushed to webhook {url}')
|
||||
try:
|
||||
return await response.json()
|
||||
except Exception as json_error:
|
||||
self.logger.debug(f'Failed to parse JSON response from webhook {url}: {json_error}')
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={'Content-Type': 'application/json'},
|
||||
timeout=aiohttp.ClientTimeout(total=15),
|
||||
) as response:
|
||||
if response.status >= 400:
|
||||
self.logger.warning(f'Webhook {url} returned status {response.status}')
|
||||
return None
|
||||
else:
|
||||
self.logger.debug(f'Successfully pushed to webhook {url}')
|
||||
try:
|
||||
return await response.json()
|
||||
except Exception as json_error:
|
||||
self.logger.debug(f'Failed to parse JSON response from webhook {url}: {json_error}')
|
||||
return None
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning(f'Timeout pushing to webhook {url}')
|
||||
return None
|
||||
|
||||
@@ -7,6 +7,7 @@ import typing
|
||||
import os
|
||||
import sys
|
||||
import httpx
|
||||
import traceback
|
||||
import sqlalchemy
|
||||
from async_lru import alru_cache
|
||||
from langbot_plugin.api.entities.builtin.pipeline.query import provider_session
|
||||
@@ -101,6 +102,12 @@ class PluginRuntimeConnector:
|
||||
self.handler_task = asyncio.create_task(self.handler.run())
|
||||
_ = await self.handler.ping()
|
||||
self.ap.logger.info('Connected to plugin runtime.')
|
||||
# Sync polymorphic component instances after connection
|
||||
try:
|
||||
await self.sync_polymorphic_component_instances()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
self.ap.logger.error(f'Failed to sync polymorphic component instances: {e}')
|
||||
await self.handler_task
|
||||
|
||||
task: asyncio.Task | None = None
|
||||
@@ -456,18 +463,30 @@ class PluginRuntimeConnector:
|
||||
|
||||
yield cmd_ret
|
||||
|
||||
# KnowledgeRetriever methods
|
||||
async def list_knowledge_retrievers(self, bound_plugins: list[str] | None = None) -> list[dict[str, Any]]:
|
||||
"""List all available KnowledgeRetriever components."""
|
||||
if not self.is_enable_plugin:
|
||||
return []
|
||||
|
||||
retrievers_data = await self.handler.list_knowledge_retrievers(include_plugins=bound_plugins)
|
||||
return retrievers_data
|
||||
|
||||
async def retrieve_knowledge(
|
||||
self,
|
||||
plugin_author: str,
|
||||
plugin_name: str,
|
||||
retriever_name: str,
|
||||
instance_id: str,
|
||||
retrieval_context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Retrieve knowledge using a KnowledgeEngine instance."""
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Retrieve knowledge using a KnowledgeRetriever instance."""
|
||||
if not self.is_enable_plugin:
|
||||
return {'results': []}
|
||||
return []
|
||||
|
||||
return await self.handler.retrieve_knowledge(plugin_author, plugin_name, retriever_name, retrieval_context)
|
||||
return await self.handler.retrieve_knowledge(
|
||||
plugin_author, plugin_name, retriever_name, instance_id, retrieval_context
|
||||
)
|
||||
|
||||
def dispose(self):
|
||||
# No need to consider the shutdown on Windows
|
||||
@@ -481,84 +500,41 @@ class PluginRuntimeConnector:
|
||||
self.heartbeat_task.cancel()
|
||||
self.heartbeat_task = None
|
||||
|
||||
@staticmethod
|
||||
def _parse_plugin_id(plugin_id: str) -> tuple[str, str]:
|
||||
"""Parse a plugin ID string into (author, name).
|
||||
async def sync_polymorphic_component_instances(self) -> dict[str, Any]:
|
||||
"""Sync polymorphic component instances with runtime.
|
||||
|
||||
Args:
|
||||
plugin_id: Plugin ID in 'author/name' format.
|
||||
|
||||
Returns:
|
||||
Tuple of (plugin_author, plugin_name).
|
||||
|
||||
Raises:
|
||||
ValueError: If plugin_id is not in the expected 'author/name' format.
|
||||
This collects all external knowledge bases from database and sends to runtime
|
||||
to ensure instance integrity across restarts.
|
||||
"""
|
||||
if '/' not in plugin_id:
|
||||
raise ValueError(
|
||||
f"Invalid plugin_id format: '{plugin_id}'. Expected 'author/name' format (e.g. 'langbot/rag-engine')."
|
||||
if not self.is_enable_plugin:
|
||||
return {}
|
||||
|
||||
# ===== external knowledge bases =====
|
||||
|
||||
external_kbs = await self.ap.external_kb_service.get_external_knowledge_bases()
|
||||
|
||||
# Build required_instances list
|
||||
required_instances = []
|
||||
for kb in external_kbs:
|
||||
required_instances.append(
|
||||
{
|
||||
'instance_id': kb['uuid'],
|
||||
'plugin_author': kb['plugin_author'],
|
||||
'plugin_name': kb['plugin_name'],
|
||||
'component_kind': 'KnowledgeRetriever',
|
||||
'component_name': kb['retriever_name'],
|
||||
'config': kb['retriever_config'],
|
||||
}
|
||||
)
|
||||
return plugin_id.split('/', 1)
|
||||
|
||||
async def call_rag_ingest(self, plugin_id: str, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call plugin to ingest document.
|
||||
self.ap.logger.info(f'Syncing {len(required_instances)} polymorphic component instances to runtime')
|
||||
|
||||
Args:
|
||||
plugin_id: Target plugin ID (author/name).
|
||||
context_data: IngestionContext data.
|
||||
"""
|
||||
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||
return await self.handler.rag_ingest_document(plugin_author, plugin_name, context_data)
|
||||
# Send to runtime
|
||||
sync_result = await self.handler.sync_polymorphic_component_instances(required_instances)
|
||||
|
||||
async def call_rag_delete_document(self, plugin_id: str, document_id: str, kb_id: str) -> bool:
|
||||
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||
return await self.handler.rag_delete_document(plugin_author, plugin_name, document_id, kb_id)
|
||||
self.ap.logger.info(
|
||||
f'Sync complete: {len(sync_result.get("success_instances", []))} succeeded, '
|
||||
f'{len(sync_result.get("failed_instances", []))} failed'
|
||||
)
|
||||
|
||||
async def get_rag_creation_schema(self, plugin_id: str) -> dict[str, Any]:
|
||||
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||
return await self.handler.get_rag_creation_schema(plugin_author, plugin_name)
|
||||
|
||||
async def get_rag_retrieval_schema(self, plugin_id: str) -> dict[str, Any]:
|
||||
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||
return await self.handler.get_rag_retrieval_schema(plugin_author, plugin_name)
|
||||
|
||||
async def rag_on_kb_create(self, plugin_id: str, kb_id: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Notify plugin about KB creation."""
|
||||
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||
return await self.handler.rag_on_kb_create(plugin_author, plugin_name, kb_id, config)
|
||||
|
||||
async def rag_on_kb_delete(self, plugin_id: str, kb_id: str) -> dict[str, Any]:
|
||||
"""Notify plugin about KB deletion."""
|
||||
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||
return await self.handler.rag_on_kb_delete(plugin_author, plugin_name, kb_id)
|
||||
|
||||
async def call_rag_retrieve(self, plugin_id: str, retrieval_context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call plugin to retrieve knowledge.
|
||||
|
||||
Args:
|
||||
plugin_id: Target plugin ID (author/name).
|
||||
retrieval_context: RetrievalContext data.
|
||||
"""
|
||||
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||
return await self.handler.retrieve_knowledge(plugin_author, plugin_name, '', retrieval_context)
|
||||
|
||||
async def list_knowledge_engines(self) -> list[dict[str, Any]]:
|
||||
"""List all available Knowledge Engines from plugins.
|
||||
|
||||
Returns a list of Knowledge Engines with their capabilities and configuration schemas.
|
||||
"""
|
||||
if not self.is_enable_plugin:
|
||||
return []
|
||||
|
||||
return await self.handler.list_knowledge_engines()
|
||||
|
||||
async def list_parsers(self) -> list[dict[str, Any]]:
|
||||
"""List all available parsers from plugins."""
|
||||
if not self.is_enable_plugin:
|
||||
return []
|
||||
return await self.handler.list_parsers()
|
||||
|
||||
async def call_parser(self, plugin_id: str, context_data: dict[str, Any], file_bytes: bytes) -> dict[str, Any]:
|
||||
"""Call plugin to parse a document."""
|
||||
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||
return await self.handler.parse_document(plugin_author, plugin_name, context_data, file_bytes)
|
||||
return sync_result
|
||||
|
||||
@@ -26,20 +26,6 @@ from ..core import app
|
||||
from ..utils import constants
|
||||
|
||||
|
||||
def _make_rag_error_response(error: Exception, error_type: str, **extra_context) -> handler.ActionResponse:
|
||||
"""Create a clean error response for RAG operations.
|
||||
|
||||
Args:
|
||||
error: The caught exception.
|
||||
error_type: A category string like 'EmbeddingError', 'VectorStoreError'.
|
||||
**extra_context: Additional context fields for the error message.
|
||||
"""
|
||||
context_parts = [f'{k}={v}' for k, v in extra_context.items()]
|
||||
context_str = f' [{", ".join(context_parts)}]' if context_parts else ''
|
||||
message = f'[{error_type}/{type(error).__name__}]{context_str} {str(error)}'
|
||||
return handler.ActionResponse.error(message=message)
|
||||
|
||||
|
||||
class RuntimeConnectionHandler(handler.Handler):
|
||||
"""Runtime connection handler"""
|
||||
|
||||
@@ -293,7 +279,6 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
target_id = data['target_id']
|
||||
message_chain = data['message_chain']
|
||||
|
||||
# Use custom deserializer that properly handles Forward messages
|
||||
message_chain_obj = platform_message.MessageChain.model_validate(message_chain)
|
||||
|
||||
bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
|
||||
@@ -337,14 +322,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
)
|
||||
|
||||
messages_obj = [provider_message.Message.model_validate(message) for message in messages]
|
||||
|
||||
# The func field is excluded during model_dump() in plugin side (marked as exclude=True),
|
||||
# but it's a required field for LLMTool validation. We need to provide a placeholder
|
||||
# function when reconstructing the LLMTool objects from serialized data.
|
||||
async def _placeholder_func(**kwargs):
|
||||
pass
|
||||
|
||||
funcs_obj = [resource_tool.LLMTool.model_validate({**func, 'func': _placeholder_func}) for func in funcs]
|
||||
funcs_obj = [resource_tool.LLMTool.model_validate(func) for func in funcs]
|
||||
|
||||
result = await llm_model.provider.invoke_llm(
|
||||
query=None,
|
||||
@@ -460,7 +438,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
},
|
||||
)
|
||||
|
||||
@self.action(PluginToRuntimeAction.GET_CONFIG_FILE)
|
||||
@self.action(RuntimeToLangBotAction.GET_CONFIG_FILE)
|
||||
async def get_config_file(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Get a config file by file key"""
|
||||
file_key = data['file_key']
|
||||
@@ -479,223 +457,6 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
message=f'Failed to load config file {file_key}: {e}',
|
||||
)
|
||||
|
||||
# ================= RAG Capability Handlers =================
|
||||
|
||||
@self.action(PluginToRuntimeAction.INVOKE_EMBEDDING)
|
||||
async def invoke_embedding(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
embedding_model_uuid = data['embedding_model_uuid']
|
||||
texts = data['texts']
|
||||
|
||||
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(embedding_model_uuid)
|
||||
if embedding_model is None:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Embedding model with embedding_model_uuid {embedding_model_uuid} not found',
|
||||
)
|
||||
|
||||
try:
|
||||
vectors = await embedding_model.provider.invoke_embedding(embedding_model, texts)
|
||||
return handler.ActionResponse.success(data={'vectors': vectors})
|
||||
except Exception as e:
|
||||
return _make_rag_error_response(e, 'EmbeddingError', embedding_model_uuid=embedding_model_uuid)
|
||||
|
||||
@self.action(PluginToRuntimeAction.VECTOR_UPSERT)
|
||||
async def vector_upsert(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
collection_id = data['collection_id']
|
||||
vectors = data['vectors']
|
||||
ids = data['ids']
|
||||
metadata = data.get('metadata')
|
||||
documents = data.get('documents')
|
||||
if len(vectors) != len(ids):
|
||||
return handler.ActionResponse.error(message='vectors and ids must have same length')
|
||||
if metadata and len(metadata) != len(vectors):
|
||||
return handler.ActionResponse.error(message='metadata must match vectors length')
|
||||
if documents and len(documents) != len(vectors):
|
||||
return handler.ActionResponse.error(message='documents must match vectors length')
|
||||
try:
|
||||
await self.ap.rag_runtime_service.vector_upsert(
|
||||
collection_id,
|
||||
vectors,
|
||||
ids,
|
||||
metadata,
|
||||
documents,
|
||||
)
|
||||
return handler.ActionResponse.success(data={})
|
||||
except Exception as e:
|
||||
return _make_rag_error_response(e, 'VectorStoreError', collection_id=collection_id)
|
||||
|
||||
@self.action(PluginToRuntimeAction.VECTOR_SEARCH)
|
||||
async def vector_search(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
collection_id = data['collection_id']
|
||||
query_vector = data['query_vector']
|
||||
top_k = data['top_k']
|
||||
filters = data.get('filters')
|
||||
search_type = data.get('search_type', 'vector')
|
||||
query_text = data.get('query_text', '')
|
||||
try:
|
||||
results = await self.ap.rag_runtime_service.vector_search(
|
||||
collection_id,
|
||||
query_vector,
|
||||
top_k,
|
||||
filters,
|
||||
search_type,
|
||||
query_text,
|
||||
)
|
||||
return handler.ActionResponse.success(data={'results': results})
|
||||
except Exception as e:
|
||||
return _make_rag_error_response(e, 'VectorStoreError', collection_id=collection_id)
|
||||
|
||||
@self.action(PluginToRuntimeAction.VECTOR_DELETE)
|
||||
async def vector_delete(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
collection_id = data['collection_id']
|
||||
file_ids = data.get('file_ids')
|
||||
filters = data.get('filters')
|
||||
try:
|
||||
count = await self.ap.rag_runtime_service.vector_delete(collection_id, file_ids, filters)
|
||||
return handler.ActionResponse.success(data={'count': count})
|
||||
except Exception as e:
|
||||
return _make_rag_error_response(e, 'VectorStoreError', collection_id=collection_id)
|
||||
|
||||
@self.action(PluginToRuntimeAction.GET_KNOWLEDEGE_FILE_STREAM)
|
||||
async def get_knowledge_file_stream(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
storage_path = data['storage_path']
|
||||
try:
|
||||
content_bytes = await self.ap.rag_runtime_service.get_file_stream(storage_path)
|
||||
file_key = await self.send_file(content_bytes, '')
|
||||
return handler.ActionResponse.success(data={'file_key': file_key})
|
||||
except Exception as e:
|
||||
return _make_rag_error_response(e, 'FileServiceError', storage_path=storage_path)
|
||||
|
||||
@self.action(PluginToRuntimeAction.LIST_PARSERS)
|
||||
async def list_parsers(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Plugin requests host to list available parser plugins."""
|
||||
mime_type = data.get('mime_type')
|
||||
try:
|
||||
parsers = await self.ap.knowledge_service.list_parsers(mime_type)
|
||||
return handler.ActionResponse.success(data={'parsers': parsers})
|
||||
except Exception as e:
|
||||
return _make_rag_error_response(e, 'ParserDiscoveryError', mime_type=mime_type)
|
||||
|
||||
@self.action(PluginToRuntimeAction.INVOKE_PARSER)
|
||||
async def invoke_parser(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Plugin requests host to invoke a parser plugin."""
|
||||
plugin_author = data['plugin_author']
|
||||
plugin_name = data['plugin_name']
|
||||
storage_path = data['storage_path']
|
||||
mime_type = data.get('mime_type', 'application/octet-stream')
|
||||
filename = data.get('filename', '')
|
||||
metadata = data.get('metadata', {})
|
||||
try:
|
||||
# Read file from storage
|
||||
file_bytes = await self.ap.rag_runtime_service.get_file_stream(storage_path)
|
||||
context_data = {
|
||||
'mime_type': mime_type,
|
||||
'filename': filename,
|
||||
'metadata': metadata,
|
||||
}
|
||||
result = await self.ap.plugin_connector.call_parser(
|
||||
f'{plugin_author}/{plugin_name}', context_data, file_bytes
|
||||
)
|
||||
return handler.ActionResponse.success(data=result)
|
||||
except Exception as e:
|
||||
return _make_rag_error_response(e, 'ParserError')
|
||||
|
||||
# ================= Knowledge Base Query APIs =================
|
||||
|
||||
@self.action(PluginToRuntimeAction.LIST_PIPELINE_KNOWLEDGE_BASES)
|
||||
async def list_pipeline_knowledge_bases(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""List knowledge bases configured for the current query's pipeline."""
|
||||
query_id = data['query_id']
|
||||
|
||||
if query_id not in self.ap.query_pool.cached_queries:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Query with query_id {query_id} not found',
|
||||
)
|
||||
|
||||
query = self.ap.query_pool.cached_queries[query_id]
|
||||
|
||||
kb_uuids = []
|
||||
if query.pipeline_config:
|
||||
local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {})
|
||||
kb_uuids = local_agent_config.get('knowledge-bases', [])
|
||||
# Backward compatibility
|
||||
if not kb_uuids:
|
||||
old_kb_uuid = local_agent_config.get('knowledge-base', '')
|
||||
if old_kb_uuid and old_kb_uuid != '__none__':
|
||||
kb_uuids = [old_kb_uuid]
|
||||
|
||||
knowledge_bases = []
|
||||
for kb_uuid in kb_uuids:
|
||||
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
if kb:
|
||||
knowledge_bases.append(
|
||||
{
|
||||
'uuid': kb.get_uuid(),
|
||||
'name': kb.get_name(),
|
||||
'description': kb.knowledge_base_entity.description or '',
|
||||
}
|
||||
)
|
||||
|
||||
return handler.ActionResponse.success(data={'knowledge_bases': knowledge_bases})
|
||||
|
||||
@self.action(PluginToRuntimeAction.RETRIEVE_KNOWLEDGE_BASE)
|
||||
async def retrieve_knowledge_base(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Retrieve documents from a knowledge base within the pipeline's scope."""
|
||||
query_id = data['query_id']
|
||||
kb_id = data['kb_id']
|
||||
query_text = data['query_text']
|
||||
top_k = data.get('top_k', 5)
|
||||
filters = data.get('filters', {})
|
||||
|
||||
if query_id not in self.ap.query_pool.cached_queries:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Query with query_id {query_id} not found',
|
||||
)
|
||||
|
||||
query = self.ap.query_pool.cached_queries[query_id]
|
||||
|
||||
# Validate kb_id is in pipeline's allowed list
|
||||
allowed_kb_uuids = []
|
||||
if query.pipeline_config:
|
||||
local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {})
|
||||
allowed_kb_uuids = local_agent_config.get('knowledge-bases', [])
|
||||
if not allowed_kb_uuids:
|
||||
old_kb_uuid = local_agent_config.get('knowledge-base', '')
|
||||
if old_kb_uuid and old_kb_uuid != '__none__':
|
||||
allowed_kb_uuids = [old_kb_uuid]
|
||||
|
||||
if kb_id not in allowed_kb_uuids:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Knowledge base {kb_id} is not configured for this pipeline',
|
||||
)
|
||||
|
||||
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_id)
|
||||
if not kb:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Knowledge base {kb_id} not found',
|
||||
)
|
||||
|
||||
try:
|
||||
entries = await kb.retrieve(
|
||||
query_text,
|
||||
settings={
|
||||
'top_k': top_k,
|
||||
'filters': filters,
|
||||
},
|
||||
)
|
||||
results = [entry.model_dump(mode='json') for entry in entries]
|
||||
return handler.ActionResponse.success(data={'results': results})
|
||||
except Exception as e:
|
||||
return _make_rag_error_response(e, 'RetrievalError', kb_id=kb_id)
|
||||
|
||||
@self.action(CommonAction.PING)
|
||||
async def ping(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Ping"""
|
||||
return handler.ActionResponse.success(
|
||||
data={
|
||||
'pong': 'pong',
|
||||
},
|
||||
)
|
||||
|
||||
async def ping(self) -> dict[str, Any]:
|
||||
"""Ping the runtime"""
|
||||
return await self.call_action(
|
||||
@@ -955,13 +716,26 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
async for ret in gen:
|
||||
yield ret
|
||||
|
||||
# KnowledgeRetriever methods
|
||||
async def list_knowledge_retrievers(self, include_plugins: list[str] | None = None) -> list[dict[str, Any]]:
|
||||
"""List knowledge retrievers"""
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.LIST_KNOWLEDGE_RETRIEVERS,
|
||||
{
|
||||
'include_plugins': include_plugins,
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
return result['retrievers']
|
||||
|
||||
async def retrieve_knowledge(
|
||||
self,
|
||||
plugin_author: str,
|
||||
plugin_name: str,
|
||||
retriever_name: str,
|
||||
instance_id: str,
|
||||
retrieval_context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Retrieve knowledge"""
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.RETRIEVE_KNOWLEDGE,
|
||||
@@ -969,10 +743,22 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
'plugin_author': plugin_author,
|
||||
'plugin_name': plugin_name,
|
||||
'retriever_name': retriever_name,
|
||||
'instance_id': instance_id,
|
||||
'retrieval_context': retrieval_context,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
return result['retrieval_results']
|
||||
|
||||
async def sync_polymorphic_component_instances(self, required_instances: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""Sync polymorphic component instances with runtime"""
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.SYNC_POLYMORPHIC_COMPONENT_INSTANCES,
|
||||
{
|
||||
'required_instances': required_instances,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_debug_info(self) -> dict[str, Any]:
|
||||
@@ -983,91 +769,3 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
timeout=10,
|
||||
)
|
||||
return result
|
||||
|
||||
# ================= RAG Capability Callers (LangBot -> Runtime) =================
|
||||
|
||||
async def rag_ingest_document(
|
||||
self, plugin_author: str, plugin_name: str, context_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Send INGEST_DOCUMENT action to runtime."""
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.RAG_INGEST_DOCUMENT,
|
||||
{'plugin_author': plugin_author, 'plugin_name': plugin_name, 'context': context_data},
|
||||
timeout=1200, # Ingestion can be slow for large documents
|
||||
)
|
||||
return result
|
||||
|
||||
async def rag_delete_document(self, plugin_author: str, plugin_name: str, document_id: str, kb_id: str) -> bool:
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.RAG_DELETE_DOCUMENT,
|
||||
{'plugin_author': plugin_author, 'plugin_name': plugin_name, 'document_id': document_id, 'kb_id': kb_id},
|
||||
timeout=30,
|
||||
)
|
||||
return result.get('success', False)
|
||||
|
||||
async def rag_on_kb_create(
|
||||
self, plugin_author: str, plugin_name: str, kb_id: str, config: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Notify plugin about KB creation."""
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.RAG_ON_KB_CREATE,
|
||||
{'plugin_author': plugin_author, 'plugin_name': plugin_name, 'kb_id': kb_id, 'config': config},
|
||||
timeout=30,
|
||||
)
|
||||
return result
|
||||
|
||||
async def rag_on_kb_delete(self, plugin_author: str, plugin_name: str, kb_id: str) -> dict[str, Any]:
|
||||
"""Notify plugin about KB deletion."""
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.RAG_ON_KB_DELETE,
|
||||
{'plugin_author': plugin_author, 'plugin_name': plugin_name, 'kb_id': kb_id},
|
||||
timeout=30,
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_rag_creation_schema(self, plugin_author: str, plugin_name: str) -> dict[str, Any]:
|
||||
return await self.call_action(
|
||||
LangBotToRuntimeAction.GET_RAG_CREATION_SETTINGS_SCHEMA,
|
||||
{'plugin_author': plugin_author, 'plugin_name': plugin_name},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
async def get_rag_retrieval_schema(self, plugin_author: str, plugin_name: str) -> dict[str, Any]:
|
||||
return await self.call_action(
|
||||
LangBotToRuntimeAction.GET_RAG_RETRIEVAL_SETTINGS_SCHEMA,
|
||||
{'plugin_author': plugin_author, 'plugin_name': plugin_name},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
async def list_knowledge_engines(self) -> list[dict[str, Any]]:
|
||||
"""List all available Knowledge Engines from plugins."""
|
||||
result = await self.call_action(LangBotToRuntimeAction.LIST_KNOWLEDGE_ENGINES, {}, timeout=60)
|
||||
return result.get('engines', [])
|
||||
|
||||
# ================= Parser Capability Callers (LangBot -> Runtime) =================
|
||||
|
||||
async def list_parsers(self) -> list[dict[str, Any]]:
|
||||
"""List all available parsers from plugins."""
|
||||
result = await self.call_action(LangBotToRuntimeAction.LIST_PARSERS, {}, timeout=60)
|
||||
return result.get('parsers', [])
|
||||
|
||||
async def parse_document(
|
||||
self, plugin_author: str, plugin_name: str, context_data: dict[str, Any], file_bytes: bytes
|
||||
) -> dict[str, Any]:
|
||||
"""Send PARSE_DOCUMENT action to runtime.
|
||||
|
||||
Sends file content via chunked FILE_CHUNK transfer, then invokes
|
||||
the PARSE_DOCUMENT action with a file_key reference.
|
||||
"""
|
||||
# Send file to runtime via chunked transfer
|
||||
file_key = await self.send_file(file_bytes, '')
|
||||
|
||||
# Include file_key in context_data for the runtime to read
|
||||
context_data['file_key'] = file_key
|
||||
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.PARSE_DOCUMENT,
|
||||
{'plugin_author': plugin_author, 'plugin_name': plugin_name, 'context': context_data},
|
||||
timeout=300,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -72,28 +72,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
content = f'<think>\n{thinking_content}\n</think>\n{content}'.strip()
|
||||
return content, thinking_content
|
||||
|
||||
def _extract_dify_text_output(self, value: typing.Any) -> str:
|
||||
"""Extract text content from Dify output payload."""
|
||||
if value is None:
|
||||
return ''
|
||||
if isinstance(value, dict):
|
||||
content = value.get('content')
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
if isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return ''
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
if isinstance(parsed, dict) and isinstance(parsed.get('content'), str):
|
||||
return parsed['content']
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[dict]]:
|
||||
"""预处理用户消息,提取纯文本,并将图片/文件上传到 Dify 服务
|
||||
|
||||
@@ -214,8 +192,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
if mode == 'workflow':
|
||||
if chunk['event'] == 'node_finished':
|
||||
if chunk['data']['node_type'] == 'answer':
|
||||
answer = self._extract_dify_text_output(chunk['data']['outputs'].get('answer'))
|
||||
content, _ = self._process_thinking_content(answer)
|
||||
content, _ = self._process_thinking_content(chunk['data']['outputs']['answer'])
|
||||
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
@@ -428,7 +405,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
for f in upload_files
|
||||
]
|
||||
|
||||
mode = 'basic'
|
||||
basic_mode_pending_chunk = ''
|
||||
|
||||
inputs = {}
|
||||
@@ -441,7 +417,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
is_final = False
|
||||
think_start = False
|
||||
think_end = False
|
||||
yielded_final = False
|
||||
|
||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
|
||||
@@ -455,12 +430,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
):
|
||||
self.ap.logger.debug('dify-chat-chunk: ' + str(chunk))
|
||||
|
||||
if chunk['event'] == 'workflow_started':
|
||||
mode = 'workflow'
|
||||
elif chunk['event'] in ('node_started', 'node_finished', 'workflow_finished'):
|
||||
# Some Dify deployments may omit workflow_started in streamed chunks.
|
||||
mode = 'workflow'
|
||||
|
||||
# if chunk['event'] == 'workflow_started':
|
||||
# mode = 'workflow'
|
||||
# if mode == 'workflow':
|
||||
# elif mode == 'basic':
|
||||
# 因为都只是返回的 message也没有工具调用什么的,暂时不分类
|
||||
if chunk['event'] == 'message':
|
||||
message_idx += 1
|
||||
if remove_think:
|
||||
@@ -483,30 +457,14 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
if chunk['event'] == 'message_end':
|
||||
is_final = True
|
||||
elif chunk['event'] == 'workflow_finished':
|
||||
is_final = True
|
||||
if chunk['data'].get('error'):
|
||||
raise errors.DifyAPIError(chunk['data']['error'])
|
||||
|
||||
if mode == 'workflow' and chunk['event'] == 'node_finished':
|
||||
if chunk['data'].get('node_type') == 'answer':
|
||||
answer = self._extract_dify_text_output(chunk['data'].get('outputs', {}).get('answer'))
|
||||
if answer:
|
||||
basic_mode_pending_chunk = answer
|
||||
|
||||
if (
|
||||
not yielded_final
|
||||
and (is_final or message_idx % 8 == 0)
|
||||
and (basic_mode_pending_chunk != '' or is_final)
|
||||
):
|
||||
if is_final or message_idx % 8 == 0:
|
||||
# content, _ = self._process_thinking_content(basic_mode_pending_chunk)
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=basic_mode_pending_chunk,
|
||||
is_final=is_final,
|
||||
)
|
||||
if is_final:
|
||||
yielded_final = True
|
||||
|
||||
if chunk is None:
|
||||
raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置')
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
import copy
|
||||
import typing
|
||||
from .. import runner
|
||||
from ..modelmgr import requester as modelmgr_requester
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.rag.context as rag_context
|
||||
@@ -27,109 +26,19 @@ Respond in the same language as the user's input.
|
||||
|
||||
@runner.runner_class('local-agent')
|
||||
class LocalAgentRunner(runner.RequestRunner):
|
||||
"""Local agent request runner"""
|
||||
"""本地Agent请求运行器"""
|
||||
|
||||
async def _get_model_candidates(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
) -> list[modelmgr_requester.RuntimeLLMModel]:
|
||||
"""Build ordered list of models to try: primary model + fallback models."""
|
||||
candidates = []
|
||||
class ToolCallTracker:
|
||||
"""工具调用追踪器"""
|
||||
|
||||
# Primary model
|
||||
if query.use_llm_model_uuid:
|
||||
try:
|
||||
primary = await self.ap.model_mgr.get_model_by_uuid(query.use_llm_model_uuid)
|
||||
candidates.append(primary)
|
||||
except ValueError:
|
||||
self.ap.logger.warning(f'Primary model {query.use_llm_model_uuid} not found')
|
||||
|
||||
# Fallback models
|
||||
fallback_uuids = (query.variables or {}).get('_fallback_model_uuids', [])
|
||||
for fb_uuid in fallback_uuids:
|
||||
try:
|
||||
fb_model = await self.ap.model_mgr.get_model_by_uuid(fb_uuid)
|
||||
candidates.append(fb_model)
|
||||
except ValueError:
|
||||
self.ap.logger.warning(f'Fallback model {fb_uuid} not found, skipping')
|
||||
|
||||
return candidates
|
||||
|
||||
async def _invoke_with_fallback(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
candidates: list[modelmgr_requester.RuntimeLLMModel],
|
||||
messages: list,
|
||||
funcs: list,
|
||||
remove_think: bool,
|
||||
) -> tuple[provider_message.Message, modelmgr_requester.RuntimeLLMModel]:
|
||||
"""Try non-streaming invocation with sequential fallback. Returns (message, model_used)."""
|
||||
last_error = None
|
||||
for model in candidates:
|
||||
try:
|
||||
msg = await model.provider.invoke_llm(
|
||||
query,
|
||||
model,
|
||||
messages,
|
||||
funcs if model.model_entity.abilities.__contains__('func_call') else [],
|
||||
extra_args=model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
return msg, model
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.ap.logger.warning(f'Model {model.model_entity.name} failed: {e}, trying next fallback...')
|
||||
raise last_error or RuntimeError('No model candidates available')
|
||||
|
||||
async def _invoke_stream_with_fallback(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
candidates: list[modelmgr_requester.RuntimeLLMModel],
|
||||
messages: list,
|
||||
funcs: list,
|
||||
remove_think: bool,
|
||||
) -> tuple[typing.AsyncGenerator, modelmgr_requester.RuntimeLLMModel]:
|
||||
"""Try streaming invocation with sequential fallback. Returns (stream_generator, model_used).
|
||||
|
||||
Fallback is only possible before any chunks have been yielded to the client.
|
||||
Once streaming starts, the model is committed.
|
||||
"""
|
||||
last_error = None
|
||||
for model in candidates:
|
||||
try:
|
||||
stream = model.provider.invoke_llm_stream(
|
||||
query,
|
||||
model,
|
||||
messages,
|
||||
funcs if model.model_entity.abilities.__contains__('func_call') else [],
|
||||
extra_args=model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
# Attempt to get the first chunk to verify the stream works
|
||||
first_chunk = await stream.__anext__()
|
||||
|
||||
async def _chain_stream(first, rest):
|
||||
yield first
|
||||
async for chunk in rest:
|
||||
yield chunk
|
||||
|
||||
return _chain_stream(first_chunk, stream), model
|
||||
except StopAsyncIteration:
|
||||
# Empty stream — treat as success (model returned nothing)
|
||||
async def _empty_stream():
|
||||
return
|
||||
yield # make it a generator
|
||||
|
||||
return _empty_stream(), model
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.ap.logger.warning(f'Model {model.model_entity.name} stream failed: {e}, trying next fallback...')
|
||||
raise last_error or RuntimeError('No model candidates available')
|
||||
def __init__(self):
|
||||
self.active_calls: dict[str, dict] = {}
|
||||
self.completed_calls: list[provider_message.ToolCall] = []
|
||||
|
||||
async def run(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message | provider_message.MessageChunk, None]:
|
||||
"""Run request"""
|
||||
"""运行请求"""
|
||||
pending_tool_calls = []
|
||||
|
||||
# Get knowledge bases list (new field)
|
||||
@@ -165,14 +74,15 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
self.ap.logger.warning(f'Knowledge base {kb_uuid} not found, skipping')
|
||||
continue
|
||||
|
||||
result = await kb.retrieve(
|
||||
user_message_text,
|
||||
settings={
|
||||
'bot_uuid': query.bot_uuid or '',
|
||||
'sender_id': str(query.sender_id),
|
||||
'session_name': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
},
|
||||
)
|
||||
# Get top_k based on KB type
|
||||
if kb.get_type() == 'internal':
|
||||
top_k = kb.knowledge_base_entity.top_k
|
||||
elif kb.get_type() == 'external':
|
||||
top_k = 5 # external kb's top_k is managed by plugin config
|
||||
else:
|
||||
top_k = 5 # default fallback
|
||||
|
||||
result = await kb.retrieve(user_message_text, top_k)
|
||||
|
||||
if result:
|
||||
all_results.extend(result)
|
||||
@@ -187,9 +97,9 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
if content.type == 'text' and content.text is not None:
|
||||
texts.append(f'[{idx}] {content.text}')
|
||||
idx += 1
|
||||
rag_context_text = '\n\n'.join(texts)
|
||||
rag_context = '\n\n'.join(texts)
|
||||
final_user_message_text = rag_combined_prompt_template.format(
|
||||
rag_context=rag_context_text, user_message=user_message_text
|
||||
rag_context=rag_context, user_message=user_message_text
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -211,51 +121,51 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
|
||||
remove_think = query.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
|
||||
# Build ordered candidate list (primary + fallbacks)
|
||||
candidates = await self._get_model_candidates(query)
|
||||
if not candidates:
|
||||
raise RuntimeError('No LLM model configured for local-agent runner')
|
||||
use_llm_model = await self.ap.model_mgr.get_model_by_uuid(query.use_llm_model_uuid)
|
||||
|
||||
self.ap.logger.debug(
|
||||
f'localagent req: query={query.query_id} req_messages={req_messages} '
|
||||
f'candidates={[m.model_entity.name for m in candidates]}'
|
||||
f'localagent req: query={query.query_id} req_messages={req_messages} use_llm_model={query.use_llm_model_uuid}'
|
||||
)
|
||||
|
||||
if not is_stream:
|
||||
# Non-streaming: invoke with fallback
|
||||
msg, use_llm_model = await self._invoke_with_fallback(
|
||||
# 非流式输出,直接请求
|
||||
|
||||
msg = await use_llm_model.provider.invoke_llm(
|
||||
query,
|
||||
candidates,
|
||||
use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs,
|
||||
remove_think,
|
||||
extra_args=use_llm_model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
yield msg
|
||||
final_msg = msg
|
||||
else:
|
||||
# Streaming: invoke with fallback
|
||||
# 流式输出,需要处理工具调用
|
||||
tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
||||
msg_idx = 0
|
||||
accumulated_content = ''
|
||||
accumulated_content = '' # 从开始累积的所有内容
|
||||
last_role = 'assistant'
|
||||
msg_sequence = 1
|
||||
|
||||
stream_src, use_llm_model = await self._invoke_stream_with_fallback(
|
||||
async for msg in use_llm_model.provider.invoke_llm_stream(
|
||||
query,
|
||||
candidates,
|
||||
use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs,
|
||||
remove_think,
|
||||
)
|
||||
async for msg in stream_src:
|
||||
extra_args=use_llm_model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
):
|
||||
msg_idx = msg_idx + 1
|
||||
|
||||
# 记录角色
|
||||
if msg.role:
|
||||
last_role = msg.role
|
||||
|
||||
# 累积内容
|
||||
if msg.content:
|
||||
accumulated_content += msg.content
|
||||
|
||||
# 处理工具调用
|
||||
if msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
if tool_call.id not in tool_calls_map:
|
||||
@@ -267,18 +177,21 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
),
|
||||
)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
# 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖
|
||||
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
|
||||
|
||||
# continue
|
||||
# 每8个chunk或最后一个chunk时,输出所有累积的内容
|
||||
if msg_idx % 8 == 0 or msg.is_final:
|
||||
msg_sequence += 1
|
||||
yield provider_message.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content,
|
||||
content=accumulated_content, # 输出所有累积内容
|
||||
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
|
||||
is_final=msg.is_final,
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
|
||||
# 创建最终消息用于后续处理
|
||||
final_msg = provider_message.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content,
|
||||
@@ -293,8 +206,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
|
||||
req_messages.append(final_msg)
|
||||
|
||||
# Once a model succeeds, commit to it for the tool call loop
|
||||
# (no fallback mid-conversation — different models may interpret tool results differently)
|
||||
# 持续请求,只要还有待处理的工具调用就继续处理调用
|
||||
while pending_tool_calls:
|
||||
for tool_call in pending_tool_calls:
|
||||
try:
|
||||
@@ -335,6 +247,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
|
||||
req_messages.append(msg)
|
||||
except Exception as e:
|
||||
# 工具调用出错,添加一个报错信息到 req_messages
|
||||
err_msg = provider_message.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id)
|
||||
|
||||
yield err_msg
|
||||
@@ -342,38 +255,39 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
req_messages.append(err_msg)
|
||||
|
||||
self.ap.logger.debug(
|
||||
f'localagent req: query={query.query_id} req_messages={req_messages} '
|
||||
f'use_llm_model={use_llm_model.model_entity.name}'
|
||||
f'localagent req: query={query.query_id} req_messages={req_messages} use_llm_model={query.use_llm_model_uuid}'
|
||||
)
|
||||
|
||||
if is_stream:
|
||||
tool_calls_map = {}
|
||||
msg_idx = 0
|
||||
accumulated_content = ''
|
||||
accumulated_content = '' # 从开始累积的所有内容
|
||||
last_role = 'assistant'
|
||||
msg_sequence = first_end_sequence
|
||||
|
||||
tool_stream_src = use_llm_model.provider.invoke_llm_stream(
|
||||
async for msg in use_llm_model.provider.invoke_llm_stream(
|
||||
query,
|
||||
use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [],
|
||||
query.use_funcs,
|
||||
extra_args=use_llm_model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
async for msg in tool_stream_src:
|
||||
):
|
||||
msg_idx += 1
|
||||
|
||||
# 记录角色
|
||||
if msg.role:
|
||||
last_role = msg.role
|
||||
|
||||
# Prepend first-round content on first chunk of tool-call round
|
||||
# 第一次请求工具调用时的内容
|
||||
if msg_idx == 1:
|
||||
accumulated_content = first_content if first_content is not None else accumulated_content
|
||||
|
||||
# 累积内容
|
||||
if msg.content:
|
||||
accumulated_content += msg.content
|
||||
|
||||
# 处理工具调用
|
||||
if msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
if tool_call.id not in tool_calls_map:
|
||||
@@ -385,13 +299,15 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
),
|
||||
)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
# 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖
|
||||
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
|
||||
|
||||
# 每8个chunk或最后一个chunk时,输出所有累积的内容
|
||||
if msg_idx % 8 == 0 or msg.is_final:
|
||||
msg_sequence += 1
|
||||
yield provider_message.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content,
|
||||
content=accumulated_content, # 输出所有累积内容
|
||||
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
|
||||
is_final=msg.is_final,
|
||||
msg_sequence=msg_sequence,
|
||||
@@ -404,12 +320,12 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
else:
|
||||
# Non-streaming: use committed model directly (no fallback in tool loop)
|
||||
# 处理完所有调用,再次请求
|
||||
msg = await use_llm_model.provider.invoke_llm(
|
||||
query,
|
||||
use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [],
|
||||
query.use_funcs,
|
||||
extra_args=use_llm_model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
|
||||
@@ -5,8 +5,6 @@ import json
|
||||
import uuid
|
||||
import aiohttp
|
||||
|
||||
from langbot.pkg.utils import httpclient
|
||||
|
||||
from .. import runner
|
||||
from ...core import app
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
@@ -219,50 +217,50 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
self.ap.logger.debug('no auth')
|
||||
|
||||
# 调用webhook
|
||||
session = httpclient.get_session()
|
||||
if is_stream:
|
||||
# 流式请求
|
||||
async with session.post(
|
||||
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
|
||||
# 处理流式响应
|
||||
async for chunk in self._process_stream_response(response):
|
||||
yield chunk
|
||||
else:
|
||||
async with session.post(
|
||||
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||
) as response:
|
||||
try:
|
||||
async for chunk in self._process_stream_response(response):
|
||||
output_content = chunk.content if chunk.is_final else ''
|
||||
except:
|
||||
# 非流式请求(保持原有逻辑)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if is_stream:
|
||||
# 流式请求
|
||||
async with session.post(
|
||||
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
|
||||
# 解析响应
|
||||
response_data = await response.json()
|
||||
self.ap.logger.debug(f'n8n webhook response: {response_data}')
|
||||
# 处理流式响应
|
||||
async for chunk in self._process_stream_response(response):
|
||||
yield chunk
|
||||
else:
|
||||
async with session.post(
|
||||
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||
) as response:
|
||||
try:
|
||||
async for chunk in self._process_stream_response(response):
|
||||
output_content = chunk.content if chunk.is_final else ''
|
||||
except:
|
||||
# 非流式请求(保持原有逻辑)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
|
||||
# 从响应中提取输出
|
||||
if self.output_key in response_data:
|
||||
output_content = response_data[self.output_key]
|
||||
else:
|
||||
# 如果没有指定的输出键,则使用整个响应
|
||||
output_content = json.dumps(response_data, ensure_ascii=False)
|
||||
# 解析响应
|
||||
response_data = await response.json()
|
||||
self.ap.logger.debug(f'n8n webhook response: {response_data}')
|
||||
|
||||
# 返回消息
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=output_content,
|
||||
)
|
||||
# 从响应中提取输出
|
||||
if self.output_key in response_data:
|
||||
output_content = response_data[self.output_key]
|
||||
else:
|
||||
# 如果没有指定的输出键,则使用整个响应
|
||||
output_content = json.dumps(response_data, ensure_ascii=False)
|
||||
|
||||
# 返回消息
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=output_content,
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'n8n webhook call exception: {str(e)}')
|
||||
raise N8nAPIError(f'n8n webhook call exception: {str(e)}')
|
||||
|
||||
@@ -22,12 +22,12 @@ class KnowledgeBaseInterface(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def retrieve(self, query: str, settings: dict | None = None) -> list[rag_context.RetrievalResultEntry]:
|
||||
async def retrieve(self, query: str, top_k: int) -> list[rag_context.RetrievalResultEntry]:
|
||||
"""Retrieve relevant documents from the knowledge base
|
||||
|
||||
Args:
|
||||
query: The query string
|
||||
settings: Optional per-request retrieval settings overrides
|
||||
top_k: Number of top results to return
|
||||
|
||||
Returns:
|
||||
List of retrieve result entries
|
||||
@@ -45,8 +45,8 @@ class KnowledgeBaseInterface(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_knowledge_engine_plugin_id(self) -> str:
|
||||
"""Get the Knowledge Engine plugin ID"""
|
||||
def get_type(self) -> str:
|
||||
"""Get the type of knowledge base (internal/external)"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
85
src/langbot/pkg/rag/knowledge/external.py
Normal file
85
src/langbot/pkg/rag/knowledge/external.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""External knowledge base implementation"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langbot.pkg.core import app
|
||||
from langbot.pkg.entity.persistence import rag as persistence_rag
|
||||
from langbot_plugin.api.entities.builtin.rag import context as rag_context
|
||||
from .base import KnowledgeBaseInterface
|
||||
|
||||
|
||||
class ExternalKnowledgeBase(KnowledgeBaseInterface):
|
||||
"""External knowledge base that queries via HTTP API or plugin retriever"""
|
||||
|
||||
external_kb_entity: persistence_rag.ExternalKnowledgeBase
|
||||
|
||||
# Plugin retriever instance ID
|
||||
retriever_instance_id: str | None
|
||||
|
||||
def __init__(self, ap: app.Application, external_kb_entity: persistence_rag.ExternalKnowledgeBase):
|
||||
super().__init__(ap)
|
||||
self.external_kb_entity = external_kb_entity
|
||||
self.retriever_instance_id = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the external knowledge base"""
|
||||
# Use KB UUID as instance ID
|
||||
# Instance creation is now handled by the unified sync mechanism
|
||||
# when LangBot connects to runtime
|
||||
self.retriever_instance_id = self.external_kb_entity.uuid
|
||||
|
||||
self.ap.logger.info(
|
||||
f'Initialized external KB {self.external_kb_entity.uuid}, instance will be created by sync mechanism'
|
||||
)
|
||||
|
||||
async def retrieve(self, query: str, top_k: int = 5) -> list[rag_context.RetrievalResultEntry]:
|
||||
"""Retrieve documents from external knowledge base via plugin retriever"""
|
||||
if not self.retriever_instance_id:
|
||||
self.ap.logger.error(f'No retriever instance for KB {self.external_kb_entity.uuid}')
|
||||
return []
|
||||
|
||||
try:
|
||||
results = await self.ap.plugin_connector.retrieve_knowledge(
|
||||
self.external_kb_entity.plugin_author,
|
||||
self.external_kb_entity.plugin_name,
|
||||
self.external_kb_entity.retriever_name,
|
||||
self.retriever_instance_id,
|
||||
{'query': query},
|
||||
)
|
||||
|
||||
# Convert plugin results to RetrievalResultEntry
|
||||
retrieval_entries = []
|
||||
for result in results:
|
||||
retrieval_entries.append(rag_context.RetrievalResultEntry(**result))
|
||||
|
||||
return retrieval_entries
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Plugin retriever error: {e}')
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
def get_uuid(self) -> str:
|
||||
"""Get the UUID of the external knowledge base"""
|
||||
return self.external_kb_entity.uuid
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Get the name of the external knowledge base"""
|
||||
return self.external_kb_entity.name
|
||||
|
||||
def get_type(self) -> str:
|
||||
"""Get the type of knowledge base"""
|
||||
return 'external'
|
||||
|
||||
async def dispose(self):
|
||||
"""Clean up resources"""
|
||||
# Trigger sync to immediately delete the instance from plugin process
|
||||
# This ensures instance is cleaned up without waiting for next LangBot restart
|
||||
try:
|
||||
await self.ap.plugin_connector.sync_polymorphic_component_instances()
|
||||
self.ap.logger.info(
|
||||
f'Disposed external KB {self.external_kb_entity.uuid}, triggered sync to delete instance'
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to sync after disposing KB: {e}')
|
||||
@@ -1,19 +1,18 @@
|
||||
from __future__ import annotations
|
||||
import mimetypes
|
||||
import os.path
|
||||
import traceback
|
||||
import uuid
|
||||
import zipfile
|
||||
import io
|
||||
from typing import Any
|
||||
from .services import parser, chunker
|
||||
from langbot.pkg.core import app
|
||||
from langbot.pkg.rag.knowledge.services.embedder import Embedder
|
||||
from langbot.pkg.rag.knowledge.services.retriever import Retriever
|
||||
import sqlalchemy
|
||||
|
||||
|
||||
from langbot.pkg.entity.persistence import rag as persistence_rag
|
||||
from langbot.pkg.core import taskmgr
|
||||
from langbot_plugin.api.entities.builtin.rag import context as rag_context
|
||||
from .base import KnowledgeBaseInterface
|
||||
from .external import ExternalKnowledgeBase
|
||||
|
||||
|
||||
class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||
@@ -21,16 +20,28 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||
|
||||
knowledge_base_entity: persistence_rag.KnowledgeBase
|
||||
|
||||
parser: parser.FileParser
|
||||
|
||||
chunker: chunker.Chunker
|
||||
|
||||
embedder: Embedder
|
||||
|
||||
retriever: Retriever
|
||||
|
||||
def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase):
|
||||
super().__init__(ap)
|
||||
self.knowledge_base_entity = knowledge_base_entity
|
||||
self.parser = parser.FileParser(ap=self.ap)
|
||||
self.chunker = chunker.Chunker(ap=self.ap)
|
||||
self.embedder = Embedder(ap=self.ap)
|
||||
self.retriever = Retriever(ap=self.ap)
|
||||
# 传递kb_id给retriever
|
||||
self.retriever.kb_id = knowledge_base_entity.uuid
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def _store_file_task(
|
||||
self, file: persistence_rag.File, task_context: taskmgr.TaskContext, parser_plugin_id: str | None = None
|
||||
):
|
||||
async def _store_file_task(self, file: persistence_rag.File, task_context: taskmgr.TaskContext):
|
||||
try:
|
||||
# set file status to processing
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
@@ -39,45 +50,30 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||
.values(status='processing')
|
||||
)
|
||||
|
||||
task_context.set_current_action('Processing file')
|
||||
task_context.set_current_action('Parsing file')
|
||||
# parse file
|
||||
text = await self.parser.parse(file.file_name, file.extension)
|
||||
if not text:
|
||||
raise Exception(f'No text extracted from file {file.file_name}')
|
||||
|
||||
# Get file size from storage
|
||||
file_size = await self.ap.storage_mgr.storage_provider.size(file.file_name)
|
||||
task_context.set_current_action('Chunking file')
|
||||
# chunk file
|
||||
chunks_texts = await self.chunker.chunk(text)
|
||||
if not chunks_texts:
|
||||
raise Exception(f'No chunks extracted from file {file.file_name}')
|
||||
|
||||
# Detect MIME type from extension
|
||||
mime_type, _ = mimetypes.guess_type(file.file_name)
|
||||
if mime_type is None:
|
||||
mime_type = 'application/octet-stream'
|
||||
task_context.set_current_action('Embedding chunks')
|
||||
|
||||
# If a parser plugin is specified, call it before ingestion
|
||||
parsed_content = None
|
||||
if parser_plugin_id:
|
||||
task_context.set_current_action('Parsing file')
|
||||
file_bytes = await self.ap.storage_mgr.storage_provider.load(file.file_name)
|
||||
parse_context = {
|
||||
'mime_type': mime_type,
|
||||
'filename': file.file_name,
|
||||
'metadata': {},
|
||||
}
|
||||
parsed_content = await self.ap.plugin_connector.call_parser(parser_plugin_id, parse_context, file_bytes)
|
||||
|
||||
# Call plugin to ingest document
|
||||
result = await self._ingest_document(
|
||||
{
|
||||
'document_id': file.uuid,
|
||||
'filename': file.file_name,
|
||||
'extension': file.extension,
|
||||
'file_size': file_size,
|
||||
'mime_type': mime_type,
|
||||
},
|
||||
file.file_name, # storage path
|
||||
parsed_content=parsed_content,
|
||||
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(
|
||||
self.knowledge_base_entity.embedding_model_uuid
|
||||
)
|
||||
# embed chunks
|
||||
await self.embedder.embed_and_store(
|
||||
kb_id=self.knowledge_base_entity.uuid,
|
||||
file_id=file.uuid,
|
||||
chunks=chunks_texts,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
# Check plugin result status
|
||||
if result.get('status') == 'failed':
|
||||
error_msg = result.get('error_message', 'Plugin ingestion returned failed status')
|
||||
raise Exception(error_msg)
|
||||
|
||||
# set file status to completed
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
@@ -101,17 +97,16 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||
# delete file from storage
|
||||
await self.ap.storage_mgr.storage_provider.delete(file.file_name)
|
||||
|
||||
async def store_file(self, file_id: str, parser_plugin_id: str | None = None) -> str:
|
||||
async def store_file(self, file_id: str) -> str:
|
||||
# pre checking
|
||||
if not await self.ap.storage_mgr.storage_provider.exists(file_id):
|
||||
raise Exception(f'File {file_id} not found')
|
||||
|
||||
file_name = file_id
|
||||
_, ext = os.path.splitext(file_name)
|
||||
extension = ext.lstrip('.').lower() if ext else ''
|
||||
extension = file_name.split('.')[-1].lower()
|
||||
|
||||
if extension == 'zip':
|
||||
return await self._store_zip_file(file_id, parser_plugin_id=parser_plugin_id)
|
||||
return await self._store_zip_file(file_id)
|
||||
|
||||
file_uuid = str(uuid.uuid4())
|
||||
kb_id = self.knowledge_base_entity.uuid
|
||||
@@ -131,7 +126,7 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||
# run background task asynchronously
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self._store_file_task(file_obj, task_context=ctx, parser_plugin_id=parser_plugin_id),
|
||||
self._store_file_task(file_obj, task_context=ctx),
|
||||
kind='knowledge-operation',
|
||||
name=f'knowledge-store-file-{file_id}',
|
||||
label=f'Store file {file_id}',
|
||||
@@ -139,7 +134,7 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||
)
|
||||
return wrapper.id
|
||||
|
||||
async def _store_zip_file(self, zip_file_id: str, parser_plugin_id: str | None = None) -> str:
|
||||
async def _store_zip_file(self, zip_file_id: str) -> str:
|
||||
"""Handle ZIP file by extracting each document and storing them separately."""
|
||||
self.ap.logger.info(f'Processing ZIP file: {zip_file_id}')
|
||||
|
||||
@@ -155,8 +150,7 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||
if file_info.is_dir() or file_info.filename.startswith('.'):
|
||||
continue
|
||||
|
||||
_, file_ext = os.path.splitext(file_info.filename)
|
||||
file_extension = file_ext.lstrip('.').lower()
|
||||
file_extension = file_info.filename.split('.')[-1].lower()
|
||||
if file_extension not in supported_extensions:
|
||||
self.ap.logger.debug(f'Skipping unsupported file in ZIP: {file_info.filename}')
|
||||
continue
|
||||
@@ -165,18 +159,18 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||
file_content = zip_ref.read(file_info.filename)
|
||||
|
||||
base_name = file_info.filename.replace('/', '_').replace('\\', '_')
|
||||
file_stem, file_ext = os.path.splitext(base_name)
|
||||
extension = file_ext.lstrip('.')
|
||||
extension = base_name.split('.')[-1]
|
||||
file_name = base_name.split('.')[0]
|
||||
|
||||
if file_stem.startswith('__MACOSX'):
|
||||
if file_name.startswith('__MACOSX'):
|
||||
continue
|
||||
|
||||
extracted_file_id = file_stem + '_' + str(uuid.uuid4())[:8] + '.' + extension
|
||||
extracted_file_id = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension
|
||||
# save file to storage
|
||||
|
||||
await self.ap.storage_mgr.storage_provider.save(extracted_file_id, file_content)
|
||||
|
||||
task_id = await self.store_file(extracted_file_id, parser_plugin_id=parser_plugin_id)
|
||||
task_id = await self.store_file(extracted_file_id)
|
||||
stored_file_tasks.append(task_id)
|
||||
|
||||
self.ap.logger.info(
|
||||
@@ -195,28 +189,21 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||
|
||||
return stored_file_tasks[0] if stored_file_tasks else ''
|
||||
|
||||
async def retrieve(self, query: str, settings: dict | None = None) -> list[rag_context.RetrievalResultEntry]:
|
||||
# Merge stored retrieval_settings with per-request overrides
|
||||
stored = self.knowledge_base_entity.retrieval_settings or {}
|
||||
merged = {**stored, **(settings or {})}
|
||||
if 'top_k' not in merged:
|
||||
merged['top_k'] = 5 # fallback default
|
||||
|
||||
response = await self._retrieve(query, merged)
|
||||
|
||||
results_data = response.get('results', [])
|
||||
entries = []
|
||||
for r in results_data:
|
||||
if isinstance(r, dict):
|
||||
entries.append(rag_context.RetrievalResultEntry(**r))
|
||||
elif isinstance(r, rag_context.RetrievalResultEntry):
|
||||
entries.append(r)
|
||||
return entries
|
||||
async def retrieve(self, query: str, top_k: int) -> list[rag_context.RetrievalResultEntry]:
|
||||
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(
|
||||
self.knowledge_base_entity.embedding_model_uuid
|
||||
)
|
||||
return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model, top_k)
|
||||
|
||||
async def delete_file(self, file_id: str):
|
||||
await self._delete_document(file_id)
|
||||
# delete vector
|
||||
await self.ap.vector_db_mgr.vector_db.delete_by_file_id(self.knowledge_base_entity.uuid, file_id)
|
||||
|
||||
# delete chunk
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file_id)
|
||||
)
|
||||
|
||||
# Also cleanup DB record
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id)
|
||||
)
|
||||
@@ -229,295 +216,32 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||
"""Get the name of the knowledge base"""
|
||||
return self.knowledge_base_entity.name
|
||||
|
||||
def get_knowledge_engine_plugin_id(self) -> str:
|
||||
"""Get the Knowledge Engine plugin ID"""
|
||||
return self.knowledge_base_entity.knowledge_engine_plugin_id or ''
|
||||
def get_type(self) -> str:
|
||||
"""Get the type of knowledge base"""
|
||||
return 'internal'
|
||||
|
||||
async def dispose(self):
|
||||
"""Dispose the knowledge base, notifying the plugin to cleanup."""
|
||||
await self._on_kb_delete()
|
||||
|
||||
# ========== Plugin Communication Methods ==========
|
||||
|
||||
async def _on_kb_create(self) -> None:
|
||||
"""Notify plugin about KB creation."""
|
||||
plugin_id = self.knowledge_base_entity.knowledge_engine_plugin_id
|
||||
if not plugin_id:
|
||||
return
|
||||
|
||||
try:
|
||||
config = self.knowledge_base_entity.creation_settings or {}
|
||||
self.ap.logger.info(
|
||||
f'Calling RAG plugin {plugin_id}: on_knowledge_base_create(kb_id={self.knowledge_base_entity.uuid})'
|
||||
)
|
||||
await self.ap.plugin_connector.rag_on_kb_create(plugin_id, self.knowledge_base_entity.uuid, config)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to notify plugin {plugin_id} on KB create: {e}')
|
||||
raise
|
||||
|
||||
async def _on_kb_delete(self) -> None:
|
||||
"""Notify plugin about KB deletion."""
|
||||
plugin_id = self.knowledge_base_entity.knowledge_engine_plugin_id
|
||||
if not plugin_id:
|
||||
return
|
||||
|
||||
try:
|
||||
self.ap.logger.info(
|
||||
f'Calling RAG plugin {plugin_id}: on_knowledge_base_delete(kb_id={self.knowledge_base_entity.uuid})'
|
||||
)
|
||||
await self.ap.plugin_connector.rag_on_kb_delete(plugin_id, self.knowledge_base_entity.uuid)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to notify plugin {plugin_id} on KB delete: {e}')
|
||||
|
||||
async def _ingest_document(
|
||||
self,
|
||||
file_metadata: dict[str, Any],
|
||||
storage_path: str,
|
||||
parsed_content: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Call plugin to ingest document."""
|
||||
kb = self.knowledge_base_entity
|
||||
plugin_id = kb.knowledge_engine_plugin_id
|
||||
if not plugin_id:
|
||||
self.ap.logger.error(f'No RAG plugin ID configured for KB {kb.uuid}. Ingestion failed.')
|
||||
raise ValueError('RAG Plugin ID required')
|
||||
|
||||
self.ap.logger.info(f'Calling RAG plugin {plugin_id}: ingest(doc={file_metadata.get("filename")})')
|
||||
|
||||
# Inject knowledge_base_id into file metadata as required by SDK schema
|
||||
file_metadata['knowledge_base_id'] = kb.uuid
|
||||
|
||||
context_data = {
|
||||
'file_object': {
|
||||
'metadata': file_metadata,
|
||||
'storage_path': storage_path,
|
||||
},
|
||||
'knowledge_base_id': kb.uuid,
|
||||
'collection_id': kb.collection_id or kb.uuid,
|
||||
'creation_settings': kb.creation_settings or {},
|
||||
'parsed_content': parsed_content,
|
||||
}
|
||||
|
||||
try:
|
||||
result = await self.ap.plugin_connector.call_rag_ingest(plugin_id, context_data)
|
||||
return result
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Plugin ingestion failed: {e}')
|
||||
raise
|
||||
|
||||
async def _retrieve(
|
||||
self,
|
||||
query: str,
|
||||
settings: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Call plugin to retrieve documents.
|
||||
|
||||
Raises:
|
||||
ValueError: If no RAG plugin is configured for this KB.
|
||||
Exception: If the plugin retrieval call fails.
|
||||
"""
|
||||
kb = self.knowledge_base_entity
|
||||
plugin_id = kb.knowledge_engine_plugin_id
|
||||
if not plugin_id:
|
||||
raise ValueError(f'No RAG plugin ID configured for KB {kb.uuid}. Retrieval failed.')
|
||||
|
||||
# Session context (e.g. session_name) stays in retrieval_settings
|
||||
# for plugins that need it. Do NOT move them into filters, as filters
|
||||
# are passed directly to vector_search by some plugins (e.g. LangRAG)
|
||||
# and would cause empty results when the metadata field doesn't exist.
|
||||
filters = settings.pop('filters', {})
|
||||
|
||||
retrieval_context = {
|
||||
'query': query,
|
||||
'knowledge_base_id': kb.uuid,
|
||||
'collection_id': kb.collection_id or kb.uuid,
|
||||
'retrieval_settings': settings,
|
||||
'creation_settings': kb.creation_settings or {},
|
||||
'filters': filters,
|
||||
}
|
||||
|
||||
result = await self.ap.plugin_connector.call_rag_retrieve(
|
||||
plugin_id,
|
||||
retrieval_context,
|
||||
)
|
||||
return result
|
||||
|
||||
async def _delete_document(self, document_id: str) -> bool:
|
||||
"""Call plugin to delete document."""
|
||||
kb = self.knowledge_base_entity
|
||||
plugin_id = kb.knowledge_engine_plugin_id
|
||||
if not plugin_id:
|
||||
return False
|
||||
|
||||
self.ap.logger.info(f'Calling RAG plugin {plugin_id}: delete_document(doc_id={document_id})')
|
||||
|
||||
try:
|
||||
return await self.ap.plugin_connector.call_rag_delete_document(plugin_id, document_id, kb.uuid)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Plugin document deletion failed: {e}')
|
||||
return False
|
||||
await self.ap.vector_db_mgr.vector_db.delete_collection(self.knowledge_base_entity.uuid)
|
||||
|
||||
|
||||
class RAGManager:
|
||||
ap: app.Application
|
||||
|
||||
knowledge_bases: dict[str, KnowledgeBaseInterface]
|
||||
knowledge_bases: list[KnowledgeBaseInterface]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.knowledge_bases = {}
|
||||
self.knowledge_bases = []
|
||||
|
||||
async def initialize(self):
|
||||
await self.load_knowledge_bases_from_db()
|
||||
|
||||
async def get_all_knowledge_base_details(self) -> list[dict]:
|
||||
"""Get all knowledge bases with enriched Knowledge Engine details."""
|
||||
# 1. Get raw KBs from DB
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
|
||||
knowledge_bases = result.all()
|
||||
|
||||
# 2. Get all available Knowledge Engines for enrichment
|
||||
engine_map = {}
|
||||
if self.ap.plugin_connector.is_enable_plugin:
|
||||
try:
|
||||
engines = await self.ap.plugin_connector.list_knowledge_engines()
|
||||
engine_map = {e['plugin_id']: e for e in engines}
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to list Knowledge Engines: {e}')
|
||||
|
||||
# 3. Serialize and enrich
|
||||
kb_list = []
|
||||
for kb in knowledge_bases:
|
||||
kb_dict = self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, kb)
|
||||
self._enrich_kb_dict(kb_dict, engine_map)
|
||||
kb_list.append(kb_dict)
|
||||
|
||||
return kb_list
|
||||
|
||||
async def get_knowledge_base_details(self, kb_uuid: str) -> dict | None:
|
||||
"""Get specific knowledge base with enriched Knowledge Engine details."""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
kb = result.first()
|
||||
if not kb:
|
||||
return None
|
||||
|
||||
kb_dict = self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, kb)
|
||||
|
||||
# Fetch engines
|
||||
engine_map = {}
|
||||
if self.ap.plugin_connector.is_enable_plugin:
|
||||
try:
|
||||
engines = await self.ap.plugin_connector.list_knowledge_engines()
|
||||
engine_map = {e['plugin_id']: e for e in engines}
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to list Knowledge Engines: {e}')
|
||||
|
||||
self._enrich_kb_dict(kb_dict, engine_map)
|
||||
return kb_dict
|
||||
|
||||
@staticmethod
|
||||
def _to_i18n_name(name) -> dict:
|
||||
"""Ensure name is always an I18nObject-compatible dict.
|
||||
|
||||
If *name* is already a dict (with ``en_US`` / ``zh_Hans`` keys) it is
|
||||
returned as-is. A plain string is wrapped into an I18nObject so the
|
||||
frontend ``extractI18nObject`` helper never receives an unexpected type.
|
||||
"""
|
||||
if isinstance(name, dict):
|
||||
return name
|
||||
return {'en_US': str(name), 'zh_Hans': str(name)}
|
||||
|
||||
def _enrich_kb_dict(self, kb_dict: dict, engine_map: dict) -> None:
|
||||
"""Helper to inject engine info into KB dict."""
|
||||
plugin_id = kb_dict.get('knowledge_engine_plugin_id')
|
||||
|
||||
# Default fallback structure — name must be I18nObject for frontend compatibility
|
||||
fallback_name = self._to_i18n_name(plugin_id or 'Internal (Legacy)')
|
||||
fallback_info = {
|
||||
'plugin_id': plugin_id,
|
||||
'name': fallback_name,
|
||||
'capabilities': [],
|
||||
}
|
||||
|
||||
if not plugin_id:
|
||||
kb_dict['knowledge_engine'] = fallback_info
|
||||
return
|
||||
|
||||
engine_info = engine_map.get(plugin_id)
|
||||
if engine_info:
|
||||
kb_dict['knowledge_engine'] = {
|
||||
'plugin_id': plugin_id,
|
||||
'name': self._to_i18n_name(engine_info.get('name', plugin_id)),
|
||||
'capabilities': engine_info.get('capabilities', []),
|
||||
}
|
||||
else:
|
||||
kb_dict['knowledge_engine'] = fallback_info
|
||||
|
||||
async def create_knowledge_base(
|
||||
self,
|
||||
name: str,
|
||||
knowledge_engine_plugin_id: str,
|
||||
creation_settings: dict,
|
||||
retrieval_settings: dict | None = None,
|
||||
description: str = '',
|
||||
) -> persistence_rag.KnowledgeBase:
|
||||
"""Create a new knowledge base using a RAG plugin."""
|
||||
# Validate that the Knowledge Engine plugin exists
|
||||
if self.ap.plugin_connector.is_enable_plugin:
|
||||
try:
|
||||
engines = await self.ap.plugin_connector.list_knowledge_engines()
|
||||
engine_ids = [e.get('plugin_id') for e in engines]
|
||||
if knowledge_engine_plugin_id not in engine_ids:
|
||||
raise ValueError(f'Knowledge Engine plugin {knowledge_engine_plugin_id} not found')
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to validate Knowledge Engine plugin existence: {e}')
|
||||
|
||||
kb_uuid = str(uuid.uuid4())
|
||||
# Use UUID as collection ID by default for isolation
|
||||
collection_id = kb_uuid
|
||||
|
||||
kb_data = {
|
||||
'uuid': kb_uuid,
|
||||
'name': name,
|
||||
'description': description,
|
||||
'knowledge_engine_plugin_id': knowledge_engine_plugin_id,
|
||||
'collection_id': collection_id,
|
||||
'creation_settings': creation_settings,
|
||||
'retrieval_settings': retrieval_settings or {},
|
||||
}
|
||||
|
||||
# Create Entity
|
||||
kb = persistence_rag.KnowledgeBase(**kb_data)
|
||||
|
||||
# Persist
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.KnowledgeBase).values(kb_data))
|
||||
|
||||
# Load into Runtime
|
||||
runtime_kb = await self.load_knowledge_base(kb)
|
||||
|
||||
# Notify Plugin — rollback DB record and runtime entry on failure
|
||||
try:
|
||||
await runtime_kb._on_kb_create()
|
||||
except Exception:
|
||||
self.knowledge_bases.pop(kb_uuid, None)
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
raise
|
||||
|
||||
self.ap.logger.info(f'Created new Knowledge Base {name} ({kb_uuid}) using plugin {knowledge_engine_plugin_id}')
|
||||
return kb
|
||||
|
||||
async def load_knowledge_bases_from_db(self):
|
||||
self.ap.logger.info('Loading knowledge bases from db...')
|
||||
|
||||
self.knowledge_bases = {}
|
||||
self.knowledge_bases = []
|
||||
|
||||
# Load knowledge bases
|
||||
# Load internal knowledge bases
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
|
||||
knowledge_bases = result.all()
|
||||
|
||||
@@ -529,37 +253,86 @@ class RAGManager:
|
||||
f'Error loading knowledge base {knowledge_base.uuid}: {e}\n{traceback.format_exc()}'
|
||||
)
|
||||
|
||||
# Load external knowledge bases
|
||||
external_result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_rag.ExternalKnowledgeBase)
|
||||
)
|
||||
external_kbs = external_result.all()
|
||||
|
||||
for external_kb in external_kbs:
|
||||
try:
|
||||
# Don't trigger sync during batch loading - will sync once after LangBot connects to runtime
|
||||
await self.load_external_knowledge_base(external_kb, trigger_sync=False)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(
|
||||
f'Error loading external knowledge base {external_kb.uuid}: {e}\n{traceback.format_exc()}'
|
||||
)
|
||||
|
||||
async def load_knowledge_base(
|
||||
self,
|
||||
knowledge_base_entity: persistence_rag.KnowledgeBase | sqlalchemy.Row | dict,
|
||||
) -> RuntimeKnowledgeBase:
|
||||
if isinstance(knowledge_base_entity, sqlalchemy.Row):
|
||||
# Safe access to _mapping for SQLAlchemy 1.4+
|
||||
knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity._mapping)
|
||||
elif isinstance(knowledge_base_entity, dict):
|
||||
# Filter out non-database fields (like knowledge_engine which is computed)
|
||||
filtered_dict = {
|
||||
k: v for k, v in knowledge_base_entity.items() if k in persistence_rag.KnowledgeBase.ALL_DB_FIELDS
|
||||
}
|
||||
knowledge_base_entity = persistence_rag.KnowledgeBase(**filtered_dict)
|
||||
knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity)
|
||||
|
||||
runtime_knowledge_base = RuntimeKnowledgeBase(ap=self.ap, knowledge_base_entity=knowledge_base_entity)
|
||||
|
||||
await runtime_knowledge_base.initialize()
|
||||
|
||||
self.knowledge_bases[runtime_knowledge_base.get_uuid()] = runtime_knowledge_base
|
||||
self.knowledge_bases.append(runtime_knowledge_base)
|
||||
|
||||
return runtime_knowledge_base
|
||||
|
||||
async def load_external_knowledge_base(
|
||||
self,
|
||||
external_kb_entity: persistence_rag.ExternalKnowledgeBase | sqlalchemy.Row | dict,
|
||||
trigger_sync: bool = True,
|
||||
) -> ExternalKnowledgeBase:
|
||||
"""Load external knowledge base into runtime
|
||||
|
||||
Args:
|
||||
external_kb_entity: External KB entity to load
|
||||
trigger_sync: Whether to trigger sync after loading (default True for manual creation, False for batch loading)
|
||||
"""
|
||||
if isinstance(external_kb_entity, sqlalchemy.Row):
|
||||
external_kb_entity = persistence_rag.ExternalKnowledgeBase(**external_kb_entity._mapping)
|
||||
elif isinstance(external_kb_entity, dict):
|
||||
external_kb_entity = persistence_rag.ExternalKnowledgeBase(**external_kb_entity)
|
||||
|
||||
external_kb = ExternalKnowledgeBase(ap=self.ap, external_kb_entity=external_kb_entity)
|
||||
|
||||
await external_kb.initialize()
|
||||
|
||||
self.knowledge_bases.append(external_kb)
|
||||
|
||||
# Trigger sync to create the instance immediately (for manual creation)
|
||||
# Skip sync during batch loading from DB to avoid multiple sync calls
|
||||
if trigger_sync:
|
||||
try:
|
||||
await self.ap.plugin_connector.sync_polymorphic_component_instances()
|
||||
self.ap.logger.info(f'Triggered sync after loading external KB {external_kb_entity.uuid}')
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to sync after loading external KB: {e}')
|
||||
|
||||
return external_kb
|
||||
|
||||
async def get_knowledge_base_by_uuid(self, kb_uuid: str) -> KnowledgeBaseInterface | None:
|
||||
return self.knowledge_bases.get(kb_uuid)
|
||||
for kb in self.knowledge_bases:
|
||||
if kb.get_uuid() == kb_uuid:
|
||||
return kb
|
||||
return None
|
||||
|
||||
async def remove_knowledge_base_from_runtime(self, kb_uuid: str):
|
||||
self.knowledge_bases.pop(kb_uuid, None)
|
||||
for kb in self.knowledge_bases:
|
||||
if kb.get_uuid() == kb_uuid:
|
||||
self.knowledge_bases.remove(kb)
|
||||
return
|
||||
|
||||
async def delete_knowledge_base(self, kb_uuid: str):
|
||||
kb = self.knowledge_bases.pop(kb_uuid, None)
|
||||
if kb is not None:
|
||||
await kb.dispose()
|
||||
else:
|
||||
self.ap.logger.warning(f'Knowledge base {kb_uuid} not found in runtime, skipping plugin notification')
|
||||
for kb in self.knowledge_bases:
|
||||
if kb.get_uuid() == kb_uuid:
|
||||
await kb.dispose()
|
||||
self.knowledge_bases.remove(kb)
|
||||
return
|
||||
|
||||
0
src/langbot/pkg/rag/knowledge/services/__init__.py
Normal file
0
src/langbot/pkg/rag/knowledge/services/__init__.py
Normal file
15
src/langbot/pkg/rag/knowledge/services/base_service.py
Normal file
15
src/langbot/pkg/rag/knowledge/services/base_service.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# 封装异步操作
|
||||
import asyncio
|
||||
|
||||
|
||||
class BaseService:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def _run_sync(self, func, *args, **kwargs):
|
||||
"""
|
||||
在单独的线程中运行同步函数。
|
||||
如果第一个参数是 session,则在 to_thread 中获取新的 session。
|
||||
"""
|
||||
|
||||
return await asyncio.to_thread(func, *args, **kwargs)
|
||||
49
src/langbot/pkg/rag/knowledge/services/chunker.py
Normal file
49
src/langbot/pkg/rag/knowledge/services/chunker.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
from langbot.pkg.rag.knowledge.services import base_service
|
||||
from langbot.pkg.core import app
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
class Chunker(base_service.BaseService):
|
||||
"""
|
||||
A class for splitting long texts into smaller, overlapping chunks.
|
||||
"""
|
||||
|
||||
def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50):
|
||||
self.ap = ap
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
if self.chunk_overlap >= self.chunk_size:
|
||||
self.ap.logger.warning(
|
||||
'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.'
|
||||
)
|
||||
|
||||
def _split_text_sync(self, text: str) -> List[str]:
|
||||
"""
|
||||
Synchronously splits a long text into chunks with specified overlap.
|
||||
This is a CPU-bound operation, intended to be run in a separate thread.
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
length_function=len,
|
||||
is_separator_regex=False,
|
||||
)
|
||||
return text_splitter.split_text(text)
|
||||
|
||||
async def chunk(self, text: str) -> List[str]:
|
||||
"""
|
||||
Asynchronously chunks a given text into smaller pieces.
|
||||
"""
|
||||
self.ap.logger.info(f'Chunking text (length: {len(text)})...')
|
||||
# Run the synchronous splitting logic in a separate thread
|
||||
chunks = await self._run_sync(self._split_text_sync, text)
|
||||
self.ap.logger.info(f'Text chunked into {len(chunks)} pieces.')
|
||||
self.ap.logger.debug(f'Chunks: {json.dumps(chunks, indent=4, ensure_ascii=False)}')
|
||||
return chunks
|
||||
55
src/langbot/pkg/rag/knowledge/services/embedder.py
Normal file
55
src/langbot/pkg/rag/knowledge/services/embedder.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from typing import List
|
||||
from langbot.pkg.rag.knowledge.services.base_service import BaseService
|
||||
from langbot.pkg.entity.persistence import rag as persistence_rag
|
||||
from langbot.pkg.core import app
|
||||
from langbot.pkg.provider.modelmgr.requester import RuntimeEmbeddingModel
|
||||
import sqlalchemy
|
||||
|
||||
|
||||
class Embedder(BaseService):
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
|
||||
async def embed_and_store(
|
||||
self, kb_id: str, file_id: str, chunks: List[str], embedding_model: RuntimeEmbeddingModel
|
||||
) -> list[persistence_rag.Chunk]:
|
||||
# save chunk to db
|
||||
chunk_entities: list[persistence_rag.Chunk] = []
|
||||
chunk_ids: list[str] = []
|
||||
|
||||
for chunk_text in chunks:
|
||||
chunk_uuid = str(uuid.uuid4())
|
||||
chunk_ids.append(chunk_uuid)
|
||||
chunk_entity = persistence_rag.Chunk(uuid=chunk_uuid, file_id=file_id, text=chunk_text)
|
||||
chunk_entities.append(chunk_entity)
|
||||
|
||||
chunk_dicts = [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_rag.Chunk, chunk) for chunk in chunk_entities
|
||||
]
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts))
|
||||
|
||||
# get embeddings (batch size limit: 64 for OpenAI)
|
||||
MAX_BATCH_SIZE = 64
|
||||
embeddings_list: list[list[float]] = []
|
||||
|
||||
for i in range(0, len(chunks), MAX_BATCH_SIZE):
|
||||
batch = chunks[i : i + MAX_BATCH_SIZE]
|
||||
batch_embeddings = await embedding_model.provider.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=batch,
|
||||
extra_args={}, # TODO: add extra args
|
||||
knowledge_base_id=kb_id,
|
||||
call_type='embedding',
|
||||
)
|
||||
embeddings_list.extend(batch_embeddings)
|
||||
|
||||
# save embeddings to vdb
|
||||
await self.ap.vector_db_mgr.vector_db.add_embeddings(kb_id, chunk_ids, embeddings_list, chunk_dicts)
|
||||
|
||||
self.ap.logger.info(f'Successfully saved {len(chunk_entities)} embeddings to Knowledge Base.')
|
||||
|
||||
return chunk_entities
|
||||
291
src/langbot/pkg/rag/knowledge/services/parser.py
Normal file
291
src/langbot/pkg/rag/knowledge/services/parser.py
Normal file
@@ -0,0 +1,291 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import PyPDF2
|
||||
import io
|
||||
from docx import Document
|
||||
import chardet
|
||||
from typing import Union, Callable, Any
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
import re
|
||||
import asyncio # Import asyncio for async operations
|
||||
from langbot.pkg.core import app
|
||||
|
||||
|
||||
class FileParser:
|
||||
"""
|
||||
A robust file parser class to extract text content from various document formats.
|
||||
It supports TXT, PDF, DOCX, XLSX, CSV, Markdown, HTML, and EPUB files.
|
||||
All core file reading operations are designed to be run synchronously in a thread pool
|
||||
to avoid blocking the asyncio event loop.
|
||||
"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Runs a synchronous function in a separate thread to prevent blocking the event loop.
|
||||
This is a general utility method for wrapping blocking I/O operations.
|
||||
"""
|
||||
try:
|
||||
return await asyncio.to_thread(sync_func, *args, **kwargs)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}')
|
||||
raise
|
||||
|
||||
async def parse(self, file_name: str, extension: str) -> Union[str, None]:
|
||||
"""
|
||||
Parses the file based on its extension and returns the extracted text content.
|
||||
This is the main asynchronous entry point for parsing.
|
||||
|
||||
Args:
|
||||
file_name (str): The name of the file to be parsed, get from ap.storage_mgr
|
||||
|
||||
Returns:
|
||||
Union[str, None]: The extracted text content as a single string, or None if parsing fails.
|
||||
"""
|
||||
|
||||
file_extension = extension.lower()
|
||||
parser_method = getattr(self, f'_parse_{file_extension}', None)
|
||||
|
||||
if parser_method is None:
|
||||
self.ap.logger.error(f'Unsupported file format: {file_extension} for file {file_name}')
|
||||
return None
|
||||
|
||||
try:
|
||||
# Pass file_path to the specific parser methods
|
||||
return await parser_method(file_name)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to parse {file_extension} file {file_name}: {e}')
|
||||
return None
|
||||
|
||||
# --- Helper for reading files with encoding detection ---
|
||||
async def _read_file_content(self, file_name: str) -> Union[str, bytes]:
|
||||
"""
|
||||
Reads a file with automatic encoding detection, ensuring the synchronous
|
||||
file read operation runs in a separate thread.
|
||||
"""
|
||||
|
||||
# def _read_sync():
|
||||
# with open(file_path, 'rb') as file:
|
||||
# raw_data = file.read()
|
||||
# detected = chardet.detect(raw_data)
|
||||
# encoding = detected['encoding'] or 'utf-8'
|
||||
|
||||
# if mode == 'r':
|
||||
# return raw_data.decode(encoding, errors='ignore')
|
||||
# return raw_data # For binary mode
|
||||
|
||||
# return await self._run_sync(_read_sync)
|
||||
file_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
detected = chardet.detect(file_bytes)
|
||||
encoding = detected['encoding'] or 'utf-8'
|
||||
|
||||
return file_bytes.decode(encoding, errors='ignore')
|
||||
|
||||
# --- Specific Parser Methods ---
|
||||
|
||||
async def _parse_txt(self, file_name: str) -> str:
|
||||
"""Parses a TXT file and returns its content."""
|
||||
self.ap.logger.info(f'Parsing TXT file: {file_name}')
|
||||
return await self._read_file_content(file_name)
|
||||
|
||||
async def _parse_pdf(self, file_name: str) -> str:
|
||||
"""Parses a PDF file and returns its text content."""
|
||||
self.ap.logger.info(f'Parsing PDF file: {file_name}')
|
||||
|
||||
# def _parse_pdf_sync():
|
||||
# text_content = []
|
||||
# with open(file_name, 'rb') as file:
|
||||
# pdf_reader = PyPDF2.PdfReader(file)
|
||||
# for page in pdf_reader.pages:
|
||||
# text = page.extract_text()
|
||||
# if text:
|
||||
# text_content.append(text)
|
||||
# return '\n'.join(text_content)
|
||||
|
||||
# return await self._run_sync(_parse_pdf_sync)
|
||||
|
||||
pdf_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_pdf_sync():
|
||||
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
|
||||
text_content = []
|
||||
for page in pdf_reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_content.append(text)
|
||||
return '\n'.join(text_content)
|
||||
|
||||
return await self._run_sync(_parse_pdf_sync)
|
||||
|
||||
async def _parse_docx(self, file_name: str) -> str:
|
||||
"""Parses a DOCX file and returns its text content."""
|
||||
self.ap.logger.info(f'Parsing DOCX file: {file_name}')
|
||||
|
||||
docx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_docx_sync():
|
||||
doc = Document(io.BytesIO(docx_bytes))
|
||||
text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()]
|
||||
return '\n'.join(text_content)
|
||||
|
||||
return await self._run_sync(_parse_docx_sync)
|
||||
|
||||
async def _parse_doc(self, file_name: str) -> str:
|
||||
"""Handles .doc files, explicitly stating lack of direct support."""
|
||||
self.ap.logger.warning(f'Direct .doc parsing is not supported for {file_name}. Please convert to .docx first.')
|
||||
raise NotImplementedError('Direct .doc parsing not supported. Please convert to .docx first.')
|
||||
|
||||
# async def _parse_xlsx(self, file_name: str) -> str:
|
||||
# """Parses an XLSX file, returning text from all sheets."""
|
||||
# self.ap.logger.info(f'Parsing XLSX file: {file_name}')
|
||||
|
||||
# xlsx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
# def _parse_xlsx_sync():
|
||||
# excel_file = pd.ExcelFile(io.BytesIO(xlsx_bytes))
|
||||
# all_sheet_content = []
|
||||
# for sheet_name in excel_file.sheet_names:
|
||||
# df = pd.read_excel(io.BytesIO(xlsx_bytes), sheet_name=sheet_name)
|
||||
# sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n'
|
||||
# all_sheet_content.append(sheet_text)
|
||||
# return '\n'.join(all_sheet_content)
|
||||
|
||||
# return await self._run_sync(_parse_xlsx_sync)
|
||||
|
||||
# async def _parse_csv(self, file_name: str) -> str:
|
||||
# """Parses a CSV file and returns its content as a string."""
|
||||
# self.ap.logger.info(f'Parsing CSV file: {file_name}')
|
||||
|
||||
# csv_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
# def _parse_csv_sync():
|
||||
# # pd.read_csv can often detect encoding, but explicit detection is safer
|
||||
# # raw_data = self._read_file_content(
|
||||
# # file_name, mode='rb'
|
||||
# # ) # Note: this will need to be await outside this sync function
|
||||
# # _ = raw_data
|
||||
# # For simplicity, we'll let pandas handle encoding internally after a raw read.
|
||||
# # A more robust solution might pass encoding directly to pd.read_csv after detection.
|
||||
# detected = chardet.detect(io.BytesIO(csv_bytes))
|
||||
# encoding = detected['encoding'] or 'utf-8'
|
||||
# df = pd.read_csv(io.BytesIO(csv_bytes), encoding=encoding)
|
||||
# return df.to_string(index=False)
|
||||
|
||||
# return await self._run_sync(_parse_csv_sync)
|
||||
|
||||
async def _parse_md(self, file_name: str) -> str:
|
||||
"""Parses a Markdown file, converting it to structured plain text."""
|
||||
self.ap.logger.info(f'Parsing Markdown file: {file_name}')
|
||||
|
||||
md_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_markdown_sync():
|
||||
md_content = io.BytesIO(md_bytes).read().decode('utf-8', errors='ignore')
|
||||
html_content = markdown.markdown(
|
||||
md_content, extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code']
|
||||
)
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
text_parts = []
|
||||
for element in soup.children:
|
||||
if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
|
||||
level = int(element.name[1])
|
||||
text_parts.append('#' * level + ' ' + element.get_text().strip())
|
||||
elif element.name == 'p':
|
||||
text = element.get_text().strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
elif element.name in ['ul', 'ol']:
|
||||
for li in element.find_all('li'):
|
||||
text_parts.append(f'* {li.get_text().strip()}')
|
||||
elif element.name == 'pre':
|
||||
code_block = element.get_text().strip()
|
||||
if code_block:
|
||||
text_parts.append(f'```\n{code_block}\n```')
|
||||
elif element.name == 'table':
|
||||
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
|
||||
if table_str:
|
||||
text_parts.append(table_str)
|
||||
elif element.name:
|
||||
text = element.get_text(separator=' ', strip=True)
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
|
||||
return cleaned_text.strip()
|
||||
|
||||
return await self._run_sync(_parse_markdown_sync)
|
||||
|
||||
async def _parse_html(self, file_name: str) -> str:
|
||||
"""Parses an HTML file, extracting structured plain text."""
|
||||
self.ap.logger.info(f'Parsing HTML file: {file_name}')
|
||||
|
||||
html_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_html_sync():
|
||||
html_content = io.BytesIO(html_bytes).read().decode('utf-8', errors='ignore')
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
for script_or_style in soup(['script', 'style']):
|
||||
script_or_style.decompose()
|
||||
text_parts = []
|
||||
for element in soup.body.children if soup.body else soup.children:
|
||||
if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
|
||||
level = int(element.name[1])
|
||||
text_parts.append('#' * level + ' ' + element.get_text().strip())
|
||||
elif element.name == 'p':
|
||||
text = element.get_text().strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
elif element.name in ['ul', 'ol']:
|
||||
for li in element.find_all('li'):
|
||||
text = li.get_text().strip()
|
||||
if text:
|
||||
text_parts.append(f'* {text}')
|
||||
elif element.name == 'table':
|
||||
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
|
||||
if table_str:
|
||||
text_parts.append(table_str)
|
||||
elif element.name:
|
||||
text = element.get_text(separator=' ', strip=True)
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
|
||||
return cleaned_text.strip()
|
||||
|
||||
return await self._run_sync(_parse_html_sync)
|
||||
|
||||
def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int):
|
||||
"""Recursively adds TOC items to text_content (synchronous helper)."""
|
||||
indent = ' ' * level
|
||||
for item in toc_list:
|
||||
if isinstance(item, tuple):
|
||||
chapter, subchapters = item
|
||||
text_content.append(f'{indent}- {chapter.title}')
|
||||
self._add_toc_items_sync(subchapters, text_content, level + 1)
|
||||
else:
|
||||
text_content.append(f'{indent}- {item.title}')
|
||||
|
||||
def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str:
|
||||
"""Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous)."""
|
||||
headers = [th.get_text().strip() for th in table_element.find_all('th')]
|
||||
rows = []
|
||||
for tr in table_element.find_all('tr'):
|
||||
cells = [td.get_text().strip() for td in tr.find_all('td')]
|
||||
if cells:
|
||||
rows.append(cells)
|
||||
|
||||
if not headers and not rows:
|
||||
return ''
|
||||
|
||||
table_lines = []
|
||||
if headers:
|
||||
table_lines.append(' | '.join(headers))
|
||||
table_lines.append(' | '.join(['---'] * len(headers)))
|
||||
|
||||
for row_cells in rows:
|
||||
padded_cells = row_cells + [''] * (len(headers) - len(row_cells)) if headers else row_cells
|
||||
table_lines.append(' | '.join(padded_cells))
|
||||
|
||||
return '\n'.join(table_lines)
|
||||
53
src/langbot/pkg/rag/knowledge/services/retriever.py
Normal file
53
src/langbot/pkg/rag/knowledge/services/retriever.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import base_service
|
||||
from ....core import app
|
||||
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
|
||||
from langbot_plugin.api.entities.builtin.rag import context as rag_context
|
||||
from langbot_plugin.api.entities.builtin.provider.message import ContentElement
|
||||
|
||||
|
||||
class Retriever(base_service.BaseService):
|
||||
def __init__(self, ap: app.Application):
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
|
||||
async def retrieve(
|
||||
self, kb_id: str, query: str, embedding_model: RuntimeEmbeddingModel, k: int = 5
|
||||
) -> list[rag_context.RetrievalResultEntry]:
|
||||
self.ap.logger.info(
|
||||
f"Retrieving for query: '{query[:10]}' with k={k} using {embedding_model.model_entity.uuid}"
|
||||
)
|
||||
|
||||
query_embedding: list[float] = await embedding_model.provider.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=[query],
|
||||
extra_args={}, # TODO: add extra args
|
||||
knowledge_base_id=kb_id,
|
||||
query_text=query,
|
||||
call_type='retrieve',
|
||||
)
|
||||
|
||||
vector_results = await self.ap.vector_db_mgr.vector_db.search(kb_id, query_embedding[0], k)
|
||||
|
||||
# 'ids' shape mirrors the Chroma-style response contract for compatibility
|
||||
matched_vector_ids = vector_results.get('ids', [[]])[0]
|
||||
distances = vector_results.get('distances', [[]])[0]
|
||||
vector_metadatas = vector_results.get('metadatas', [[]])[0]
|
||||
|
||||
if not matched_vector_ids:
|
||||
self.ap.logger.info('No relevant chunks found in vector database.')
|
||||
return []
|
||||
|
||||
result: list[rag_context.RetrievalResultEntry] = []
|
||||
|
||||
for i, id in enumerate(matched_vector_ids):
|
||||
entry = rag_context.RetrievalResultEntry(
|
||||
id=id,
|
||||
content=[ContentElement.from_text(vector_metadatas[i].get('text', ''))],
|
||||
metadata=vector_metadatas[i],
|
||||
distance=distances[i],
|
||||
)
|
||||
result.append(entry)
|
||||
|
||||
return result
|
||||
@@ -1 +0,0 @@
|
||||
from .runtime import RAGRuntimeService as RAGRuntimeService
|
||||
@@ -1,89 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import posixpath
|
||||
from typing import Any
|
||||
from langbot.pkg.core import app
|
||||
|
||||
|
||||
class RAGRuntimeService:
|
||||
"""Service to handle RAG-related requests from plugins (Runtime).
|
||||
|
||||
This service acts as the bridge between plugin RPC requests and
|
||||
LangBot's infrastructure (embedding models, vector databases, file storage).
|
||||
"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def vector_upsert(
|
||||
self,
|
||||
collection_id: str,
|
||||
vectors: list[list[float]],
|
||||
ids: list[str],
|
||||
metadata: list[dict[str, Any]] | None = None,
|
||||
documents: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Handle VECTOR_UPSERT action."""
|
||||
metadatas = metadata if metadata else [{} for _ in vectors]
|
||||
await self.ap.vector_db_mgr.upsert(
|
||||
collection_name=collection_id,
|
||||
vectors=vectors,
|
||||
ids=ids,
|
||||
metadata=metadatas,
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
async def vector_search(
|
||||
self,
|
||||
collection_id: str,
|
||||
query_vector: list[float],
|
||||
top_k: int,
|
||||
filters: dict[str, Any] | None = None,
|
||||
search_type: str = 'vector',
|
||||
query_text: str = '',
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Handle VECTOR_SEARCH action."""
|
||||
return await self.ap.vector_db_mgr.search(
|
||||
collection_name=collection_id,
|
||||
query_vector=query_vector,
|
||||
limit=top_k,
|
||||
filter=filters,
|
||||
search_type=search_type,
|
||||
query_text=query_text,
|
||||
)
|
||||
|
||||
async def vector_delete(
|
||||
self, collection_id: str, file_ids: list[str] | None = None, filters: dict[str, Any] | None = None
|
||||
) -> int:
|
||||
"""Handle VECTOR_DELETE action.
|
||||
|
||||
Deletes vectors associated with the given file IDs from the collection.
|
||||
Each file_id corresponds to a document whose vectors will be removed.
|
||||
|
||||
Args:
|
||||
collection_id: The collection to delete from.
|
||||
file_ids: File IDs whose associated vectors should be deleted.
|
||||
Each file_id maps to a set of vectors stored with that file_id
|
||||
in their metadata.
|
||||
filters: Filter-based deletion (not yet supported, will raise).
|
||||
"""
|
||||
count = 0
|
||||
if file_ids:
|
||||
await self.ap.vector_db_mgr.delete_by_file_id(collection_name=collection_id, file_ids=file_ids)
|
||||
count = len(file_ids)
|
||||
elif filters:
|
||||
count = await self.ap.vector_db_mgr.delete_by_filter(collection_name=collection_id, filter=filters)
|
||||
return count
|
||||
|
||||
async def get_file_stream(self, storage_path: str) -> bytes:
|
||||
"""Handle GET_KNOWLEDEGE_FILE_STREAM action.
|
||||
|
||||
Uses the storage manager abstraction to load file content,
|
||||
regardless of the underlying storage provider.
|
||||
"""
|
||||
# Validate storage_path to prevent path traversal
|
||||
normalized = posixpath.normpath(storage_path)
|
||||
if normalized.startswith('/') or '..' in normalized.split('/'):
|
||||
raise ValueError('Invalid storage path')
|
||||
content_bytes = await self.ap.storage_mgr.storage_provider.load(normalized)
|
||||
return content_bytes if content_bytes else b''
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
from ..core import app
|
||||
from . import provider
|
||||
from .providers import localstorage
|
||||
from .providers import localstorage, s3storage
|
||||
|
||||
|
||||
class StorageMgr:
|
||||
@@ -21,8 +21,6 @@ class StorageMgr:
|
||||
storage_type = storage_config.get('use', 'local')
|
||||
|
||||
if storage_type == 's3':
|
||||
from .providers import s3storage
|
||||
|
||||
self.storage_provider = s3storage.S3StorageProvider(self.ap)
|
||||
self.ap.logger.info('Initialized S3 storage backend.')
|
||||
else:
|
||||
|
||||
@@ -43,13 +43,6 @@ class StorageProvider(abc.ABC):
|
||||
):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def size(
|
||||
self,
|
||||
key: str,
|
||||
) -> int:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_dir_recursive(
|
||||
self,
|
||||
|
||||
@@ -47,12 +47,6 @@ class LocalStorageProvider(provider.StorageProvider):
|
||||
):
|
||||
os.remove(os.path.join(LOCAL_STORAGE_PATH, f'{key}'))
|
||||
|
||||
async def size(
|
||||
self,
|
||||
key: str,
|
||||
) -> int:
|
||||
return os.path.getsize(os.path.join(LOCAL_STORAGE_PATH, f'{key}'))
|
||||
|
||||
async def delete_dir_recursive(
|
||||
self,
|
||||
dir_path: str,
|
||||
|
||||
@@ -117,21 +117,6 @@ class S3StorageProvider(provider.StorageProvider):
|
||||
self.ap.logger.error(f'Failed to delete from S3: {e}')
|
||||
raise
|
||||
|
||||
async def size(
|
||||
self,
|
||||
key: str,
|
||||
) -> int:
|
||||
"""Get object size from S3 without downloading it"""
|
||||
try:
|
||||
response = self.s3_client.head_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=key,
|
||||
)
|
||||
return response['ContentLength']
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to get size from S3: {e}')
|
||||
raise
|
||||
|
||||
async def delete_dir_recursive(
|
||||
self,
|
||||
dir_path: str,
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Survey module for in-product surveys triggered by events."""
|
||||
@@ -1,148 +0,0 @@
|
||||
"""Survey manager: tracks events, communicates with Space to fetch/submit surveys."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import typing
|
||||
import httpx
|
||||
import sqlalchemy
|
||||
|
||||
from ..core import app as core_app
|
||||
from ..entity.persistence.metadata import Metadata
|
||||
from ..utils import constants
|
||||
|
||||
SURVEY_TRIGGERED_KEY = 'survey_triggered_events'
|
||||
|
||||
|
||||
class SurveyManager:
|
||||
"""Manages survey lifecycle: event tracking, pending survey fetch, submission."""
|
||||
|
||||
def __init__(self, ap: core_app.Application):
|
||||
self.ap = ap
|
||||
self._triggered_events: set[str] = set()
|
||||
self._pending_survey: typing.Optional[dict] = None
|
||||
self._space_url: str = ''
|
||||
|
||||
async def initialize(self):
|
||||
space_config = self.ap.instance_config.data.get('space', {})
|
||||
self._space_url = space_config.get('url', '').rstrip('/')
|
||||
await self._load_triggered_events()
|
||||
|
||||
async def _load_triggered_events(self):
|
||||
"""Load previously triggered events from metadata table."""
|
||||
try:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(Metadata).where(Metadata.key == SURVEY_TRIGGERED_KEY)
|
||||
)
|
||||
row = result.first()
|
||||
if row:
|
||||
self._triggered_events = set(json.loads(row[0].value))
|
||||
except Exception:
|
||||
self._triggered_events = set()
|
||||
|
||||
async def _save_triggered_events(self):
|
||||
"""Persist triggered events to metadata table."""
|
||||
try:
|
||||
value = json.dumps(list(self._triggered_events))
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(Metadata).where(Metadata.key == SURVEY_TRIGGERED_KEY)
|
||||
)
|
||||
if result.first():
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(Metadata).where(Metadata.key == SURVEY_TRIGGERED_KEY).values(value=value)
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(Metadata).values(key=SURVEY_TRIGGERED_KEY, value=value)
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.debug(f'Failed to save survey triggered events: {e}')
|
||||
|
||||
def _is_space_configured(self) -> bool:
|
||||
space_config = self.ap.instance_config.data.get('space', {})
|
||||
if space_config.get('disable_telemetry', False):
|
||||
return False
|
||||
return bool(self._space_url)
|
||||
|
||||
async def trigger_event(self, event: str):
|
||||
"""Called when an event occurs. Checks Space for a pending survey."""
|
||||
if event in self._triggered_events:
|
||||
return
|
||||
if not self._is_space_configured():
|
||||
return
|
||||
|
||||
self._triggered_events.add(event)
|
||||
await self._save_triggered_events()
|
||||
|
||||
# Check for pending survey asynchronously
|
||||
asyncio.create_task(self._fetch_pending_survey(event))
|
||||
|
||||
async def _fetch_pending_survey(self, event: str):
|
||||
"""Fetch pending survey from Space for this event."""
|
||||
try:
|
||||
url = f'{self._space_url}/api/v1/survey/pending'
|
||||
payload = {
|
||||
'instance_id': constants.instance_id,
|
||||
'event': event,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(10)) as client:
|
||||
resp = await client.post(url, json=payload)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
if data.get('code') == 0 and data.get('data', {}).get('survey'):
|
||||
self._pending_survey = data['data']['survey']
|
||||
self.ap.logger.info(f'Survey pending: {self._pending_survey.get("survey_id")}')
|
||||
except Exception as e:
|
||||
self.ap.logger.debug(f'Failed to fetch pending survey: {e}')
|
||||
|
||||
def get_pending_survey(self) -> typing.Optional[dict]:
|
||||
"""Return the current pending survey (if any) for the frontend to display."""
|
||||
return self._pending_survey
|
||||
|
||||
def clear_pending_survey(self):
|
||||
"""Clear the pending survey (after user responds or dismisses)."""
|
||||
self._pending_survey = None
|
||||
|
||||
async def submit_response(self, survey_id: str, answers: dict, completed: bool = True) -> bool:
|
||||
"""Submit a survey response to Space."""
|
||||
if not self._is_space_configured():
|
||||
return False
|
||||
try:
|
||||
url = f'{self._space_url}/api/v1/survey/respond'
|
||||
payload = {
|
||||
'survey_id': survey_id,
|
||||
'instance_id': constants.instance_id,
|
||||
'answers': answers,
|
||||
'metadata': {
|
||||
'version': constants.semantic_version,
|
||||
},
|
||||
'completed': completed,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(10)) as client:
|
||||
resp = await client.post(url, json=payload)
|
||||
if resp.status_code == 200:
|
||||
self.clear_pending_survey()
|
||||
return True
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to submit survey response: {e}')
|
||||
return False
|
||||
|
||||
async def dismiss_survey(self, survey_id: str) -> bool:
|
||||
"""Dismiss a survey."""
|
||||
if not self._is_space_configured():
|
||||
return False
|
||||
try:
|
||||
url = f'{self._space_url}/api/v1/survey/dismiss'
|
||||
payload = {
|
||||
'survey_id': survey_id,
|
||||
'instance_id': constants.instance_id,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(10)) as client:
|
||||
resp = await client.post(url, json=payload)
|
||||
if resp.status_code == 200:
|
||||
self.clear_pending_survey()
|
||||
return True
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Failed to dismiss survey: {e}')
|
||||
return False
|
||||
@@ -60,7 +60,7 @@ class TelemetryManager:
|
||||
except Exception:
|
||||
sanitized['query_id'] = str(sanitized.get('query_id', ''))
|
||||
|
||||
for sfield in ('adapter', 'runner', 'runner_category', 'model_name', 'version', 'error', 'timestamp'):
|
||||
for sfield in ('adapter', 'runner', 'model_name', 'version', 'error', 'timestamp'):
|
||||
v = sanitized.get(sfield)
|
||||
sanitized[sfield] = '' if v is None else str(v)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import langbot
|
||||
|
||||
semantic_version = f'v{langbot.__version__}'
|
||||
|
||||
required_database_version = 24
|
||||
required_database_version = 18
|
||||
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
|
||||
|
||||
debug_mode = False
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
"""Shared aiohttp.ClientSession to avoid repeated SSL context creation.
|
||||
|
||||
Each call to `aiohttp.ClientSession()` creates a new `TCPConnector` which in turn
|
||||
creates a new `ssl.SSLContext` and loads all system root certificates. This is
|
||||
extremely expensive in both CPU and memory (~270MB total allocations observed via
|
||||
memray profiling).
|
||||
|
||||
This module provides a shared session pool so that all HTTP client code in LangBot
|
||||
reuses the same underlying SSL context and connection pool.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import aiohttp
|
||||
|
||||
_sessions: dict[str, aiohttp.ClientSession] = {}
|
||||
|
||||
|
||||
def get_session(*, trust_env: bool = False) -> aiohttp.ClientSession:
|
||||
"""Get or create a shared aiohttp.ClientSession.
|
||||
|
||||
Args:
|
||||
trust_env: Whether to trust environment variables for proxy settings.
|
||||
|
||||
Returns:
|
||||
A shared aiohttp.ClientSession instance.
|
||||
"""
|
||||
key = f'trust_env={trust_env}'
|
||||
|
||||
session = _sessions.get(key)
|
||||
if session is None or session.closed:
|
||||
session = aiohttp.ClientSession(trust_env=trust_env)
|
||||
_sessions[key] = session
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def close_all():
|
||||
"""Close all shared sessions. Call on application shutdown."""
|
||||
for session in _sessions.values():
|
||||
if not session.closed:
|
||||
await session.close()
|
||||
_sessions.clear()
|
||||
@@ -5,8 +5,6 @@ from urllib.parse import urlparse, parse_qs
|
||||
import ssl
|
||||
|
||||
import aiohttp
|
||||
|
||||
from langbot.pkg.utils import httpclient
|
||||
import PIL.Image
|
||||
import httpx
|
||||
|
||||
@@ -49,54 +47,53 @@ async def get_gewechat_image_base64(
|
||||
)
|
||||
|
||||
try:
|
||||
session = httpclient.get_session()
|
||||
# 获取图片下载链接
|
||||
try:
|
||||
async with session.post(
|
||||
f'{gewechat_url}/v2/api/message/downloadImage',
|
||||
headers=headers,
|
||||
json={'appId': app_id, 'type': image_type, 'xml': xml_content},
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
# print(response)
|
||||
raise Exception(f'获取gewechat图片下载失败: {await response.text()}')
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
# 获取图片下载链接
|
||||
try:
|
||||
async with session.post(
|
||||
f'{gewechat_url}/v2/api/message/downloadImage',
|
||||
headers=headers,
|
||||
json={'appId': app_id, 'type': image_type, 'xml': xml_content},
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
# print(response)
|
||||
raise Exception(f'获取gewechat图片下载失败: {await response.text()}')
|
||||
|
||||
resp_data = await response.json()
|
||||
if resp_data.get('ret') != 200:
|
||||
raise Exception(f'获取gewechat图片下载链接失败: {resp_data}')
|
||||
resp_data = await response.json()
|
||||
if resp_data.get('ret') != 200:
|
||||
raise Exception(f'获取gewechat图片下载链接失败: {resp_data}')
|
||||
|
||||
file_url = resp_data['data']['fileUrl']
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception('获取图片下载链接超时')
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f'获取图片下载链接网络错误: {str(e)}')
|
||||
file_url = resp_data['data']['fileUrl']
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception('获取图片下载链接超时')
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f'获取图片下载链接网络错误: {str(e)}')
|
||||
|
||||
# 解析原始URL并替换端口
|
||||
base_url = gewechat_file_url
|
||||
download_url = f'{base_url}/download/{file_url}'
|
||||
# 解析原始URL并替换端口
|
||||
base_url = gewechat_file_url
|
||||
download_url = f'{base_url}/download/{file_url}'
|
||||
|
||||
# 下载图片
|
||||
try:
|
||||
async with session.get(download_url) as img_response:
|
||||
if img_response.status != 200:
|
||||
raise Exception(f'下载图片失败: {await img_response.text()}, URL: {download_url}')
|
||||
# 下载图片
|
||||
try:
|
||||
async with session.get(download_url) as img_response:
|
||||
if img_response.status != 200:
|
||||
raise Exception(f'下载图片失败: {await img_response.text()}, URL: {download_url}')
|
||||
|
||||
image_data = await img_response.read()
|
||||
image_data = await img_response.read()
|
||||
|
||||
content_type = img_response.headers.get('Content-Type', '')
|
||||
if content_type:
|
||||
image_format = content_type.split('/')[-1]
|
||||
else:
|
||||
image_format = file_url.split('.')[-1]
|
||||
content_type = img_response.headers.get('Content-Type', '')
|
||||
if content_type:
|
||||
image_format = content_type.split('/')[-1]
|
||||
else:
|
||||
image_format = file_url.split('.')[-1]
|
||||
|
||||
base64_str = base64.b64encode(image_data).decode('utf-8')
|
||||
base64_str = base64.b64encode(image_data).decode('utf-8')
|
||||
|
||||
return base64_str, image_format
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f'下载图片超时, URL: {download_url}')
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f'下载图片网络错误: {str(e)}, URL: {download_url}')
|
||||
return base64_str, image_format
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f'下载图片超时, URL: {download_url}')
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f'下载图片网络错误: {str(e)}, URL: {download_url}')
|
||||
except Exception as e:
|
||||
raise Exception(f'获取图片失败: {str(e)}') from e
|
||||
|
||||
@@ -107,24 +104,24 @@ async def get_wecom_image_base64(pic_url: str) -> tuple[str, str]:
|
||||
:param pic_url: 企业微信图片URL
|
||||
:return: (base64_str, image_format)
|
||||
"""
|
||||
session = httpclient.get_session()
|
||||
async with session.get(pic_url) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f'Failed to download image: {response.status}')
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(pic_url) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f'Failed to download image: {response.status}')
|
||||
|
||||
# 读取图片数据
|
||||
image_data = await response.read()
|
||||
# 读取图片数据
|
||||
image_data = await response.read()
|
||||
|
||||
# 获取图片格式
|
||||
content_type = response.headers.get('Content-Type', '')
|
||||
image_format = content_type.split('/')[-1] # 例如 'image/jpeg' -> 'jpeg'
|
||||
# 获取图片格式
|
||||
content_type = response.headers.get('Content-Type', '')
|
||||
image_format = content_type.split('/')[-1] # 例如 'image/jpeg' -> 'jpeg'
|
||||
|
||||
# 转换为 base64
|
||||
import base64
|
||||
# 转换为 base64
|
||||
import base64
|
||||
|
||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||
|
||||
return image_base64, image_format
|
||||
return image_base64, image_format
|
||||
|
||||
|
||||
async def get_qq_official_image_base64(pic_url: str, content_type: str) -> tuple[str, str]:
|
||||
@@ -155,19 +152,21 @@ async def get_qq_image_bytes(image_url: str, query: dict = {}) -> tuple[bytes, s
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
session = httpclient.get_session()
|
||||
async with session.get(image_url, params=query, ssl=ssl_context, timeout=aiohttp.ClientTimeout(total=30.0)) as resp:
|
||||
resp.raise_for_status()
|
||||
file_bytes = await resp.read()
|
||||
content_type = resp.headers.get('Content-Type')
|
||||
if not content_type:
|
||||
image_format = 'jpeg'
|
||||
elif not content_type.startswith('image/'):
|
||||
pil_img = PIL.Image.open(io.BytesIO(file_bytes))
|
||||
image_format = pil_img.format.lower()
|
||||
else:
|
||||
image_format = content_type.split('/')[-1]
|
||||
return file_bytes, image_format
|
||||
async with aiohttp.ClientSession(trust_env=False) as session:
|
||||
async with session.get(
|
||||
image_url, params=query, ssl=ssl_context, timeout=aiohttp.ClientTimeout(total=30.0)
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
file_bytes = await resp.read()
|
||||
content_type = resp.headers.get('Content-Type')
|
||||
if not content_type:
|
||||
image_format = 'jpeg'
|
||||
elif not content_type.startswith('image/'):
|
||||
pil_img = PIL.Image.open(io.BytesIO(file_bytes))
|
||||
image_format = pil_img.format.lower()
|
||||
else:
|
||||
image_format = content_type.split('/')[-1]
|
||||
return file_bytes, image_format
|
||||
|
||||
|
||||
async def qq_image_url_to_base64(image_url: str) -> typing.Tuple[str, str]:
|
||||
@@ -205,11 +204,11 @@ async def extract_b64_and_format(image_base64_data: str) -> typing.Tuple[str, st
|
||||
async def get_slack_image_to_base64(pic_url: str, bot_token: str):
|
||||
headers = {'Authorization': f'Bearer {bot_token}'}
|
||||
try:
|
||||
session = httpclient.get_session()
|
||||
async with session.get(pic_url, headers=headers) as resp:
|
||||
mime_type = resp.headers.get('Content-Type', 'application/octet-stream')
|
||||
file_bytes = await resp.read()
|
||||
base64_str = base64.b64encode(file_bytes).decode('utf-8')
|
||||
return f'data:{mime_type};base64,{base64_str}'
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(pic_url, headers=headers) as resp:
|
||||
mime_type = resp.headers.get('Content-Type', 'application/octet-stream')
|
||||
file_bytes = await resp.read()
|
||||
base64_str = base64.b64encode(file_bytes).decode('utf-8')
|
||||
return f'data:{mime_type};base64,{base64_str}'
|
||||
except Exception as e:
|
||||
raise (e)
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
class RunnerCategory:
|
||||
LOCAL = 'local'
|
||||
CLOUD = 'cloud'
|
||||
UNKNOWN = 'unknown'
|
||||
|
||||
|
||||
CLOUD_DOMAINS = [
|
||||
'.n8n.cloud',
|
||||
'.n8n.io',
|
||||
'api.dify.ai',
|
||||
'cloud.dify.ai',
|
||||
'.coze.com',
|
||||
'.coze.cn',
|
||||
'cloud.langflow.ai',
|
||||
'.langflow.org',
|
||||
]
|
||||
|
||||
LOCAL_PATTERNS = [
|
||||
'localhost',
|
||||
'127.0.0.1',
|
||||
'0.0.0.0',
|
||||
'192.168.',
|
||||
'10.',
|
||||
'172.16.',
|
||||
'172.17.',
|
||||
'172.18.',
|
||||
'172.19.',
|
||||
'172.20.',
|
||||
'172.21.',
|
||||
'172.22.',
|
||||
'172.23.',
|
||||
'172.24.',
|
||||
'172.25.',
|
||||
'172.26.',
|
||||
'172.27.',
|
||||
'172.28.',
|
||||
'172.29.',
|
||||
'172.30.',
|
||||
'172.31.',
|
||||
]
|
||||
|
||||
|
||||
def get_runner_category(runner_name: str, runner_url: str) -> str:
|
||||
if not runner_url:
|
||||
return RunnerCategory.UNKNOWN
|
||||
|
||||
try:
|
||||
parsed_url = urlparse(runner_url)
|
||||
host = parsed_url.hostname.lower() if parsed_url.hostname else ''
|
||||
except Exception:
|
||||
return RunnerCategory.UNKNOWN
|
||||
|
||||
for pattern in LOCAL_PATTERNS:
|
||||
if host.startswith(pattern):
|
||||
return RunnerCategory.LOCAL
|
||||
|
||||
for domain in CLOUD_DOMAINS:
|
||||
if host.endswith(domain):
|
||||
return RunnerCategory.CLOUD
|
||||
|
||||
return RunnerCategory.CLOUD
|
||||
|
||||
|
||||
def get_runner_info(runner_name: str, runner_url: str) -> dict:
|
||||
return {
|
||||
'name': runner_name,
|
||||
'url': runner_url,
|
||||
'category': get_runner_category(runner_name, runner_url),
|
||||
}
|
||||
|
||||
|
||||
def is_cloud_runner(runner_name: str, runner_url: str) -> bool:
|
||||
return get_runner_category(runner_name, runner_url) == RunnerCategory.CLOUD
|
||||
|
||||
|
||||
def is_local_runner(runner_name: str, runner_url: str) -> bool:
|
||||
return get_runner_category(runner_name, runner_url) == RunnerCategory.LOCAL
|
||||
|
||||
|
||||
def extract_runner_url(runner_name: str, runner, pipeline_config: dict | None) -> str | None:
|
||||
if not runner or not hasattr(runner, 'pipeline_config'):
|
||||
return None
|
||||
|
||||
ai_config = pipeline_config.get('ai', {}) if pipeline_config else {}
|
||||
|
||||
if runner_name == 'dify-service-api':
|
||||
return ai_config.get('dify-service-api', {}).get('base-url')
|
||||
elif runner_name == 'n8n-service-api':
|
||||
return ai_config.get('n8n-service-api', {}).get('webhook-url')
|
||||
elif runner_name == 'coze-api':
|
||||
return ai_config.get('coze-api', {}).get('api-base')
|
||||
elif runner_name == 'langflow-api':
|
||||
return ai_config.get('langflow-api', {}).get('base-url')
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_runner_category_from_runner(runner_name: str, runner, pipeline_config: dict | None) -> str:
|
||||
runner_url = extract_runner_url(runner_name, runner, pipeline_config)
|
||||
return get_runner_category(runner_name, runner_url)
|
||||
@@ -1,69 +0,0 @@
|
||||
"""Shared utilities for metadata filter handling across VDB backends.
|
||||
|
||||
Canonical filter format (Chroma-style ``where`` syntax):
|
||||
|
||||
{"file_id": "abc"} # implicit $eq
|
||||
{"file_id": {"$eq": "abc"}} # explicit $eq
|
||||
{"created_at": {"$gte": 1700000000}} # comparison
|
||||
{"file_type": {"$in": ["pdf", "docx"]}} # in-list
|
||||
|
||||
Multiple top-level keys are AND-ed. Supported operators:
|
||||
``$eq``, ``$ne``, ``$gt``, ``$gte``, ``$lt``, ``$lte``, ``$in``, ``$nin``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
SUPPORTED_OPS = frozenset({'$eq', '$ne', '$gt', '$gte', '$lt', '$lte', '$in', '$nin'})
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def normalize_filter(
|
||||
raw: dict[str, Any] | None,
|
||||
) -> list[tuple[str, str, Any]]:
|
||||
"""Parse a canonical filter dict into ``[(field, op, value)]`` triples.
|
||||
|
||||
Returns an empty list when *raw* is ``None`` or empty.
|
||||
|
||||
Raises ``ValueError`` on unsupported operators or malformed entries.
|
||||
"""
|
||||
if not raw:
|
||||
return []
|
||||
|
||||
triples: list[tuple[str, str, Any]] = []
|
||||
for field, condition in raw.items():
|
||||
if isinstance(condition, dict):
|
||||
for op, value in condition.items():
|
||||
if op not in SUPPORTED_OPS:
|
||||
raise ValueError(f'Unsupported filter operator: {op}')
|
||||
triples.append((field, op, value))
|
||||
else:
|
||||
# Bare value -> implicit $eq
|
||||
triples.append((field, '$eq', condition))
|
||||
return triples
|
||||
|
||||
|
||||
def strip_unsupported_fields(
|
||||
triples: list[tuple[str, str, Any]],
|
||||
supported_fields: set[str],
|
||||
) -> list[tuple[str, str, Any]]:
|
||||
"""Return only triples whose field is in *supported_fields*.
|
||||
|
||||
Dropped fields are logged at WARNING level so the caller knows they were
|
||||
silently ignored (useful for Milvus / pgvector which only store a fixed
|
||||
schema).
|
||||
"""
|
||||
kept: list[tuple[str, str, Any]] = []
|
||||
for field, op, value in triples:
|
||||
if field in supported_fields:
|
||||
kept.append((field, op, value))
|
||||
else:
|
||||
logger.warning(
|
||||
'Filter field %r is not supported by this backend and will be ignored (supported: %s)',
|
||||
field,
|
||||
', '.join(sorted(supported_fields)),
|
||||
)
|
||||
return kept
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..core import app
|
||||
from .vdb import VectorDatabase, SearchType
|
||||
from .vdb import VectorDatabase
|
||||
from .vdbs.chroma import ChromaVectorDatabase
|
||||
from .vdbs.qdrant import QdrantVectorDatabase
|
||||
from .vdbs.seekdb import SeekDBVectorDatabase
|
||||
@@ -65,95 +65,3 @@ class VectorDBManager:
|
||||
else:
|
||||
self.vector_db = ChromaVectorDatabase(self.ap)
|
||||
self.ap.logger.warning('No vector database backend configured, defaulting to Chroma.')
|
||||
|
||||
def get_supported_search_types(self) -> list[str]:
|
||||
"""Return the search types supported by the current VDB backend."""
|
||||
if self.vector_db is None:
|
||||
return [SearchType.VECTOR.value]
|
||||
return [st.value for st in self.vector_db.supported_search_types()]
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: list[list[float]],
|
||||
ids: list[str],
|
||||
metadata: list[dict] | None = None,
|
||||
documents: list[str] | None = None,
|
||||
):
|
||||
"""Proxy: Upsert vectors"""
|
||||
await self.vector_db.add_embeddings(
|
||||
collection=collection_name,
|
||||
ids=ids,
|
||||
embeddings_list=vectors,
|
||||
metadatas=metadata or [{} for _ in vectors],
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_vector: list[float],
|
||||
limit: int,
|
||||
filter: dict | None = None,
|
||||
search_type: str = 'vector',
|
||||
query_text: str = '',
|
||||
) -> list[dict]:
|
||||
"""Proxy: Search vectors.
|
||||
|
||||
Returns a list of dicts with keys: 'id', 'distance', 'metadata'.
|
||||
The underlying VectorDatabase.search returns Chroma-style format:
|
||||
{ 'ids': [['id1']], 'distances': [[0.1]], 'metadatas': [[{}]] }
|
||||
"""
|
||||
results = await self.vector_db.search(
|
||||
collection=collection_name,
|
||||
query_embedding=query_vector,
|
||||
k=limit,
|
||||
search_type=search_type,
|
||||
query_text=query_text,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
if not results or 'ids' not in results or not results['ids']:
|
||||
return []
|
||||
|
||||
# Flatten nested lists (Chroma returns batch-style: list of lists)
|
||||
raw_ids = results['ids']
|
||||
raw_dists = results.get('distances', [])
|
||||
raw_metas = results.get('metadatas', [])
|
||||
|
||||
r_ids = raw_ids[0] if raw_ids and isinstance(raw_ids[0], list) else raw_ids
|
||||
r_dists = raw_dists[0] if raw_dists and isinstance(raw_dists[0], list) else raw_dists
|
||||
r_metas = raw_metas[0] if raw_metas and isinstance(raw_metas[0], list) else raw_metas
|
||||
|
||||
parsed_results = []
|
||||
for i, id_val in enumerate(r_ids):
|
||||
parsed_results.append(
|
||||
{
|
||||
'id': id_val,
|
||||
'distance': r_dists[i] if r_dists and i < len(r_dists) else 0.0,
|
||||
'metadata': r_metas[i] if r_metas and i < len(r_metas) else {},
|
||||
}
|
||||
)
|
||||
|
||||
return parsed_results
|
||||
|
||||
async def delete_by_file_id(self, collection_name: str, file_ids: list[str]):
|
||||
"""Proxy: Delete vectors by file_id (metadata-level identifier).
|
||||
|
||||
This delegates to VectorDatabase.delete_by_file_id which removes
|
||||
all vectors associated with the given file IDs.
|
||||
"""
|
||||
for file_id in file_ids:
|
||||
await self.vector_db.delete_by_file_id(collection_name, file_id)
|
||||
|
||||
async def delete_collection(self, collection_name: str):
|
||||
"""Proxy: Delete an entire collection."""
|
||||
await self.vector_db.delete_collection(collection_name)
|
||||
|
||||
async def delete_by_filter(self, collection_name: str, filter: dict) -> int:
|
||||
"""Proxy: Delete vectors by metadata filter.
|
||||
|
||||
Returns:
|
||||
Number of deleted vectors (best-effort; some backends return 0).
|
||||
"""
|
||||
return await self.vector_db.delete_by_filter(collection_name, filter)
|
||||
|
||||
@@ -1,28 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import enum
|
||||
from typing import Any, Dict
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SearchType(str, enum.Enum):
|
||||
"""Supported search types for vector databases."""
|
||||
|
||||
VECTOR = 'vector'
|
||||
FULL_TEXT = 'full_text'
|
||||
HYBRID = 'hybrid'
|
||||
|
||||
|
||||
class VectorDatabase(abc.ABC):
|
||||
@classmethod
|
||||
def supported_search_types(cls) -> list[SearchType]:
|
||||
"""Return the search types supported by this VDB backend.
|
||||
|
||||
Default: vector search only. Override in subclasses that support
|
||||
full-text or hybrid search.
|
||||
"""
|
||||
return [SearchType.VECTOR]
|
||||
|
||||
@abc.abstractmethod
|
||||
async def add_embeddings(
|
||||
self,
|
||||
@@ -30,47 +12,14 @@ class VectorDatabase(abc.ABC):
|
||||
ids: list[str],
|
||||
embeddings_list: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
documents: list[str] | None = None,
|
||||
documents: list[str],
|
||||
) -> None:
|
||||
"""Add vector data to the specified collection.
|
||||
|
||||
Args:
|
||||
collection: Collection name.
|
||||
ids: Unique IDs for each vector.
|
||||
embeddings_list: List of embedding vectors.
|
||||
metadatas: List of metadata dicts.
|
||||
documents: Optional raw text documents. Required for full-text
|
||||
and hybrid search in backends that support them.
|
||||
"""
|
||||
"""Add vector data to the specified collection."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
collection: str,
|
||||
query_embedding: np.ndarray,
|
||||
k: int = 5,
|
||||
search_type: str = 'vector',
|
||||
query_text: str = '',
|
||||
filter: dict[str, Any] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for the most similar vectors in the specified collection.
|
||||
|
||||
Args:
|
||||
collection: Collection name.
|
||||
query_embedding: Query vector for similarity search.
|
||||
k: Number of results to return.
|
||||
search_type: One of 'vector', 'full_text', 'hybrid'.
|
||||
query_text: Raw query text, used for full_text and hybrid search.
|
||||
filter: Optional metadata filters using Chroma-style ``where``
|
||||
syntax. Multiple top-level keys are AND-ed. Supported
|
||||
operators: ``$eq``, ``$ne``, ``$gt``, ``$gte``, ``$lt``,
|
||||
``$lte``, ``$in``, ``$nin``. Example::
|
||||
|
||||
{"file_id": "abc"}
|
||||
{"created_at": {"$gte": 1700000000}}
|
||||
{"file_type": {"$in": ["pdf", "docx"]}}
|
||||
"""
|
||||
async def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]:
|
||||
"""Search for the most similar vectors in the specified collection."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -78,20 +27,6 @@ class VectorDatabase(abc.ABC):
|
||||
"""Delete vectors from the specified collection by file_id."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_by_filter(self, collection: str, filter: dict[str, Any]) -> int:
|
||||
"""Delete vectors matching the given metadata filter.
|
||||
|
||||
Args:
|
||||
collection: Collection name.
|
||||
filter: Metadata filter dict in canonical format (see ``search``).
|
||||
|
||||
Returns:
|
||||
Number of deleted vectors (best-effort; backends that cannot
|
||||
report an exact count may return 0).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
"""Get or create collection."""
|
||||
|
||||
@@ -2,14 +2,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from chromadb import PersistentClient
|
||||
from langbot.pkg.vector.vdb import VectorDatabase, SearchType
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
from langbot.pkg.core import app
|
||||
import chromadb
|
||||
import chromadb.errors
|
||||
|
||||
# RRF smoothing constant (standard value from the literature)
|
||||
_RRF_K = 60
|
||||
|
||||
|
||||
class ChromaVectorDatabase(VectorDatabase):
|
||||
def __init__(self, ap: app.Application, base_path: str = './data/chroma'):
|
||||
@@ -17,10 +14,6 @@ class ChromaVectorDatabase(VectorDatabase):
|
||||
self.client = PersistentClient(path=base_path)
|
||||
self._collections = {}
|
||||
|
||||
@classmethod
|
||||
def supported_search_types(cls) -> list[SearchType]:
|
||||
return [SearchType.VECTOR, SearchType.FULL_TEXT, SearchType.HYBRID]
|
||||
|
||||
async def get_or_create_collection(self, collection: str) -> chromadb.Collection:
|
||||
if collection not in self._collections:
|
||||
self._collections[collection] = await asyncio.to_thread(
|
||||
@@ -35,192 +28,27 @@ class ChromaVectorDatabase(VectorDatabase):
|
||||
ids: list[str],
|
||||
embeddings_list: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
documents: list[str] | None = None,
|
||||
) -> None:
|
||||
col = await self.get_or_create_collection(collection)
|
||||
kwargs: dict[str, Any] = dict(embeddings=embeddings_list, ids=ids, metadatas=metadatas)
|
||||
if documents is not None:
|
||||
kwargs['documents'] = documents
|
||||
await asyncio.to_thread(col.upsert, **kwargs)
|
||||
self.ap.logger.info(f"Upserted {len(ids)} embeddings to Chroma collection '{collection}'.")
|
||||
await asyncio.to_thread(col.add, embeddings=embeddings_list, ids=ids, metadatas=metadatas)
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.")
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection: str,
|
||||
query_embedding: list[float],
|
||||
k: int = 5,
|
||||
search_type: str = 'vector',
|
||||
query_text: str = '',
|
||||
filter: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
|
||||
col = await self.get_or_create_collection(collection)
|
||||
|
||||
if search_type == SearchType.FULL_TEXT:
|
||||
return await self._full_text_search(col, collection, k, query_text, filter)
|
||||
elif search_type == SearchType.HYBRID:
|
||||
return await self._hybrid_search(col, collection, query_embedding, k, query_text, filter)
|
||||
|
||||
# Default: vector search
|
||||
return await self._vector_search(col, collection, query_embedding, k, filter)
|
||||
|
||||
async def _vector_search(
|
||||
self,
|
||||
col: chromadb.Collection,
|
||||
collection: str,
|
||||
query_embedding: list[float],
|
||||
k: int,
|
||||
filter: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
query_kwargs: dict[str, Any] = dict(
|
||||
results = await asyncio.to_thread(
|
||||
col.query,
|
||||
query_embeddings=query_embedding,
|
||||
n_results=k,
|
||||
include=['metadatas', 'distances', 'documents'],
|
||||
)
|
||||
if filter:
|
||||
query_kwargs['where'] = filter
|
||||
results = await asyncio.to_thread(col.query, **query_kwargs)
|
||||
self.ap.logger.info(
|
||||
f"Chroma vector search in '{collection}' returned {len(results.get('ids', [[]])[0])} results."
|
||||
)
|
||||
self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.")
|
||||
return results
|
||||
|
||||
async def _full_text_search(
|
||||
self,
|
||||
col: chromadb.Collection,
|
||||
collection: str,
|
||||
k: int,
|
||||
query_text: str,
|
||||
filter: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
if not query_text:
|
||||
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]], 'documents': [[]]}
|
||||
|
||||
get_kwargs: dict[str, Any] = dict(
|
||||
where_document={'$contains': query_text},
|
||||
include=['metadatas', 'documents'],
|
||||
limit=k,
|
||||
)
|
||||
if filter:
|
||||
get_kwargs['where'] = filter
|
||||
results = await asyncio.to_thread(col.get, **get_kwargs)
|
||||
|
||||
# col.get returns flat lists; wrap into column-major format.
|
||||
# Distances are all 0.0 because Chroma's local $contains is a boolean
|
||||
# filter with no relevance scoring. Chroma's BM25 sparse embedding
|
||||
# function (ChromaBm25EmbeddingFunction) can generate scored sparse
|
||||
# vectors, but sparse vector *indexing* is only available on Chroma
|
||||
# Cloud, not locally. For ranked results, use hybrid mode or apply a
|
||||
# reranker in a downstream stage.
|
||||
ids = results.get('ids', [])
|
||||
metadatas = results.get('metadatas', []) or [None] * len(ids)
|
||||
documents = results.get('documents', []) or [None] * len(ids)
|
||||
distances = [0.0] * len(ids)
|
||||
|
||||
self.ap.logger.info(f"Chroma full-text search in '{collection}' returned {len(ids)} results.")
|
||||
return {'ids': [ids], 'metadatas': [metadatas], 'distances': [distances], 'documents': [documents]}
|
||||
|
||||
async def _hybrid_search(
|
||||
self,
|
||||
col: chromadb.Collection,
|
||||
collection: str,
|
||||
query_embedding: list[float],
|
||||
k: int,
|
||||
query_text: str,
|
||||
filter: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
# Fall back to pure vector search when no text is provided
|
||||
if not query_text:
|
||||
return await self._vector_search(col, collection, query_embedding, k, filter)
|
||||
|
||||
# Run vector search and full-text search in parallel
|
||||
vector_task = self._vector_search(col, collection, query_embedding, k, filter)
|
||||
text_task = self._full_text_search(col, collection, k, query_text, filter)
|
||||
vector_results, text_results = await asyncio.gather(vector_task, text_task)
|
||||
|
||||
vector_ids = vector_results.get('ids', [[]])[0]
|
||||
text_ids = text_results.get('ids', [[]])[0]
|
||||
|
||||
if not vector_ids and not text_ids:
|
||||
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]], 'documents': [[]]}
|
||||
|
||||
# RRF fusion
|
||||
fused = self._rrf_fuse([vector_ids, text_ids], k)
|
||||
if not fused:
|
||||
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]], 'documents': [[]]}
|
||||
|
||||
fused_ids = [doc_id for doc_id, _ in fused]
|
||||
|
||||
# Fetch full metadata and documents for fused results
|
||||
fetched = await asyncio.to_thread(col.get, ids=fused_ids, include=['metadatas', 'documents'])
|
||||
|
||||
# col.get returns results in arbitrary order; re-order to match fused ranking
|
||||
fetched_map: dict[str, tuple] = {}
|
||||
for i, fid in enumerate(fetched.get('ids', [])):
|
||||
meta = (fetched.get('metadatas') or [None] * len(fetched['ids']))[i]
|
||||
doc = (fetched.get('documents') or [None] * len(fetched['ids']))[i]
|
||||
fetched_map[fid] = (meta, doc)
|
||||
|
||||
ordered_ids = []
|
||||
ordered_metas = []
|
||||
ordered_docs = []
|
||||
ordered_dists = []
|
||||
|
||||
# Normalize RRF scores to 0~1 distances via min-max scaling.
|
||||
# Raw RRF scores are tiny (e.g. 0.016~0.033 with k=60) so a naive
|
||||
# ``1 - score`` would compress all distances into a narrow 0.96~0.98
|
||||
# band with almost no discriminative power. Min-max normalization
|
||||
# spreads them across the full 0~1 range (0.0 = best match).
|
||||
max_score = fused[0][1]
|
||||
min_score = fused[-1][1]
|
||||
score_range = max_score - min_score
|
||||
|
||||
for doc_id, score in fused:
|
||||
if doc_id in fetched_map:
|
||||
meta, doc = fetched_map[doc_id]
|
||||
ordered_ids.append(doc_id)
|
||||
ordered_metas.append(meta)
|
||||
ordered_docs.append(doc)
|
||||
if score_range > 0:
|
||||
ordered_dists.append(1.0 - (score - min_score) / score_range)
|
||||
else:
|
||||
ordered_dists.append(0.0)
|
||||
|
||||
self.ap.logger.info(
|
||||
f"Chroma hybrid search in '{collection}' returned {len(ordered_ids)} results "
|
||||
f'(vector={len(vector_ids)}, text={len(text_ids)}).'
|
||||
)
|
||||
return {
|
||||
'ids': [ordered_ids],
|
||||
'metadatas': [ordered_metas],
|
||||
'distances': [ordered_dists],
|
||||
'documents': [ordered_docs],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _rrf_fuse(result_lists: list[list[str]], k: int) -> list[tuple[str, float]]:
|
||||
"""Reciprocal Rank Fusion over multiple ranked ID lists.
|
||||
|
||||
Returns a list of (doc_id, rrf_score) sorted by descending score,
|
||||
truncated to *k* entries.
|
||||
"""
|
||||
scores: dict[str, float] = {}
|
||||
for ranked_ids in result_lists:
|
||||
for rank, doc_id in enumerate(ranked_ids):
|
||||
scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (_RRF_K + rank + 1)
|
||||
sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_results[:k]
|
||||
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
col = await self.get_or_create_collection(collection)
|
||||
await asyncio.to_thread(col.delete, where={'file_id': file_id})
|
||||
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with file_id: {file_id}")
|
||||
|
||||
async def delete_by_filter(self, collection: str, filter: dict[str, Any]) -> int:
|
||||
col = await self.get_or_create_collection(collection)
|
||||
await asyncio.to_thread(col.delete, where=filter)
|
||||
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' by filter")
|
||||
return 0 # Chroma delete does not return a count
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
if collection in self._collections:
|
||||
del self._collections[collection]
|
||||
|
||||
@@ -4,51 +4,8 @@ from typing import Any, Dict
|
||||
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema
|
||||
from pymilvus.milvus_client.index import IndexParams
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
from langbot.pkg.vector.filter_utils import normalize_filter, strip_unsupported_fields
|
||||
from langbot.pkg.core import app
|
||||
|
||||
# Milvus schema only stores these metadata fields; filter on other fields is
|
||||
# silently dropped with a warning.
|
||||
_MILVUS_SUPPORTED_FIELDS = {'text', 'file_id', 'chunk_uuid'}
|
||||
|
||||
|
||||
def _build_milvus_expr(filter_dict: dict[str, Any]) -> str:
|
||||
"""Translate canonical filter dict into a Milvus boolean expression string."""
|
||||
triples = normalize_filter(filter_dict)
|
||||
triples = strip_unsupported_fields(triples, _MILVUS_SUPPORTED_FIELDS)
|
||||
if not triples:
|
||||
return ''
|
||||
|
||||
parts: list[str] = []
|
||||
for field, op, value in triples:
|
||||
if op == '$eq':
|
||||
parts.append(f'{field} == {_milvus_literal(value)}')
|
||||
elif op == '$ne':
|
||||
parts.append(f'{field} != {_milvus_literal(value)}')
|
||||
elif op == '$gt':
|
||||
parts.append(f'{field} > {_milvus_literal(value)}')
|
||||
elif op == '$gte':
|
||||
parts.append(f'{field} >= {_milvus_literal(value)}')
|
||||
elif op == '$lt':
|
||||
parts.append(f'{field} < {_milvus_literal(value)}')
|
||||
elif op == '$lte':
|
||||
parts.append(f'{field} <= {_milvus_literal(value)}')
|
||||
elif op == '$in':
|
||||
items = ', '.join(_milvus_literal(v) for v in value)
|
||||
parts.append(f'{field} in [{items}]')
|
||||
elif op == '$nin':
|
||||
items = ', '.join(_milvus_literal(v) for v in value)
|
||||
parts.append(f'{field} not in [{items}]')
|
||||
return ' and '.join(parts)
|
||||
|
||||
|
||||
def _milvus_literal(value: Any) -> str:
|
||||
"""Format a Python value as a Milvus expression literal."""
|
||||
if isinstance(value, str):
|
||||
escaped = value.replace('\\', '\\\\').replace('"', '\\"')
|
||||
return f'"{escaped}"'
|
||||
return str(value)
|
||||
|
||||
|
||||
class MilvusVectorDatabase(VectorDatabase):
|
||||
"""Milvus vector database implementation"""
|
||||
@@ -198,7 +155,6 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
ids: list[str],
|
||||
embeddings_list: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
documents: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Add vector embeddings to Milvus collection
|
||||
|
||||
@@ -244,15 +200,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'")
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection: str,
|
||||
query_embedding: list[float],
|
||||
k: int = 5,
|
||||
search_type: str = 'vector',
|
||||
query_text: str = '',
|
||||
filter: dict[str, Any] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
|
||||
"""Search for similar vectors in Milvus collection
|
||||
|
||||
Args:
|
||||
@@ -269,19 +217,14 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
# Perform search
|
||||
search_params = {'metric_type': 'COSINE', 'params': {}}
|
||||
|
||||
search_kwargs: dict[str, Any] = dict(
|
||||
results = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
collection_name=collection,
|
||||
data=[query_embedding],
|
||||
limit=k,
|
||||
search_params=search_params,
|
||||
output_fields=['text', 'file_id', 'chunk_uuid'],
|
||||
)
|
||||
if filter:
|
||||
expr = _build_milvus_expr(filter)
|
||||
if expr:
|
||||
search_kwargs['filter'] = expr
|
||||
|
||||
results = await asyncio.to_thread(self.client.search, **search_kwargs)
|
||||
|
||||
# Convert results to Chroma-compatible format
|
||||
# Milvus returns: [[ {id, distance, entity: {...}} ]]
|
||||
@@ -325,21 +268,6 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
await asyncio.to_thread(self.client.delete, collection_name=collection, filter=f'file_id == "{file_id}"')
|
||||
self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}")
|
||||
|
||||
async def delete_by_filter(self, collection: str, filter: dict[str, Any]) -> int:
|
||||
collection = self._normalize_collection_name(collection)
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
expr = _build_milvus_expr(filter)
|
||||
if not expr:
|
||||
self.ap.logger.warning(
|
||||
f"Milvus delete_by_filter on '{collection}': filter produced empty expression, skipping"
|
||||
)
|
||||
return 0
|
||||
|
||||
await asyncio.to_thread(self.client.delete, collection_name=collection, filter=expr)
|
||||
self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' by filter")
|
||||
return 0 # Milvus delete does not return a count
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
"""Delete a Milvus collection
|
||||
|
||||
|
||||
@@ -5,21 +5,10 @@ from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
from langbot.pkg.vector.filter_utils import normalize_filter, strip_unsupported_fields
|
||||
from langbot.pkg.core import app
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
# pgvector schema only stores these metadata fields.
|
||||
_PG_SUPPORTED_FIELDS = {'text', 'file_id', 'chunk_uuid'}
|
||||
|
||||
# Map schema field names to SQLAlchemy columns (resolved lazily from PgVectorEntry).
|
||||
_PG_COLUMN_MAP = {
|
||||
'text': 'text',
|
||||
'file_id': 'file_id',
|
||||
'chunk_uuid': 'chunk_uuid',
|
||||
}
|
||||
|
||||
|
||||
class PgVectorEntry(Base):
|
||||
"""SQLAlchemy model for pgvector entries"""
|
||||
@@ -34,33 +23,6 @@ class PgVectorEntry(Base):
|
||||
chunk_uuid = Column(String)
|
||||
|
||||
|
||||
def _build_pg_conditions(filter_dict: dict[str, Any]) -> list:
|
||||
"""Translate canonical filter dict into a list of SQLAlchemy conditions."""
|
||||
triples = normalize_filter(filter_dict)
|
||||
triples = strip_unsupported_fields(triples, _PG_SUPPORTED_FIELDS)
|
||||
|
||||
conditions = []
|
||||
for field, op, value in triples:
|
||||
col = getattr(PgVectorEntry, _PG_COLUMN_MAP[field])
|
||||
if op == '$eq':
|
||||
conditions.append(col == value)
|
||||
elif op == '$ne':
|
||||
conditions.append(col != value)
|
||||
elif op == '$gt':
|
||||
conditions.append(col > value)
|
||||
elif op == '$gte':
|
||||
conditions.append(col >= value)
|
||||
elif op == '$lt':
|
||||
conditions.append(col < value)
|
||||
elif op == '$lte':
|
||||
conditions.append(col <= value)
|
||||
elif op == '$in':
|
||||
conditions.append(col.in_(value))
|
||||
elif op == '$nin':
|
||||
conditions.append(col.notin_(value))
|
||||
return conditions
|
||||
|
||||
|
||||
class PgVectorDatabase(VectorDatabase):
|
||||
"""PostgreSQL with pgvector extension database implementation"""
|
||||
|
||||
@@ -147,7 +109,6 @@ class PgVectorDatabase(VectorDatabase):
|
||||
ids: list[str],
|
||||
embeddings_list: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
documents: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Add vector embeddings to pgvector
|
||||
|
||||
@@ -181,15 +142,7 @@ class PgVectorDatabase(VectorDatabase):
|
||||
self.ap.logger.error(f'Error adding embeddings to pgvector: {e}')
|
||||
raise
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection: str,
|
||||
query_embedding: list[float],
|
||||
k: int = 5,
|
||||
search_type: str = 'vector',
|
||||
query_text: str = '',
|
||||
filter: dict[str, Any] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
|
||||
"""Search for similar vectors using cosine distance
|
||||
|
||||
Args:
|
||||
@@ -221,10 +174,6 @@ class PgVectorDatabase(VectorDatabase):
|
||||
.limit(k)
|
||||
)
|
||||
|
||||
if filter:
|
||||
for cond in _build_pg_conditions(filter):
|
||||
stmt = stmt.filter(cond)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
rows = result.fetchall()
|
||||
|
||||
@@ -276,39 +225,6 @@ class PgVectorDatabase(VectorDatabase):
|
||||
self.ap.logger.error(f'Error deleting from pgvector: {e}')
|
||||
raise
|
||||
|
||||
async def delete_by_filter(self, collection: str, filter: dict[str, Any]) -> int:
|
||||
"""Delete vectors matching a metadata filter.
|
||||
|
||||
Args:
|
||||
collection: Collection name
|
||||
filter: Canonical metadata filter dict
|
||||
"""
|
||||
conditions = _build_pg_conditions(filter)
|
||||
if not conditions:
|
||||
self.ap.logger.warning(
|
||||
f"pgvector delete_by_filter on '{collection}': filter produced no conditions, skipping"
|
||||
)
|
||||
return 0
|
||||
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
try:
|
||||
from sqlalchemy import delete
|
||||
|
||||
stmt = delete(PgVectorEntry).where(PgVectorEntry.collection == collection)
|
||||
for cond in conditions:
|
||||
stmt = stmt.where(cond)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
deleted = result.rowcount
|
||||
self.ap.logger.info(f"Deleted {deleted} embeddings from pgvector collection '{collection}' by filter")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.ap.logger.error(f'Error deleting from pgvector by filter: {e}')
|
||||
raise
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
"""Delete all vectors in a collection
|
||||
|
||||
|
||||
@@ -5,37 +5,6 @@ from typing import Any, Dict, List
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
from langbot.pkg.core import app
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
from langbot.pkg.vector.filter_utils import normalize_filter
|
||||
|
||||
|
||||
def _build_qdrant_filter(filter_dict: dict[str, Any]) -> models.Filter:
|
||||
"""Translate canonical filter dict into a Qdrant ``models.Filter``."""
|
||||
triples = normalize_filter(filter_dict)
|
||||
must: list[models.Condition] = []
|
||||
must_not: list[models.Condition] = []
|
||||
|
||||
for field, op, value in triples:
|
||||
if op == '$eq':
|
||||
must.append(models.FieldCondition(key=field, match=models.MatchValue(value=value)))
|
||||
elif op == '$ne':
|
||||
must_not.append(models.FieldCondition(key=field, match=models.MatchValue(value=value)))
|
||||
elif op == '$in':
|
||||
must.append(models.FieldCondition(key=field, match=models.MatchAny(any=value)))
|
||||
elif op == '$nin':
|
||||
must_not.append(models.FieldCondition(key=field, match=models.MatchAny(any=value)))
|
||||
elif op in ('$gt', '$gte', '$lt', '$lte'):
|
||||
range_kwargs: dict[str, Any] = {}
|
||||
if op == '$gt':
|
||||
range_kwargs['gt'] = value
|
||||
elif op == '$gte':
|
||||
range_kwargs['gte'] = value
|
||||
elif op == '$lt':
|
||||
range_kwargs['lt'] = value
|
||||
elif op == '$lte':
|
||||
range_kwargs['lte'] = value
|
||||
must.append(models.FieldCondition(key=field, range=models.Range(**range_kwargs)))
|
||||
|
||||
return models.Filter(must=must or None, must_not=must_not or None)
|
||||
|
||||
|
||||
class QdrantVectorDatabase(VectorDatabase):
|
||||
@@ -79,7 +48,6 @@ class QdrantVectorDatabase(VectorDatabase):
|
||||
ids: List[str],
|
||||
embeddings_list: List[List[float]],
|
||||
metadatas: List[Dict[str, Any]],
|
||||
documents: List[str] | None = None,
|
||||
) -> None:
|
||||
if not embeddings_list:
|
||||
return
|
||||
@@ -92,29 +60,19 @@ class QdrantVectorDatabase(VectorDatabase):
|
||||
await self.client.upsert(collection_name=collection, points=points)
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Qdrant collection '{collection}'.")
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection: str,
|
||||
query_embedding: list[float],
|
||||
k: int = 5,
|
||||
search_type: str = 'vector',
|
||||
query_text: str = '',
|
||||
filter: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
|
||||
exists = await self.client.collection_exists(collection)
|
||||
if not exists:
|
||||
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]]}
|
||||
|
||||
query_kwargs: dict[str, Any] = dict(
|
||||
collection_name=collection,
|
||||
query=query_embedding,
|
||||
limit=k,
|
||||
with_payload=True,
|
||||
)
|
||||
if filter:
|
||||
query_kwargs['query_filter'] = _build_qdrant_filter(filter)
|
||||
|
||||
hits = (await self.client.query_points(**query_kwargs)).points
|
||||
hits = (
|
||||
await self.client.query_points(
|
||||
collection_name=collection,
|
||||
query=query_embedding,
|
||||
limit=k,
|
||||
with_payload=True,
|
||||
)
|
||||
).points
|
||||
ids = [str(hit.id) for hit in hits]
|
||||
metadatas = [hit.payload or {} for hit in hits]
|
||||
# Qdrant's score is similarity; convert to a pseudo-distance for consistency
|
||||
@@ -137,19 +95,6 @@ class QdrantVectorDatabase(VectorDatabase):
|
||||
)
|
||||
self.ap.logger.info(f"Deleted embeddings from Qdrant collection '{collection}' with file_id: {file_id}")
|
||||
|
||||
async def delete_by_filter(self, collection: str, filter: dict[str, Any]) -> int:
|
||||
exists = await self.client.collection_exists(collection)
|
||||
if not exists:
|
||||
return 0
|
||||
|
||||
qdrant_filter = _build_qdrant_filter(filter)
|
||||
await self.client.delete(
|
||||
collection_name=collection,
|
||||
points_selector=qdrant_filter,
|
||||
)
|
||||
self.ap.logger.info(f"Deleted embeddings from Qdrant collection '{collection}' by filter")
|
||||
return 0 # Qdrant delete does not return a count
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
try:
|
||||
await self.client.delete_collection(collection)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Dict, List
|
||||
|
||||
|
||||
from langbot.pkg.core import app
|
||||
from langbot.pkg.vector.vdb import VectorDatabase, SearchType
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
|
||||
try:
|
||||
import pyseekdb
|
||||
@@ -25,13 +25,9 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
SeekDB is an AI-native search database by OceanBase that unifies
|
||||
relational, vector, text, JSON and GIS in a single engine.
|
||||
|
||||
Supports embedded mode, remote server mode, and full-text/hybrid search.
|
||||
Supports both embedded mode and remote server mode.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def supported_search_types(cls) -> list[SearchType]:
|
||||
return [SearchType.VECTOR, SearchType.FULL_TEXT, SearchType.HYBRID]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
if not SEEKDB_AVAILABLE:
|
||||
raise ImportError('pyseekdb is not installed. Install it with: pip install pyseekdb')
|
||||
@@ -93,7 +89,6 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
{
|
||||
'\x00': '',
|
||||
'\\': '\\\\',
|
||||
"'": "''", # Standard SQL escaping (OceanBase NO_BACKSLASH_ESCAPES)
|
||||
'"': '\\"',
|
||||
'\n': '\\n',
|
||||
'\r': '\\r',
|
||||
@@ -116,10 +111,8 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
|
||||
# Collection doesn't exist, create it
|
||||
if vector_size is None:
|
||||
raise ValueError(
|
||||
f"Cannot create SeekDB collection '{collection}' without knowing the vector dimension. "
|
||||
'Ensure add_embeddings is called before any standalone get_or_create_collection.'
|
||||
)
|
||||
# Default dimension if not specified
|
||||
vector_size = 384
|
||||
|
||||
# Create HNSW configuration
|
||||
config = HNSWConfiguration(dimension=vector_size, distance='cosine')
|
||||
@@ -154,12 +147,7 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
return await self._get_or_create_collection_internal(collection)
|
||||
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: List[str],
|
||||
embeddings_list: List[List[float]],
|
||||
metadatas: List[Dict[str, Any]],
|
||||
documents: List[str] | None = None,
|
||||
self, collection: str, ids: List[str], embeddings_list: List[List[float]], metadatas: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Add vector embeddings to the specified collection.
|
||||
|
||||
@@ -168,7 +156,6 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
ids: List of document IDs
|
||||
embeddings_list: List of embedding vectors
|
||||
metadatas: List of metadata dictionaries
|
||||
documents: Optional raw text documents for full-text search support
|
||||
"""
|
||||
if not embeddings_list:
|
||||
return
|
||||
@@ -179,33 +166,17 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
|
||||
cleaned_metadatas = [self._clean_metadata(meta) for meta in metadatas]
|
||||
|
||||
kwargs: Dict[str, Any] = dict(ids=ids, embeddings=embeddings_list, metadatas=cleaned_metadatas)
|
||||
if documents is not None:
|
||||
kwargs['documents'] = [doc.translate(self._escape_table) for doc in documents]
|
||||
await asyncio.to_thread(coll.add, **kwargs)
|
||||
await asyncio.to_thread(coll.add, ids=ids, embeddings=embeddings_list, metadatas=cleaned_metadatas)
|
||||
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to SeekDB collection '{collection}'")
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection: str,
|
||||
query_embedding: List[float],
|
||||
k: int = 5,
|
||||
search_type: str = 'vector',
|
||||
query_text: str = '',
|
||||
filter: Dict[str, Any] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
async def search(self, collection: str, query_embedding: List[float], k: int = 5) -> Dict[str, Any]:
|
||||
"""Search for the most similar vectors in the specified collection.
|
||||
|
||||
SeekDB supports vector, full-text, and hybrid search modes.
|
||||
|
||||
Args:
|
||||
collection: Collection name
|
||||
query_embedding: Query vector (used for vector and hybrid modes)
|
||||
query_embedding: Query vector
|
||||
k: Number of results to return
|
||||
search_type: One of 'vector', 'full_text', 'hybrid'
|
||||
query_text: Raw query text (used for full_text and hybrid modes)
|
||||
filter: Optional metadata filters (Chroma-style ``where`` syntax).
|
||||
|
||||
Returns:
|
||||
Dictionary with 'ids', 'metadatas', 'distances' keys
|
||||
@@ -222,73 +193,11 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
else:
|
||||
coll = self._collections[collection]
|
||||
|
||||
# Route by search type.
|
||||
# pyseekdb's query() always requires embeddings, so full-text and
|
||||
# hybrid modes use hybrid_search() which supports text-only queries
|
||||
# and returns the same nested-list format with distances.
|
||||
if search_type == SearchType.FULL_TEXT:
|
||||
if not query_text:
|
||||
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]]}
|
||||
# Perform query
|
||||
# SeekDB's query() returns: {'ids': [[...]], 'metadatas': [[...]], 'distances': [[...]]}
|
||||
results = await asyncio.to_thread(coll.query, query_embeddings=query_embedding, n_results=k)
|
||||
|
||||
query_cfg: Dict[str, Any] = {
|
||||
'where_document': {'$contains': query_text},
|
||||
'n_results': k,
|
||||
}
|
||||
if filter:
|
||||
query_cfg['where'] = filter
|
||||
|
||||
# TODO: pyseekdb hybrid_search with query-only (no knn) returns None
|
||||
# for IDs due to column name mismatch (*/_id vs _id).
|
||||
# See: https://github.com/oceanbase/pyseekdb/issues/171
|
||||
results = await asyncio.to_thread(
|
||||
coll.hybrid_search,
|
||||
query=query_cfg,
|
||||
knn=None,
|
||||
n_results=k,
|
||||
include=['documents', 'metadatas'],
|
||||
)
|
||||
|
||||
elif search_type == SearchType.HYBRID:
|
||||
if not query_text:
|
||||
# Fall back to pure vector search when no text is provided
|
||||
query_kwargs: Dict[str, Any] = {
|
||||
'n_results': k,
|
||||
'query_embeddings': query_embedding,
|
||||
}
|
||||
if filter:
|
||||
query_kwargs['where'] = filter
|
||||
results = await asyncio.to_thread(coll.query, **query_kwargs)
|
||||
else:
|
||||
query_cfg = {
|
||||
'where_document': {'$contains': query_text},
|
||||
'n_results': k,
|
||||
}
|
||||
knn_cfg: Dict[str, Any] = {
|
||||
'query_embeddings': query_embedding,
|
||||
'n_results': k,
|
||||
}
|
||||
if filter:
|
||||
query_cfg['where'] = filter
|
||||
knn_cfg['where'] = filter
|
||||
|
||||
results = await asyncio.to_thread(
|
||||
coll.hybrid_search,
|
||||
query=query_cfg,
|
||||
knn=knn_cfg,
|
||||
rank={'rrf': {}},
|
||||
n_results=k,
|
||||
include=['documents', 'metadatas'],
|
||||
)
|
||||
else:
|
||||
# Default: vector search via query()
|
||||
query_kwargs = {'n_results': k, 'query_embeddings': query_embedding}
|
||||
if filter:
|
||||
query_kwargs['where'] = filter
|
||||
results = await asyncio.to_thread(coll.query, **query_kwargs)
|
||||
|
||||
self.ap.logger.info(
|
||||
f"SeekDB {search_type} search in '{collection}' returned {len(results.get('ids', [[]])[0])} results"
|
||||
)
|
||||
self.ap.logger.info(f"SeekDB search in '{collection}' returned {len(results.get('ids', [[]])[0])} results")
|
||||
|
||||
return results
|
||||
|
||||
@@ -318,28 +227,6 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
|
||||
self.ap.logger.info(f"Deleted embeddings from SeekDB collection '{collection}' with file_id: {file_id}")
|
||||
|
||||
async def delete_by_filter(self, collection: str, filter: Dict[str, Any]) -> int:
|
||||
"""Delete vectors from the collection by metadata filter.
|
||||
|
||||
Args:
|
||||
collection: Collection name
|
||||
filter: Chroma-style ``where`` filter dict
|
||||
"""
|
||||
exists = await asyncio.to_thread(self.client.has_collection, collection)
|
||||
if not exists:
|
||||
self.ap.logger.warning(f"SeekDB collection '{collection}' not found for deletion")
|
||||
return 0
|
||||
|
||||
if collection not in self._collections:
|
||||
coll = await asyncio.to_thread(self.client.get_collection, collection, embedding_function=None)
|
||||
self._collections[collection] = coll
|
||||
else:
|
||||
coll = self._collections[collection]
|
||||
|
||||
await asyncio.to_thread(coll.delete, where=filter)
|
||||
self.ap.logger.info(f"Deleted embeddings from SeekDB collection '{collection}' by filter")
|
||||
return 0 # SeekDB delete does not return a count
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
"""Delete the entire collection.
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ admins: []
|
||||
api:
|
||||
port: 5300
|
||||
webhook_prefix: 'http://127.0.0.1:5300'
|
||||
extra_webhook_prefix: ''
|
||||
command:
|
||||
enable: true
|
||||
prefix:
|
||||
|
||||
@@ -17,10 +17,6 @@
|
||||
"prefix": [],
|
||||
"regexp": []
|
||||
},
|
||||
"message-aggregation": {
|
||||
"enabled": false,
|
||||
"delay": 1.5
|
||||
},
|
||||
"misc": {
|
||||
"combine-quote-message": true
|
||||
}
|
||||
@@ -41,10 +37,7 @@
|
||||
"runner": "local-agent"
|
||||
},
|
||||
"local-agent": {
|
||||
"model": {
|
||||
"primary": "",
|
||||
"fallbacks": []
|
||||
},
|
||||
"model": "",
|
||||
"max-round": 10,
|
||||
"prompt": [
|
||||
{
|
||||
@@ -98,12 +91,11 @@
|
||||
"max": 0
|
||||
},
|
||||
"misc": {
|
||||
"exception-handling": "show-hint",
|
||||
"failure-hint": "Request failed.",
|
||||
"hide-exception": true,
|
||||
"at-sender": true,
|
||||
"quote-origin": true,
|
||||
"track-function-calls": false,
|
||||
"remove-think": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -59,11 +59,8 @@ stages:
|
||||
label:
|
||||
en_US: Model
|
||||
zh_Hans: 模型
|
||||
type: model-fallback-selector
|
||||
type: llm-model-selector
|
||||
required: true
|
||||
default:
|
||||
primary: ''
|
||||
fallbacks: []
|
||||
- name: max-round
|
||||
label:
|
||||
en_US: Max Round
|
||||
|
||||
@@ -78,39 +78,13 @@ stages:
|
||||
en_US: Misc
|
||||
zh_Hans: 杂项
|
||||
config:
|
||||
- name: exception-handling
|
||||
- name: hide-exception
|
||||
label:
|
||||
en_US: Exception Handling Strategy
|
||||
zh_Hans: 异常处理策略
|
||||
description:
|
||||
en_US: Controls how error messages are displayed to the user when an AI request fails
|
||||
zh_Hans: 控制 AI 请求失败时向用户展示错误信息的方式
|
||||
type: select
|
||||
en_US: Hide Exception
|
||||
zh_Hans: 不输出异常信息给用户
|
||||
type: boolean
|
||||
required: true
|
||||
default: show-hint
|
||||
options:
|
||||
- name: show-error
|
||||
label:
|
||||
en_US: Show Full Error
|
||||
zh_Hans: 显示完整报错信息
|
||||
- name: show-hint
|
||||
label:
|
||||
en_US: Show Failure Hint
|
||||
zh_Hans: 仅文字提示
|
||||
- name: hide
|
||||
label:
|
||||
en_US: Hide All
|
||||
zh_Hans: 不显示任何异常信息
|
||||
- name: failure-hint
|
||||
label:
|
||||
en_US: Failure Hint Text
|
||||
zh_Hans: 失败提示文本
|
||||
description:
|
||||
en_US: The text to display when a request fails. Only effective when Exception Handling Strategy is set to "Show Failure Hint"
|
||||
zh_Hans: 请求失败时显示的提示文本,仅在异常处理策略设置为"仅文字提示"时生效
|
||||
type: string
|
||||
required: false
|
||||
default: 'Request failed.'
|
||||
default: true
|
||||
- name: at-sender
|
||||
label:
|
||||
en_US: At Sender
|
||||
@@ -145,4 +119,3 @@ stages:
|
||||
type: boolean
|
||||
required: true
|
||||
default: false
|
||||
|
||||
|
||||
@@ -123,34 +123,6 @@ stages:
|
||||
type: array[string]
|
||||
required: true
|
||||
default: []
|
||||
- name: message-aggregation
|
||||
label:
|
||||
en_US: Message Aggregation
|
||||
zh_Hans: 消息聚合
|
||||
description:
|
||||
en_US: When a user sends multiple messages consecutively, wait for a period and merge them into one before processing
|
||||
zh_Hans: 当用户连续发送多条消息时,等待一段时间后合并为一条消息再处理(防抖)
|
||||
config:
|
||||
- name: enabled
|
||||
label:
|
||||
en_US: Enable Message Aggregation
|
||||
zh_Hans: 启用消息聚合
|
||||
description:
|
||||
en_US: If enabled, consecutive messages from the same user will be merged after a delay
|
||||
zh_Hans: 如果启用,同一用户连续发送的消息将在延迟后合并处理
|
||||
type: boolean
|
||||
required: true
|
||||
default: false
|
||||
- name: delay
|
||||
label:
|
||||
en_US: Aggregation Delay (seconds)
|
||||
zh_Hans: 聚合延迟(秒)
|
||||
description:
|
||||
en_US: 'Wait time before merging messages. Range: 1.0-10.0 seconds.'
|
||||
zh_Hans: '合并消息前的等待时间。范围:1.0-10.0 秒。'
|
||||
type: float
|
||||
required: true
|
||||
default: 1.5
|
||||
- name: misc
|
||||
label:
|
||||
en_US: Misc
|
||||
|
||||
@@ -91,15 +91,14 @@ class TestWebhookDisplayPrefix:
|
||||
|
||||
def test_default_webhook_prefix(self):
|
||||
"""Test that the default webhook display prefix is correctly set"""
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
||||
|
||||
# Should have the default value
|
||||
assert cfg['api']['webhook_prefix'] == 'http://127.0.0.1:5300'
|
||||
assert cfg['api']['extra_webhook_prefix'] == ''
|
||||
|
||||
def test_webhook_prefix_env_override(self):
|
||||
"""Test overriding webhook_prefix via environment variable"""
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
||||
|
||||
# Set environment variable
|
||||
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com:8080'
|
||||
@@ -113,7 +112,7 @@ class TestWebhookDisplayPrefix:
|
||||
|
||||
def test_webhook_prefix_with_custom_domain(self):
|
||||
"""Test webhook_prefix with custom domain"""
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
||||
|
||||
# Set to a custom domain
|
||||
os.environ['API__WEBHOOK_PREFIX'] = 'https://bot.mycompany.com'
|
||||
@@ -127,7 +126,7 @@ class TestWebhookDisplayPrefix:
|
||||
|
||||
def test_webhook_prefix_with_subdirectory(self):
|
||||
"""Test webhook_prefix with subdirectory path"""
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
||||
|
||||
# Set to a URL with subdirectory
|
||||
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com/langbot'
|
||||
@@ -139,37 +138,6 @@ class TestWebhookDisplayPrefix:
|
||||
# Cleanup
|
||||
del os.environ['API__WEBHOOK_PREFIX']
|
||||
|
||||
def test_extra_webhook_prefix_default_empty(self):
|
||||
"""Test that extra_webhook_prefix defaults to empty string"""
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||
|
||||
bot_uuid = 'test-bot-uuid'
|
||||
webhook_prefix = cfg['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
||||
extra_webhook_prefix = cfg['api'].get('extra_webhook_prefix', '')
|
||||
webhook_url = f'/bots/{bot_uuid}'
|
||||
|
||||
assert f'{webhook_prefix}{webhook_url}' == 'http://127.0.0.1:5300/bots/test-bot-uuid'
|
||||
# extra should be empty when not configured
|
||||
assert extra_webhook_prefix == ''
|
||||
|
||||
def test_extra_webhook_prefix_env_override(self):
|
||||
"""Test overriding extra_webhook_prefix via environment variable"""
|
||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||
|
||||
os.environ['API__EXTRA_WEBHOOK_PREFIX'] = 'https://extra.example.com'
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'
|
||||
|
||||
bot_uuid = 'test-bot-uuid'
|
||||
extra_prefix = result['api']['extra_webhook_prefix']
|
||||
webhook_url = f'/bots/{bot_uuid}'
|
||||
assert f'{extra_prefix}{webhook_url}' == 'https://extra.example.com/bots/test-bot-uuid'
|
||||
|
||||
# Cleanup
|
||||
del os.environ['API__EXTRA_WEBHOOK_PREFIX']
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user