mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-08 23:06:03 +00:00
Compare commits
87 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f4db53b759 | ||
|
|
9f90341dcb | ||
|
|
67b726afb2 | ||
|
|
01852b81d4 | ||
|
|
4d6f109788 | ||
|
|
e1e5e7aedf | ||
|
|
cd53abc440 | ||
|
|
16a15a122a | ||
|
|
6fa653f232 | ||
|
|
c13971d7d6 | ||
|
|
9c659ce8fa | ||
|
|
c9fc64360f | ||
|
|
88a04fdbe8 | ||
|
|
bbe019f0c6 | ||
|
|
865f6ee81b | ||
|
|
bd5ec59b7c | ||
|
|
9c0cc1003d | ||
|
|
ea07d8ad00 | ||
|
|
3ac3fad4bc | ||
|
|
254a13bba3 | ||
|
|
4355f0fa78 | ||
|
|
031737f05d | ||
|
|
9e366fc536 | ||
|
|
8bd6442965 | ||
|
|
1a1eadb282 | ||
|
|
eed72b1c12 | ||
|
|
351350ea03 | ||
|
|
bc3d6ba92f | ||
|
|
345e4baf2a | ||
|
|
6c64dc057f | ||
|
|
eec0a9c9d9 | ||
|
|
6896a55485 | ||
|
|
4b0fad233e | ||
|
|
52eb991a70 | ||
|
|
10c716be0c | ||
|
|
6e77351eda | ||
|
|
20f5ebd9b8 | ||
|
|
d2c75329cf | ||
|
|
7e2fe082f0 | ||
|
|
d451b059fd | ||
|
|
93c52fcd4c | ||
|
|
f1608682e6 | ||
|
|
077e631c13 | ||
|
|
d7df1f05d1 | ||
|
|
8b8cfb76de | ||
|
|
79311ccde3 | ||
|
|
def798bf1f | ||
|
|
5290834b8b | ||
|
|
89064a9d5b | ||
|
|
8c2aef3734 | ||
|
|
3fb9e542b6 | ||
|
|
01844d8687 | ||
|
|
2655425fbe | ||
|
|
bd15b630b0 | ||
|
|
fe5ce68436 | ||
|
|
0541b05966 | ||
|
|
13cb0aa9be | ||
|
|
a048369b38 | ||
|
|
9ae0c263dc | ||
|
|
a4e66f6459 | ||
|
|
2a74a8d6ae | ||
|
|
d31f25c8df | ||
|
|
11c05ea8db | ||
|
|
2b8bd1cc71 | ||
|
|
9148e02679 | ||
|
|
fd15284d91 | ||
|
|
8c7a0ec027 | ||
|
|
a1cef5c9bf | ||
|
|
90438cec36 | ||
|
|
95dd19f4d7 | ||
|
|
c64eb58cf8 | ||
|
|
fbd3d7ae3a | ||
|
|
40c7b0f731 | ||
|
|
cadcf10047 | ||
|
|
3e8f47fd97 | ||
|
|
b11ae55c6e | ||
|
|
2d63d528c6 | ||
|
|
10f253015d | ||
|
|
b34ebf85a6 | ||
|
|
06d3298cde | ||
|
|
614621ab7b | ||
|
|
8600d0a8e7 | ||
|
|
b83e6a53be | ||
|
|
88132dff8a | ||
|
|
2dc5999583 | ||
|
|
73461814c9 | ||
|
|
210e5e50d3 |
@@ -34,8 +34,6 @@
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 什么是 LangBot?
|
|
||||||
|
|
||||||
LangBot 是一个**开源的生产级平台**,用于构建 AI 驱动的即时通信机器人。它将大语言模型(LLM)连接到各种聊天平台,帮助你创建能够对话、执行任务、并集成到现有工作流程中的智能 Agent。
|
LangBot 是一个**开源的生产级平台**,用于构建 AI 驱动的即时通信机器人。它将大语言模型(LLM)连接到各种聊天平台,帮助你创建能够对话、执行任务、并集成到现有工作流程中的智能 Agent。
|
||||||
|
|
||||||
### 核心能力
|
### 核心能力
|
||||||
@@ -43,7 +41,7 @@ LangBot 是一个**开源的生产级平台**,用于构建 AI 驱动的即时
|
|||||||
- **AI 对话与 Agent** — 多轮对话、工具调用、多模态、流式输出。自带 RAG(知识库),深度集成 [Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io)、[Langflow](https://langflow.org) 等 LLMOps 平台。
|
- **AI 对话与 Agent** — 多轮对话、工具调用、多模态、流式输出。自带 RAG(知识库),深度集成 [Dify](https://dify.ai)、[Coze](https://coze.com)、[n8n](https://n8n.io)、[Langflow](https://langflow.org) 等 LLMOps 平台。
|
||||||
- **全平台支持** — 一套代码,覆盖 QQ、微信、企业微信、飞书、钉钉、Discord、Telegram、Slack、LINE、KOOK 等平台。
|
- **全平台支持** — 一套代码,覆盖 QQ、微信、企业微信、飞书、钉钉、Discord、Telegram、Slack、LINE、KOOK 等平台。
|
||||||
- **生产就绪** — 访问控制、限速、敏感词过滤、全面监控与异常处理,已被多家企业采用。
|
- **生产就绪** — 访问控制、限速、敏感词过滤、全面监控与异常处理,已被多家企业采用。
|
||||||
- **插件生态** — 数百个插件,事件驱动架构,组件扩展,适配 [MCP 协议](https://modelcontextprotocol.io/)。
|
- **插件生态** — 数百个插件,跨进程的事件驱动架构,组件扩展,适配 [MCP 协议](https://modelcontextprotocol.io/)。
|
||||||
- **Web 管理面板** — 通过浏览器直观地配置、管理和监控机器人,无需手动编辑配置文件。
|
- **Web 管理面板** — 通过浏览器直观地配置、管理和监控机器人,无需手动编辑配置文件。
|
||||||
- **多流水线架构** — 不同机器人用于不同场景,具备全面的监控和异常处理能力。
|
- **多流水线架构** — 不同机器人用于不同场景,具备全面的监控和异常处理能力。
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "langbot"
|
name = "langbot"
|
||||||
version = "4.8.6"
|
version = "4.9.4"
|
||||||
description = "Production-grade platform for building agentic IM bots"
|
description = "Production-grade platform for building agentic IM bots"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license-files = ["LICENSE"]
|
license-files = ["LICENSE"]
|
||||||
@@ -61,16 +61,17 @@ dependencies = [
|
|||||||
"html2text>=2024.2.26",
|
"html2text>=2024.2.26",
|
||||||
"langchain>=0.2.0",
|
"langchain>=0.2.0",
|
||||||
"langchain-text-splitters>=0.0.1",
|
"langchain-text-splitters>=0.0.1",
|
||||||
"chromadb>=0.4.24",
|
"chromadb>=1.0.0,<2.0.0",
|
||||||
"qdrant-client (>=1.15.1,<2.0.0)",
|
"qdrant-client (>=1.15.1,<2.0.0)",
|
||||||
"pyseekdb==1.0.0b7",
|
"pyseekdb==1.1.0.post3",
|
||||||
"langbot-plugin==0.2.7",
|
"langbot-plugin==0.3.5",
|
||||||
"asyncpg>=0.30.0",
|
"asyncpg>=0.30.0",
|
||||||
"line-bot-sdk>=3.19.0",
|
"line-bot-sdk>=3.19.0",
|
||||||
"tboxsdk>=0.0.10",
|
"tboxsdk>=0.0.10",
|
||||||
"boto3>=1.35.0",
|
"boto3>=1.35.0",
|
||||||
"pymilvus>=2.6.4",
|
"pymilvus>=2.6.4",
|
||||||
"pgvector>=0.4.1",
|
"pgvector>=0.4.1",
|
||||||
|
"botocore>=1.42.39",
|
||||||
]
|
]
|
||||||
keywords = [
|
keywords = [
|
||||||
"bot",
|
"bot",
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
"""LangBot - Production-grade platform for building agentic IM bots"""
|
"""LangBot - Production-grade platform for building agentic IM bots"""
|
||||||
|
|
||||||
__version__ = '4.8.6'
|
__version__ = '4.9.4'
|
||||||
|
|||||||
3
src/langbot/libs/openclaw_weixin_api/__init__.py
Normal file
3
src/langbot/libs/openclaw_weixin_api/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .client import OpenClawWeixinClient as OpenClawWeixinClient
|
||||||
|
from .types import ApiError as ApiError
|
||||||
|
from .types import LoginResult as LoginResult
|
||||||
807
src/langbot/libs/openclaw_weixin_api/client.py
Normal file
807
src/langbot/libs/openclaw_weixin_api/client.py
Normal file
@@ -0,0 +1,807 @@
|
|||||||
|
"""Async HTTP client for the OpenClaw WeChat API.
|
||||||
|
|
||||||
|
Implements the iLink Bot API protocol.
|
||||||
|
Reference: https://github.com/epiral/weixin-bot
|
||||||
|
|
||||||
|
Endpoints: getUpdates (long-poll), sendMessage, getUploadUrl, getConfig, sendTyping.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
import typing
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from .types import (
|
||||||
|
ApiError,
|
||||||
|
CDNMedia,
|
||||||
|
FileItem,
|
||||||
|
GetConfigResponse,
|
||||||
|
GetUpdatesResponse,
|
||||||
|
GetUploadUrlResponse,
|
||||||
|
ImageItem,
|
||||||
|
LoginResult,
|
||||||
|
MessageItem,
|
||||||
|
QRCodeResponse,
|
||||||
|
QRStatusResponse,
|
||||||
|
RefMessage,
|
||||||
|
TextItem,
|
||||||
|
VideoItem,
|
||||||
|
VoiceItem,
|
||||||
|
WeixinMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger('openclaw-weixin-sdk')
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = 'https://ilinkai.weixin.qq.com'
|
||||||
|
CDN_BASE_URL = 'https://novac2c.cdn.weixin.qq.com/c2c'
|
||||||
|
|
||||||
|
CHANNEL_VERSION = '1.0.0'
|
||||||
|
|
||||||
|
DEFAULT_API_TIMEOUT = 15
|
||||||
|
DEFAULT_LONG_POLL_TIMEOUT = 40
|
||||||
|
DEFAULT_CONFIG_TIMEOUT = 10
|
||||||
|
DEFAULT_QR_POLL_TIMEOUT = 35
|
||||||
|
|
||||||
|
SESSION_EXPIRED_ERRCODE = -14
|
||||||
|
|
||||||
|
DEFAULT_BOT_TYPE = '3'
|
||||||
|
|
||||||
|
# Maximum text length per message chunk (WeChat limit)
|
||||||
|
MAX_TEXT_CHUNK_SIZE = 2000
|
||||||
|
|
||||||
|
|
||||||
|
def _random_wechat_uin() -> str:
|
||||||
|
"""Generate the X-WECHAT-UIN header: random uint32 -> decimal string -> base64."""
|
||||||
|
rand_bytes = os.urandom(4)
|
||||||
|
uint32_val = struct.unpack('>I', rand_bytes)[0]
|
||||||
|
return base64.b64encode(str(uint32_val).encode('utf-8')).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def _build_base_info() -> dict:
|
||||||
|
"""Build the base_info payload included in every API request."""
|
||||||
|
return {'channel_version': CHANNEL_VERSION}
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk_text(text: str, max_size: int = MAX_TEXT_CHUNK_SIZE) -> list[str]:
|
||||||
|
"""Split long text into chunks that fit within WeChat's message size limit."""
|
||||||
|
if len(text) <= max_size:
|
||||||
|
return [text]
|
||||||
|
chunks = []
|
||||||
|
while text:
|
||||||
|
chunks.append(text[:max_size])
|
||||||
|
text = text[max_size:]
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
class OpenClawWeixinClient:
|
||||||
|
"""Async client for the OpenClaw WeChat HTTP JSON API."""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str, token: str):
|
||||||
|
self.base_url = base_url.rstrip('/')
|
||||||
|
self.token = token
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
async def _get_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._session is None or self._session.closed:
|
||||||
|
self._session = aiohttp.ClientSession()
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
def _build_headers(self) -> dict[str, str]:
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'AuthorizationType': 'ilink_bot_token',
|
||||||
|
'X-WECHAT-UIN': _random_wechat_uin(),
|
||||||
|
}
|
||||||
|
if self.token:
|
||||||
|
headers['Authorization'] = f'Bearer {self.token}'
|
||||||
|
return headers
|
||||||
|
|
||||||
|
async def _post(self, endpoint: str, payload: dict, timeout: float = DEFAULT_API_TIMEOUT) -> dict:
|
||||||
|
"""Make a POST request and return the JSON response.
|
||||||
|
|
||||||
|
Raises ApiError on HTTP errors or when the response contains a non-zero errcode.
|
||||||
|
"""
|
||||||
|
payload['base_info'] = _build_base_info()
|
||||||
|
|
||||||
|
session = await self._get_session()
|
||||||
|
url = f'{self.base_url}/{endpoint}'
|
||||||
|
headers = self._build_headers()
|
||||||
|
|
||||||
|
async with session.post(
|
||||||
|
url, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=timeout)
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise ApiError(
|
||||||
|
f'OpenClaw API error {resp.status}: {text}',
|
||||||
|
status=resp.status,
|
||||||
|
)
|
||||||
|
data = await resp.json(content_type=None)
|
||||||
|
|
||||||
|
# Check for application-level errors in the response body
|
||||||
|
errcode = data.get('errcode') or data.get('ret')
|
||||||
|
if errcode and errcode != 0:
|
||||||
|
raise ApiError(
|
||||||
|
data.get('errmsg') or f'API errcode {errcode}',
|
||||||
|
status=200,
|
||||||
|
code=errcode,
|
||||||
|
payload=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def get_updates(
|
||||||
|
self, get_updates_buf: str = '', timeout: float = DEFAULT_LONG_POLL_TIMEOUT
|
||||||
|
) -> GetUpdatesResponse:
|
||||||
|
"""Long-poll for new messages.
|
||||||
|
|
||||||
|
Note: This method does NOT raise ApiError for errcode responses —
|
||||||
|
it returns them in the GetUpdatesResponse so the caller can handle
|
||||||
|
session expiry and other errors with full context.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Bypass the errcode check in _post since get_updates needs
|
||||||
|
# to return error info (e.g. session expired) to the caller.
|
||||||
|
payload: dict = {'get_updates_buf': get_updates_buf}
|
||||||
|
payload['base_info'] = _build_base_info()
|
||||||
|
|
||||||
|
session = await self._get_session()
|
||||||
|
url = f'{self.base_url}/ilink/bot/getupdates'
|
||||||
|
headers = self._build_headers()
|
||||||
|
|
||||||
|
async with session.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise ApiError(
|
||||||
|
f'OpenClaw API error {resp.status}: {text}',
|
||||||
|
status=resp.status,
|
||||||
|
)
|
||||||
|
data = await resp.json(content_type=None)
|
||||||
|
|
||||||
|
except (asyncio.TimeoutError, aiohttp.ServerTimeoutError):
|
||||||
|
return GetUpdatesResponse(ret=0, msgs=[], get_updates_buf=get_updates_buf)
|
||||||
|
except ApiError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
if 'timeout' in str(e).lower():
|
||||||
|
return GetUpdatesResponse(ret=0, msgs=[], get_updates_buf=get_updates_buf)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return _parse_get_updates_response(data)
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
self,
|
||||||
|
to_user_id: str,
|
||||||
|
item_list: list[MessageItem],
|
||||||
|
context_token: str = '',
|
||||||
|
) -> None:
|
||||||
|
"""Send a message to a user."""
|
||||||
|
items_payload = [_message_item_to_dict(item) for item in item_list]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
'msg': {
|
||||||
|
'from_user_id': '',
|
||||||
|
'to_user_id': to_user_id,
|
||||||
|
'client_id': f'langbot-{uuid.uuid4().hex[:16]}',
|
||||||
|
'message_type': WeixinMessage.TYPE_BOT,
|
||||||
|
'message_state': WeixinMessage.STATE_FINISH,
|
||||||
|
'item_list': items_payload,
|
||||||
|
'context_token': context_token or None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
await self._post('ilink/bot/sendmessage', payload)
|
||||||
|
|
||||||
|
async def send_text(self, to_user_id: str, text: str, context_token: str = '') -> None:
|
||||||
|
"""Send a plain text message, automatically chunking if too long."""
|
||||||
|
chunks = _chunk_text(text)
|
||||||
|
for chunk in chunks:
|
||||||
|
item = MessageItem(type=MessageItem.TEXT, text_item=TextItem(text=chunk))
|
||||||
|
await self.send_message(to_user_id, [item], context_token)
|
||||||
|
|
||||||
|
async def get_config(self, ilink_user_id: str, context_token: str = '') -> GetConfigResponse:
|
||||||
|
"""Get bot config including typing_ticket."""
|
||||||
|
data = await self._post(
|
||||||
|
'ilink/bot/getconfig',
|
||||||
|
{'ilink_user_id': ilink_user_id, 'context_token': context_token or None},
|
||||||
|
timeout=DEFAULT_CONFIG_TIMEOUT,
|
||||||
|
)
|
||||||
|
return GetConfigResponse(
|
||||||
|
ret=data.get('ret'),
|
||||||
|
errmsg=data.get('errmsg'),
|
||||||
|
typing_ticket=data.get('typing_ticket'),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_typing(self, ilink_user_id: str, typing_ticket: str, status: int = 1) -> None:
|
||||||
|
"""Send typing indicator. status: 1=typing, 2=cancel."""
|
||||||
|
await self._post(
|
||||||
|
'ilink/bot/sendtyping',
|
||||||
|
{
|
||||||
|
'ilink_user_id': ilink_user_id,
|
||||||
|
'typing_ticket': typing_ticket,
|
||||||
|
'status': status,
|
||||||
|
},
|
||||||
|
timeout=DEFAULT_CONFIG_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stop_typing(self, ilink_user_id: str, typing_ticket: str) -> None:
|
||||||
|
"""Cancel the typing indicator for a user."""
|
||||||
|
await self.send_typing(ilink_user_id, typing_ticket, status=2)
|
||||||
|
|
||||||
|
async def download_media(
|
||||||
|
self,
|
||||||
|
media: CDNMedia,
|
||||||
|
) -> bytes:
|
||||||
|
"""Download and decrypt a file from the WeChat CDN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
media: CDNMedia object with encrypt_query_param and aes_key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decrypted file bytes.
|
||||||
|
"""
|
||||||
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
from cryptography.hazmat.primitives.padding import PKCS7
|
||||||
|
|
||||||
|
if not media.encrypt_query_param:
|
||||||
|
raise ApiError('CDN media has no encrypt_query_param', status=0)
|
||||||
|
if not media.aes_key:
|
||||||
|
raise ApiError('CDN media has no aes_key', status=0)
|
||||||
|
|
||||||
|
# Derive 16-byte AES key
|
||||||
|
# aes_key is base64-encoded; the decoded content may be:
|
||||||
|
# - raw 16 bytes (direct AES key)
|
||||||
|
# - 32-char hex string (decode hex to get 16 bytes)
|
||||||
|
raw = base64.b64decode(media.aes_key)
|
||||||
|
if len(raw) == 16:
|
||||||
|
aes_key = raw
|
||||||
|
elif len(raw) == 32:
|
||||||
|
# Hex-encoded 16-byte key
|
||||||
|
aes_key = bytes.fromhex(raw.decode('utf-8'))
|
||||||
|
else:
|
||||||
|
raise ApiError(f'Invalid AES key length: {len(raw)} (expected 16 or 32)', status=0)
|
||||||
|
|
||||||
|
# Download encrypted bytes from CDN
|
||||||
|
session = await self._get_session()
|
||||||
|
cdn_url = f'{CDN_BASE_URL}/download?encrypted_query_param={quote(media.encrypt_query_param, safe="")}'
|
||||||
|
|
||||||
|
async with session.get(cdn_url, timeout=aiohttp.ClientTimeout(total=120)) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise ApiError(f'CDN download failed: {resp.status} {text}', status=resp.status)
|
||||||
|
encrypted = await resp.read()
|
||||||
|
|
||||||
|
# Decrypt AES-128-ECB with PKCS7 padding
|
||||||
|
cipher = Cipher(algorithms.AES(aes_key), modes.ECB())
|
||||||
|
decryptor = cipher.decryptor()
|
||||||
|
padded = decryptor.update(encrypted) + decryptor.finalize()
|
||||||
|
|
||||||
|
unpadder = PKCS7(128).unpadder()
|
||||||
|
return unpadder.update(padded) + unpadder.finalize()
|
||||||
|
|
||||||
|
async def upload_media(
|
||||||
|
self,
|
||||||
|
file_bytes: bytes,
|
||||||
|
to_user_id: str,
|
||||||
|
media_type: int,
|
||||||
|
) -> CDNMedia:
|
||||||
|
"""Encrypt and upload media to WeChat CDN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_bytes: Raw file bytes to upload.
|
||||||
|
to_user_id: Recipient user ID.
|
||||||
|
media_type: 1=IMAGE, 2=VIDEO, 3=FILE, 4=VOICE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CDNMedia with encrypt_query_param and aes_key for use in sendMessage.
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
from cryptography.hazmat.primitives.padding import PKCS7
|
||||||
|
|
||||||
|
# 1. Generate random 16-byte AES key
|
||||||
|
raw_key = os.urandom(16)
|
||||||
|
aes_key_hex = raw_key.hex() # 32-char hex string
|
||||||
|
|
||||||
|
# 2. Encode key for CDNMedia: base64(hex_string) — same for all media types
|
||||||
|
# Matches official SDK: Buffer.from(aeskey_hex).toString("base64")
|
||||||
|
encoded_key = base64.b64encode(aes_key_hex.encode('utf-8')).decode('utf-8')
|
||||||
|
|
||||||
|
# 3. Encrypt file with AES-128-ECB + PKCS7
|
||||||
|
padder = PKCS7(128).padder()
|
||||||
|
padded = padder.update(file_bytes) + padder.finalize()
|
||||||
|
cipher = Cipher(algorithms.AES(raw_key), modes.ECB())
|
||||||
|
encryptor = cipher.encryptor()
|
||||||
|
encrypted = encryptor.update(padded) + encryptor.finalize()
|
||||||
|
|
||||||
|
# 4. Get upload URL
|
||||||
|
raw_md5 = hashlib.md5(file_bytes).hexdigest()
|
||||||
|
filekey = os.urandom(16).hex() # 32-char hex, matches official SDK
|
||||||
|
|
||||||
|
upload_resp = await self.get_upload_url(
|
||||||
|
filekey=filekey,
|
||||||
|
media_type=media_type,
|
||||||
|
to_user_id=to_user_id,
|
||||||
|
rawsize=len(file_bytes),
|
||||||
|
rawfilemd5=raw_md5,
|
||||||
|
filesize=len(encrypted),
|
||||||
|
aeskey=aes_key_hex, # hex string, as expected by the API
|
||||||
|
)
|
||||||
|
|
||||||
|
if not upload_resp.upload_param:
|
||||||
|
raise ApiError('Failed to get upload URL', status=0)
|
||||||
|
|
||||||
|
# 5. Upload to CDN
|
||||||
|
# upload_param is an opaque token from the server — pass it as-is
|
||||||
|
session = await self._get_session()
|
||||||
|
cdn_url = f'{CDN_BASE_URL}/upload?encrypted_query_param={quote(upload_resp.upload_param, safe="")}&filekey={quote(filekey, safe="")}'
|
||||||
|
logger.debug(
|
||||||
|
'CDN upload: url=%s raw_size=%d encrypted_size=%d md5=%s aeskey=%s',
|
||||||
|
cdn_url,
|
||||||
|
len(file_bytes),
|
||||||
|
len(encrypted),
|
||||||
|
raw_md5,
|
||||||
|
encoded_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with session.post(
|
||||||
|
cdn_url,
|
||||||
|
data=encrypted,
|
||||||
|
headers={'Content-Type': 'application/octet-stream'},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=120),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
logger.error('CDN upload failed: status=%d url=%s body=%s', resp.status, cdn_url, text[:500])
|
||||||
|
raise ApiError(f'CDN upload failed: {resp.status} {text}', status=resp.status)
|
||||||
|
download_param = resp.headers.get('x-encrypted-param', '')
|
||||||
|
|
||||||
|
if not download_param:
|
||||||
|
raise ApiError('CDN upload succeeded but no x-encrypted-param returned', status=0)
|
||||||
|
|
||||||
|
return CDNMedia(
|
||||||
|
encrypt_query_param=download_param,
|
||||||
|
aes_key=encoded_key,
|
||||||
|
encrypt_type=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_image(
|
||||||
|
self,
|
||||||
|
to_user_id: str,
|
||||||
|
image_bytes: bytes,
|
||||||
|
context_token: str = '',
|
||||||
|
) -> None:
|
||||||
|
"""Upload an image to CDN and send it."""
|
||||||
|
media = await self.upload_media(image_bytes, to_user_id, media_type=1)
|
||||||
|
item = MessageItem(
|
||||||
|
type=MessageItem.IMAGE,
|
||||||
|
image_item=ImageItem(
|
||||||
|
media=media,
|
||||||
|
aeskey=media.aes_key,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await self.send_message(to_user_id, [item], context_token)
|
||||||
|
|
||||||
|
async def send_file(
|
||||||
|
self,
|
||||||
|
to_user_id: str,
|
||||||
|
file_bytes: bytes,
|
||||||
|
file_name: str,
|
||||||
|
context_token: str = '',
|
||||||
|
) -> None:
|
||||||
|
"""Upload a file to CDN and send it."""
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
media = await self.upload_media(file_bytes, to_user_id, media_type=3)
|
||||||
|
item = MessageItem(
|
||||||
|
type=MessageItem.FILE,
|
||||||
|
file_item=FileItem(
|
||||||
|
media=media,
|
||||||
|
file_name=file_name,
|
||||||
|
md5=hashlib.md5(file_bytes).hexdigest(),
|
||||||
|
len=str(len(file_bytes)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await self.send_message(to_user_id, [item], context_token)
|
||||||
|
|
||||||
|
async def send_voice(
|
||||||
|
self,
|
||||||
|
to_user_id: str,
|
||||||
|
voice_bytes: bytes,
|
||||||
|
playtime: int = 0,
|
||||||
|
context_token: str = '',
|
||||||
|
) -> None:
|
||||||
|
"""Upload a voice message to CDN and send it."""
|
||||||
|
media = await self.upload_media(voice_bytes, to_user_id, media_type=4)
|
||||||
|
item = MessageItem(
|
||||||
|
type=MessageItem.VOICE,
|
||||||
|
voice_item=VoiceItem(
|
||||||
|
media=media,
|
||||||
|
playtime=playtime,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await self.send_message(to_user_id, [item], context_token)
|
||||||
|
|
||||||
|
async def get_upload_url(
|
||||||
|
self,
|
||||||
|
filekey: str,
|
||||||
|
media_type: int,
|
||||||
|
to_user_id: str,
|
||||||
|
rawsize: int,
|
||||||
|
rawfilemd5: str,
|
||||||
|
filesize: int,
|
||||||
|
thumb_rawsize: Optional[int] = None,
|
||||||
|
thumb_rawfilemd5: Optional[str] = None,
|
||||||
|
thumb_filesize: Optional[int] = None,
|
||||||
|
aeskey: Optional[str] = None,
|
||||||
|
) -> GetUploadUrlResponse:
|
||||||
|
"""Get a pre-signed CDN upload URL."""
|
||||||
|
payload: dict = {
|
||||||
|
'filekey': filekey,
|
||||||
|
'media_type': media_type,
|
||||||
|
'to_user_id': to_user_id,
|
||||||
|
'rawsize': rawsize,
|
||||||
|
'rawfilemd5': rawfilemd5,
|
||||||
|
'filesize': filesize,
|
||||||
|
'no_need_thumb': True,
|
||||||
|
}
|
||||||
|
if thumb_rawsize is not None:
|
||||||
|
payload['thumb_rawsize'] = thumb_rawsize
|
||||||
|
if thumb_rawfilemd5 is not None:
|
||||||
|
payload['thumb_rawfilemd5'] = thumb_rawfilemd5
|
||||||
|
if thumb_filesize is not None:
|
||||||
|
payload['thumb_filesize'] = thumb_filesize
|
||||||
|
if aeskey is not None:
|
||||||
|
payload['aeskey'] = aeskey
|
||||||
|
|
||||||
|
data = await self._post('ilink/bot/getuploadurl', payload)
|
||||||
|
logger.debug('get_upload_url response: %s', data)
|
||||||
|
return GetUploadUrlResponse(
|
||||||
|
upload_param=data.get('upload_param'),
|
||||||
|
thumb_upload_param=data.get('thumb_upload_param'),
|
||||||
|
)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
# QR Code Login
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def fetch_qrcode(self, bot_type: str = DEFAULT_BOT_TYPE) -> QRCodeResponse:
|
||||||
|
"""Fetch a QR code for WeChat login authorization (GET, no auth needed)."""
|
||||||
|
session = await self._get_session()
|
||||||
|
url = f'{self.base_url}/ilink/bot/get_bot_qrcode?bot_type={bot_type}'
|
||||||
|
|
||||||
|
async with session.get(url, timeout=aiohttp.ClientTimeout(total=DEFAULT_API_TIMEOUT)) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise ApiError(
|
||||||
|
f'Failed to fetch QR code: {resp.status} {text}',
|
||||||
|
status=resp.status,
|
||||||
|
)
|
||||||
|
data = await resp.json(content_type=None)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
'fetch_qrcode response: qrcode=%s, img=%s', data.get('qrcode'), bool(data.get('qrcode_img_content'))
|
||||||
|
)
|
||||||
|
|
||||||
|
return QRCodeResponse(
|
||||||
|
qrcode=data.get('qrcode'),
|
||||||
|
qrcode_img_content=data.get('qrcode_img_content'),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fetch_qr_image_base64(self, url: str) -> str:
|
||||||
|
"""Generate a QR code image from the URL and return a data URI string.
|
||||||
|
|
||||||
|
The qrcode_img_content URL points to an HTML page (not a raw image),
|
||||||
|
so we generate the QR code locally using the qrcode library.
|
||||||
|
"""
|
||||||
|
import qrcode
|
||||||
|
|
||||||
|
qr = qrcode.QRCode(error_correction=qrcode.constants.ERROR_CORRECT_L)
|
||||||
|
qr.add_data(url)
|
||||||
|
qr.make(fit=True)
|
||||||
|
img = qr.make_image(fill_color='black', back_color='white')
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, format='PNG')
|
||||||
|
b64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
||||||
|
return f'data:image/png;base64,{b64}'
|
||||||
|
|
||||||
|
async def poll_qrcode_status(self, qrcode: str) -> QRStatusResponse:
|
||||||
|
"""Long-poll the QR code scan status (GET with iLink-App-ClientVersion header)."""
|
||||||
|
session = await self._get_session()
|
||||||
|
url = f'{self.base_url}/ilink/bot/get_qrcode_status?qrcode={quote(qrcode, safe="")}'
|
||||||
|
headers = {'iLink-App-ClientVersion': '1'}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.get(
|
||||||
|
url, headers=headers, timeout=aiohttp.ClientTimeout(total=DEFAULT_QR_POLL_TIMEOUT)
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise ApiError(
|
||||||
|
f'Failed to poll QR status: {resp.status} {text}',
|
||||||
|
status=resp.status,
|
||||||
|
)
|
||||||
|
data = await resp.json(content_type=None)
|
||||||
|
logger.debug('QR status poll response: %s', data)
|
||||||
|
except (asyncio.TimeoutError, aiohttp.ServerTimeoutError):
|
||||||
|
return QRStatusResponse(status='wait')
|
||||||
|
|
||||||
|
return QRStatusResponse(
|
||||||
|
status=data.get('status'),
|
||||||
|
bot_token=data.get('bot_token'),
|
||||||
|
ilink_bot_id=data.get('ilink_bot_id'),
|
||||||
|
baseurl=data.get('baseurl'),
|
||||||
|
ilink_user_id=data.get('ilink_user_id'),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def login(
|
||||||
|
self,
|
||||||
|
max_retries: int = 5,
|
||||||
|
poll_timeout_ms: int = 480_000,
|
||||||
|
on_qrcode: Optional[typing.Callable[[str, str], typing.Any]] = None,
|
||||||
|
on_status: Optional[typing.Callable[[str], typing.Any]] = None,
|
||||||
|
) -> LoginResult:
|
||||||
|
"""Complete QR code login flow with auto-retry on expiry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_retries: Max number of QR code refreshes on expiry.
|
||||||
|
poll_timeout_ms: Timeout per QR code in milliseconds.
|
||||||
|
on_qrcode: Callback(qr_image_base64, qr_url) called each time a
|
||||||
|
new QR code is fetched. Use this to display the QR code.
|
||||||
|
on_status: Callback(status_str) called on each status poll change.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoginResult with token, base_url, and account_id.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ApiError: On unrecoverable API errors.
|
||||||
|
Exception: If all retries are exhausted.
|
||||||
|
"""
|
||||||
|
last_qr_base64: Optional[str] = None
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
qr_resp = await self.fetch_qrcode()
|
||||||
|
if not qr_resp.qrcode or not qr_resp.qrcode_img_content:
|
||||||
|
raise ApiError('Failed to get QR code from server', status=0)
|
||||||
|
|
||||||
|
# Convert QR image to base64 and notify caller
|
||||||
|
last_qr_base64 = await self._fetch_qr_image_base64(qr_resp.qrcode_img_content)
|
||||||
|
if on_qrcode:
|
||||||
|
try:
|
||||||
|
result = on_qrcode(last_qr_base64, qr_resp.qrcode_img_content)
|
||||||
|
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
|
||||||
|
await result
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning('on_qrcode callback error: %s', e)
|
||||||
|
|
||||||
|
# Poll until confirmed / expired / timeout
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
deadline = loop.time() + poll_timeout_ms / 1000.0
|
||||||
|
|
||||||
|
while loop.time() < deadline:
|
||||||
|
try:
|
||||||
|
status_resp = await self.poll_qrcode_status(qr_resp.qrcode)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('Error polling QR status: %s', e)
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if on_status:
|
||||||
|
try:
|
||||||
|
cb_result = on_status(status_resp.status or 'unknown')
|
||||||
|
if asyncio.iscoroutine(cb_result) or asyncio.isfuture(cb_result):
|
||||||
|
await cb_result
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning('on_status callback error: %s', e)
|
||||||
|
|
||||||
|
if status_resp.status == 'confirmed' and status_resp.bot_token:
|
||||||
|
new_base_url = status_resp.baseurl or self.base_url
|
||||||
|
# Update this client instance as well
|
||||||
|
self.token = status_resp.bot_token
|
||||||
|
self.base_url = new_base_url.rstrip('/')
|
||||||
|
return LoginResult(
|
||||||
|
token=status_resp.bot_token,
|
||||||
|
base_url=new_base_url,
|
||||||
|
account_id=status_resp.ilink_bot_id or '',
|
||||||
|
qr_image_base64=last_qr_base64,
|
||||||
|
)
|
||||||
|
|
||||||
|
if status_resp.status == 'expired':
|
||||||
|
break # retry with a new QR code
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
else:
|
||||||
|
# While-loop ended without break → poll timeout, treat as expired
|
||||||
|
pass
|
||||||
|
|
||||||
|
remaining = max_retries - attempt - 1
|
||||||
|
if remaining > 0:
|
||||||
|
logger.info('QR code expired, refreshing... (%d retries left)', remaining)
|
||||||
|
else:
|
||||||
|
raise ApiError('QR code login failed: max retries exceeded', status=0)
|
||||||
|
|
||||||
|
# Should not reach here, but just in case
|
||||||
|
raise ApiError('QR code login failed', status=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Parsing helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_cdn_media(data: Optional[dict]) -> Optional[CDNMedia]:
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
return CDNMedia(
|
||||||
|
encrypt_query_param=data.get('encrypt_query_param'),
|
||||||
|
aes_key=data.get('aes_key'),
|
||||||
|
encrypt_type=data.get('encrypt_type'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_message_item(data: dict) -> MessageItem:
|
||||||
|
item = MessageItem(
|
||||||
|
type=data.get('type'),
|
||||||
|
create_time_ms=data.get('create_time_ms'),
|
||||||
|
update_time_ms=data.get('update_time_ms'),
|
||||||
|
is_completed=data.get('is_completed'),
|
||||||
|
msg_id=data.get('msg_id'),
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.get('text_item'):
|
||||||
|
item.text_item = TextItem(text=data['text_item'].get('text'))
|
||||||
|
|
||||||
|
if data.get('image_item'):
|
||||||
|
img = data['image_item']
|
||||||
|
item.image_item = ImageItem(
|
||||||
|
media=_parse_cdn_media(img.get('media')),
|
||||||
|
thumb_media=_parse_cdn_media(img.get('thumb_media')),
|
||||||
|
aeskey=img.get('aeskey'),
|
||||||
|
url=img.get('url'),
|
||||||
|
mid_size=img.get('mid_size'),
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.get('voice_item'):
|
||||||
|
v = data['voice_item']
|
||||||
|
item.voice_item = VoiceItem(
|
||||||
|
media=_parse_cdn_media(v.get('media')),
|
||||||
|
encode_type=v.get('encode_type'),
|
||||||
|
playtime=v.get('playtime'),
|
||||||
|
text=v.get('text'),
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.get('file_item'):
|
||||||
|
f = data['file_item']
|
||||||
|
item.file_item = FileItem(
|
||||||
|
media=_parse_cdn_media(f.get('media')),
|
||||||
|
file_name=f.get('file_name'),
|
||||||
|
md5=f.get('md5'),
|
||||||
|
len=f.get('len'),
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.get('video_item'):
|
||||||
|
vid = data['video_item']
|
||||||
|
item.video_item = VideoItem(
|
||||||
|
media=_parse_cdn_media(vid.get('media')),
|
||||||
|
video_size=vid.get('video_size'),
|
||||||
|
play_length=vid.get('play_length'),
|
||||||
|
video_md5=vid.get('video_md5'),
|
||||||
|
thumb_media=_parse_cdn_media(vid.get('thumb_media')),
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.get('ref_msg'):
|
||||||
|
ref = data['ref_msg']
|
||||||
|
item.ref_msg = RefMessage(
|
||||||
|
title=ref.get('title'),
|
||||||
|
message_item=_parse_message_item(ref['message_item']) if ref.get('message_item') else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_weixin_message(data: dict) -> WeixinMessage:
|
||||||
|
msg = WeixinMessage(
|
||||||
|
seq=data.get('seq'),
|
||||||
|
message_id=data.get('message_id'),
|
||||||
|
from_user_id=data.get('from_user_id'),
|
||||||
|
to_user_id=data.get('to_user_id'),
|
||||||
|
client_id=data.get('client_id'),
|
||||||
|
create_time_ms=data.get('create_time_ms'),
|
||||||
|
session_id=data.get('session_id'),
|
||||||
|
group_id=data.get('group_id'),
|
||||||
|
message_type=data.get('message_type'),
|
||||||
|
message_state=data.get('message_state'),
|
||||||
|
context_token=data.get('context_token'),
|
||||||
|
)
|
||||||
|
if data.get('item_list'):
|
||||||
|
msg.item_list = [_parse_message_item(item) for item in data['item_list']]
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_get_updates_response(data: dict) -> GetUpdatesResponse:
|
||||||
|
resp = GetUpdatesResponse(
|
||||||
|
ret=data.get('ret'),
|
||||||
|
errcode=data.get('errcode'),
|
||||||
|
errmsg=data.get('errmsg'),
|
||||||
|
get_updates_buf=data.get('get_updates_buf'),
|
||||||
|
longpolling_timeout_ms=data.get('longpolling_timeout_ms'),
|
||||||
|
)
|
||||||
|
if data.get('msgs'):
|
||||||
|
resp.msgs = [_parse_weixin_message(m) for m in data['msgs']]
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def _cdn_media_to_dict(media: Optional[CDNMedia]) -> Optional[dict]:
|
||||||
|
if not media:
|
||||||
|
return None
|
||||||
|
d: dict = {}
|
||||||
|
if media.encrypt_query_param is not None:
|
||||||
|
d['encrypt_query_param'] = media.encrypt_query_param
|
||||||
|
if media.aes_key is not None:
|
||||||
|
d['aes_key'] = media.aes_key
|
||||||
|
if media.encrypt_type is not None:
|
||||||
|
d['encrypt_type'] = media.encrypt_type
|
||||||
|
return d or None
|
||||||
|
|
||||||
|
|
||||||
|
def _message_item_to_dict(item: MessageItem) -> dict:
|
||||||
|
d: dict = {'type': item.type}
|
||||||
|
|
||||||
|
if item.text_item:
|
||||||
|
d['text_item'] = {'text': item.text_item.text}
|
||||||
|
|
||||||
|
if item.image_item:
|
||||||
|
img_d: dict = {}
|
||||||
|
if item.image_item.media:
|
||||||
|
img_d['media'] = _cdn_media_to_dict(item.image_item.media)
|
||||||
|
if item.image_item.mid_size is not None:
|
||||||
|
img_d['mid_size'] = item.image_item.mid_size
|
||||||
|
d['image_item'] = img_d
|
||||||
|
|
||||||
|
if item.voice_item:
|
||||||
|
voice_d: dict = {}
|
||||||
|
if item.voice_item.media:
|
||||||
|
voice_d['media'] = _cdn_media_to_dict(item.voice_item.media)
|
||||||
|
if item.voice_item.playtime is not None:
|
||||||
|
voice_d['playtime'] = item.voice_item.playtime
|
||||||
|
d['voice_item'] = voice_d
|
||||||
|
|
||||||
|
if item.file_item:
|
||||||
|
file_d: dict = {}
|
||||||
|
if item.file_item.media:
|
||||||
|
file_d['media'] = _cdn_media_to_dict(item.file_item.media)
|
||||||
|
if item.file_item.file_name:
|
||||||
|
file_d['file_name'] = item.file_item.file_name
|
||||||
|
if item.file_item.len:
|
||||||
|
file_d['len'] = item.file_item.len
|
||||||
|
d['file_item'] = file_d
|
||||||
|
|
||||||
|
if item.video_item:
|
||||||
|
vid_d: dict = {}
|
||||||
|
if item.video_item.media:
|
||||||
|
vid_d['media'] = _cdn_media_to_dict(item.video_item.media)
|
||||||
|
if item.video_item.video_size is not None:
|
||||||
|
vid_d['video_size'] = item.video_item.video_size
|
||||||
|
d['video_item'] = vid_d
|
||||||
|
|
||||||
|
return d
|
||||||
200
src/langbot/libs/openclaw_weixin_api/types.py
Normal file
200
src/langbot/libs/openclaw_weixin_api/types.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
"""Type definitions for the OpenClaw WeChat API, mirroring the upstream protocol."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
SESSION_EXPIRED_ERRCODE = -14
|
||||||
|
|
||||||
|
|
||||||
|
class ApiError(Exception):
|
||||||
|
"""Structured error raised by the OpenClaw WeChat API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
status: int = 0,
|
||||||
|
code: int | None = None,
|
||||||
|
payload: Any = None,
|
||||||
|
):
|
||||||
|
super().__init__(message)
|
||||||
|
self.status = status
|
||||||
|
self.code = code
|
||||||
|
self.payload = payload
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_session_expired(self) -> bool:
|
||||||
|
return self.code == SESSION_EXPIRED_ERRCODE
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CDNMedia:
|
||||||
|
encrypt_query_param: Optional[str] = None
|
||||||
|
aes_key: Optional[str] = None
|
||||||
|
encrypt_type: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextItem:
|
||||||
|
text: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageItem:
|
||||||
|
media: Optional[CDNMedia] = None
|
||||||
|
thumb_media: Optional[CDNMedia] = None
|
||||||
|
aeskey: Optional[str] = None
|
||||||
|
url: Optional[str] = None
|
||||||
|
mid_size: Optional[int] = None
|
||||||
|
thumb_size: Optional[int] = None
|
||||||
|
thumb_height: Optional[int] = None
|
||||||
|
thumb_width: Optional[int] = None
|
||||||
|
hd_size: Optional[int] = None
|
||||||
|
_downloaded_bytes: Optional[bytes] = field(default=None, repr=False)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VoiceItem:
|
||||||
|
media: Optional[CDNMedia] = None
|
||||||
|
encode_type: Optional[int] = None
|
||||||
|
bits_per_sample: Optional[int] = None
|
||||||
|
sample_rate: Optional[int] = None
|
||||||
|
playtime: Optional[int] = None
|
||||||
|
text: Optional[str] = None
|
||||||
|
_downloaded_bytes: Optional[bytes] = field(default=None, repr=False)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FileItem:
|
||||||
|
media: Optional[CDNMedia] = None
|
||||||
|
file_name: Optional[str] = None
|
||||||
|
md5: Optional[str] = None
|
||||||
|
len: Optional[str] = None
|
||||||
|
_downloaded_bytes: Optional[bytes] = field(default=None, repr=False)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoItem:
|
||||||
|
media: Optional[CDNMedia] = None
|
||||||
|
video_size: Optional[int] = None
|
||||||
|
play_length: Optional[int] = None
|
||||||
|
video_md5: Optional[str] = None
|
||||||
|
thumb_media: Optional[CDNMedia] = None
|
||||||
|
thumb_size: Optional[int] = None
|
||||||
|
thumb_height: Optional[int] = None
|
||||||
|
thumb_width: Optional[int] = None
|
||||||
|
_downloaded_bytes: Optional[bytes] = field(default=None, repr=False)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RefMessage:
|
||||||
|
message_item: Optional[MessageItem] = None
|
||||||
|
title: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MessageItem:
|
||||||
|
"""A single content item inside a WeixinMessage."""
|
||||||
|
|
||||||
|
# Item types
|
||||||
|
NONE = 0
|
||||||
|
TEXT = 1
|
||||||
|
IMAGE = 2
|
||||||
|
VOICE = 3
|
||||||
|
FILE = 4
|
||||||
|
VIDEO = 5
|
||||||
|
|
||||||
|
type: Optional[int] = None
|
||||||
|
create_time_ms: Optional[int] = None
|
||||||
|
update_time_ms: Optional[int] = None
|
||||||
|
is_completed: Optional[bool] = None
|
||||||
|
msg_id: Optional[str] = None
|
||||||
|
ref_msg: Optional[RefMessage] = None
|
||||||
|
text_item: Optional[TextItem] = None
|
||||||
|
image_item: Optional[ImageItem] = None
|
||||||
|
voice_item: Optional[VoiceItem] = None
|
||||||
|
file_item: Optional[FileItem] = None
|
||||||
|
video_item: Optional[VideoItem] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WeixinMessage:
|
||||||
|
"""Unified message from getUpdates or for sendMessage."""
|
||||||
|
|
||||||
|
# Message types
|
||||||
|
TYPE_USER = 1
|
||||||
|
TYPE_BOT = 2
|
||||||
|
|
||||||
|
# Message states
|
||||||
|
STATE_NEW = 0
|
||||||
|
STATE_GENERATING = 1
|
||||||
|
STATE_FINISH = 2
|
||||||
|
|
||||||
|
seq: Optional[int] = None
|
||||||
|
message_id: Optional[int] = None
|
||||||
|
from_user_id: Optional[str] = None
|
||||||
|
to_user_id: Optional[str] = None
|
||||||
|
client_id: Optional[str] = None
|
||||||
|
create_time_ms: Optional[int] = None
|
||||||
|
update_time_ms: Optional[int] = None
|
||||||
|
delete_time_ms: Optional[int] = None
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
group_id: Optional[str] = None
|
||||||
|
message_type: Optional[int] = None
|
||||||
|
message_state: Optional[int] = None
|
||||||
|
item_list: Optional[list[MessageItem]] = None
|
||||||
|
context_token: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GetUpdatesResponse:
|
||||||
|
ret: Optional[int] = None
|
||||||
|
errcode: Optional[int] = None
|
||||||
|
errmsg: Optional[str] = None
|
||||||
|
msgs: list[WeixinMessage] = field(default_factory=list)
|
||||||
|
get_updates_buf: Optional[str] = None
|
||||||
|
longpolling_timeout_ms: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GetConfigResponse:
|
||||||
|
ret: Optional[int] = None
|
||||||
|
errmsg: Optional[str] = None
|
||||||
|
typing_ticket: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GetUploadUrlResponse:
|
||||||
|
upload_param: Optional[str] = None
|
||||||
|
thumb_upload_param: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QRCodeResponse:
|
||||||
|
"""Response from get_bot_qrcode endpoint."""
|
||||||
|
|
||||||
|
qrcode: Optional[str] = None
|
||||||
|
qrcode_img_content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QRStatusResponse:
|
||||||
|
"""Response from get_qrcode_status endpoint."""
|
||||||
|
|
||||||
|
status: Optional[str] = None # "wait" | "scaned" | "confirmed" | "expired"
|
||||||
|
bot_token: Optional[str] = None
|
||||||
|
ilink_bot_id: Optional[str] = None
|
||||||
|
baseurl: Optional[str] = None
|
||||||
|
ilink_user_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoginResult:
|
||||||
|
"""Result returned by the login flow."""
|
||||||
|
|
||||||
|
token: str
|
||||||
|
base_url: str
|
||||||
|
account_id: str
|
||||||
|
qr_image_base64: Optional[str] = None # data URI of the last QR code shown
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import requests
|
import requests
|
||||||
import aiohttp
|
from langbot.pkg.utils import httpclient
|
||||||
|
|
||||||
|
|
||||||
def post_json(base_url, token, data=None):
|
def post_json(base_url, token, data=None):
|
||||||
@@ -63,16 +63,16 @@ async def async_request(
|
|||||||
"""
|
"""
|
||||||
headers = {'Content-Type': 'application/json'}
|
headers = {'Content-Type': 'application/json'}
|
||||||
url = f'{base_url}?key={token_key}'
|
url = f'{base_url}?key={token_key}'
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.request(
|
async with session.request(
|
||||||
method=method, url=url, params=params, headers=headers, data=data, json=json
|
method=method, url=url, params=params, headers=headers, data=data, json=json
|
||||||
) as response:
|
) as response:
|
||||||
response.raise_for_status() # 如果状态码不是200,抛出异常
|
response.raise_for_status() # 如果状态码不是200,抛出异常
|
||||||
result = await response.json()
|
result = await response.json()
|
||||||
# print(result)
|
# print(result)
|
||||||
return result
|
return result
|
||||||
# if result.get('Code') == 200:
|
# if result.get('Code') == 200:
|
||||||
#
|
#
|
||||||
# return await result
|
# return await result
|
||||||
# else:
|
# else:
|
||||||
# raise RuntimeError("请求失败",response.text)
|
# raise RuntimeError("请求失败",response.text)
|
||||||
|
|||||||
@@ -199,6 +199,253 @@ class StreamSessionManager:
|
|||||||
self._msg_index.pop(msg_id, None)
|
self._msg_index.pop(msg_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def download_encrypted_file(download_url: str, encoding_aes_key: str, logger: EventLogger) -> Optional[str]:
|
||||||
|
"""Download an AES-encrypted file from WeChat Work and return as data URI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
download_url: The encrypted file download URL.
|
||||||
|
encoding_aes_key: The AES key used for decryption (base64-encoded, without trailing '=').
|
||||||
|
logger: Logger instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A data URI string (e.g. 'data:image/jpeg;base64,...') or None on failure.
|
||||||
|
"""
|
||||||
|
if not download_url:
|
||||||
|
return None
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(download_url)
|
||||||
|
if response.status_code != 200:
|
||||||
|
await logger.error(f'failed to get file: {response.text}')
|
||||||
|
return None
|
||||||
|
encrypted_bytes = response.content
|
||||||
|
|
||||||
|
aes_key = base64.b64decode(encoding_aes_key + '=')
|
||||||
|
iv = aes_key[:16]
|
||||||
|
|
||||||
|
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
||||||
|
decrypted = cipher.decrypt(encrypted_bytes)
|
||||||
|
|
||||||
|
pad_len = decrypted[-1]
|
||||||
|
decrypted = decrypted[:-pad_len]
|
||||||
|
|
||||||
|
if decrypted.startswith(b'\xff\xd8'):
|
||||||
|
mime_type = 'image/jpeg'
|
||||||
|
elif decrypted.startswith(b'\x89PNG'):
|
||||||
|
mime_type = 'image/png'
|
||||||
|
elif decrypted.startswith((b'GIF87a', b'GIF89a')):
|
||||||
|
mime_type = 'image/gif'
|
||||||
|
elif decrypted.startswith(b'BM'):
|
||||||
|
mime_type = 'image/bmp'
|
||||||
|
elif decrypted.startswith(b'II*\x00') or decrypted.startswith(b'MM\x00*'):
|
||||||
|
mime_type = 'image/tiff'
|
||||||
|
else:
|
||||||
|
mime_type = 'application/octet-stream'
|
||||||
|
|
||||||
|
base64_str = base64.b64encode(decrypted).decode('utf-8')
|
||||||
|
return f'data:{mime_type};base64,{base64_str}'
|
||||||
|
|
||||||
|
|
||||||
|
async def parse_wecom_bot_message(
|
||||||
|
msg_json: dict[str, Any], encoding_aes_key: str, logger: EventLogger
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Parse a decrypted WeChat Work AI Bot message JSON into a unified message dict.
|
||||||
|
|
||||||
|
This is the shared message parsing logic used by both webhook and WebSocket modes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg_json: The decrypted message JSON from WeChat Work.
|
||||||
|
encoding_aes_key: AES key for file decryption.
|
||||||
|
logger: Logger instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict suitable for constructing a WecomBotEvent.
|
||||||
|
"""
|
||||||
|
message_data: dict[str, Any] = {}
|
||||||
|
|
||||||
|
msg_type = msg_json.get('msgtype', '')
|
||||||
|
if msg_type:
|
||||||
|
message_data['msgtype'] = msg_type
|
||||||
|
|
||||||
|
if msg_json.get('chattype', '') == 'single':
|
||||||
|
message_data['type'] = 'single'
|
||||||
|
elif msg_json.get('chattype', '') == 'group':
|
||||||
|
message_data['type'] = 'group'
|
||||||
|
|
||||||
|
max_inline_file_size = 5 * 1024 * 1024
|
||||||
|
|
||||||
|
async def _safe_download(url: str):
|
||||||
|
if not url:
|
||||||
|
return None
|
||||||
|
return await download_encrypted_file(url, encoding_aes_key, logger)
|
||||||
|
|
||||||
|
if msg_type == 'text':
|
||||||
|
message_data['content'] = msg_json.get('text', {}).get('content')
|
||||||
|
elif msg_type == 'markdown':
|
||||||
|
message_data['content'] = msg_json.get('markdown', {}).get('content') or msg_json.get('text', {}).get(
|
||||||
|
'content', ''
|
||||||
|
)
|
||||||
|
elif msg_type == 'image':
|
||||||
|
picurl = msg_json.get('image', {}).get('url', '')
|
||||||
|
base64_data = await _safe_download(picurl)
|
||||||
|
if base64_data:
|
||||||
|
message_data['picurl'] = base64_data
|
||||||
|
message_data['images'] = [base64_data]
|
||||||
|
elif msg_type == 'voice':
|
||||||
|
voice_info = msg_json.get('voice', {}) or {}
|
||||||
|
download_url = voice_info.get('url')
|
||||||
|
message_data['voice'] = {
|
||||||
|
'url': download_url,
|
||||||
|
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||||
|
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||||
|
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||||
|
}
|
||||||
|
if voice_info.get('content'):
|
||||||
|
message_data['content'] = voice_info.get('content')
|
||||||
|
if (message_data['voice'].get('filesize') or 0) <= max_inline_file_size:
|
||||||
|
voice_base64 = await _safe_download(download_url)
|
||||||
|
if voice_base64:
|
||||||
|
message_data['voice']['base64'] = voice_base64
|
||||||
|
elif msg_type == 'video':
|
||||||
|
video_info = msg_json.get('video', {}) or {}
|
||||||
|
download_url = video_info.get('url')
|
||||||
|
video_data = {
|
||||||
|
'url': download_url,
|
||||||
|
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||||
|
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||||
|
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||||
|
'filename': video_info.get('filename') or video_info.get('name'),
|
||||||
|
}
|
||||||
|
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||||
|
video_base64 = await _safe_download(download_url)
|
||||||
|
if video_base64:
|
||||||
|
video_data['base64'] = video_base64
|
||||||
|
message_data['video'] = video_data
|
||||||
|
elif msg_type == 'file':
|
||||||
|
file_info = msg_json.get('file', {}) or {}
|
||||||
|
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||||
|
file_data = {
|
||||||
|
'filename': file_info.get('filename') or file_info.get('name'),
|
||||||
|
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||||
|
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||||
|
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||||
|
'download_url': download_url,
|
||||||
|
'extra': file_info,
|
||||||
|
}
|
||||||
|
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||||
|
file_base64 = await _safe_download(download_url)
|
||||||
|
if file_base64:
|
||||||
|
file_data['base64'] = file_base64
|
||||||
|
message_data['file'] = file_data
|
||||||
|
elif msg_type == 'link':
|
||||||
|
message_data['link'] = msg_json.get('link', {})
|
||||||
|
if not message_data.get('content'):
|
||||||
|
title = message_data['link'].get('title', '')
|
||||||
|
desc = message_data['link'].get('description') or message_data['link'].get('digest', '')
|
||||||
|
message_data['content'] = '\n'.join(filter(None, [title, desc]))
|
||||||
|
elif msg_type == 'mixed':
|
||||||
|
items = msg_json.get('mixed', {}).get('msg_item', [])
|
||||||
|
texts = []
|
||||||
|
images = []
|
||||||
|
files = []
|
||||||
|
voices = []
|
||||||
|
videos = []
|
||||||
|
links = []
|
||||||
|
for item in items:
|
||||||
|
item_type = item.get('msgtype')
|
||||||
|
if item_type == 'text':
|
||||||
|
texts.append(item.get('text', {}).get('content', ''))
|
||||||
|
elif item_type == 'image':
|
||||||
|
img_url = item.get('image', {}).get('url')
|
||||||
|
base64_data = await _safe_download(img_url)
|
||||||
|
if base64_data:
|
||||||
|
images.append(base64_data)
|
||||||
|
elif item_type == 'file':
|
||||||
|
file_info = item.get('file', {}) or {}
|
||||||
|
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||||
|
file_data = {
|
||||||
|
'filename': file_info.get('filename') or file_info.get('name'),
|
||||||
|
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||||
|
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||||
|
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||||
|
'download_url': download_url,
|
||||||
|
'extra': file_info,
|
||||||
|
}
|
||||||
|
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||||
|
file_base64 = await _safe_download(download_url)
|
||||||
|
if file_base64:
|
||||||
|
file_data['base64'] = file_base64
|
||||||
|
files.append(file_data)
|
||||||
|
elif item_type == 'voice':
|
||||||
|
voice_info = item.get('voice', {}) or {}
|
||||||
|
download_url = voice_info.get('url')
|
||||||
|
voice_data = {
|
||||||
|
'url': download_url,
|
||||||
|
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||||
|
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||||
|
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||||
|
}
|
||||||
|
if voice_info.get('content'):
|
||||||
|
texts.append(voice_info.get('content'))
|
||||||
|
if (voice_data.get('filesize') or 0) <= max_inline_file_size:
|
||||||
|
voice_base64 = await _safe_download(download_url)
|
||||||
|
if voice_base64:
|
||||||
|
voice_data['base64'] = voice_base64
|
||||||
|
voices.append(voice_data)
|
||||||
|
elif item_type == 'video':
|
||||||
|
video_info = item.get('video', {}) or {}
|
||||||
|
download_url = video_info.get('url')
|
||||||
|
video_data = {
|
||||||
|
'url': download_url,
|
||||||
|
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||||
|
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||||
|
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||||
|
'filename': video_info.get('filename') or video_info.get('name'),
|
||||||
|
}
|
||||||
|
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||||
|
video_base64 = await _safe_download(download_url)
|
||||||
|
if video_base64:
|
||||||
|
video_data['base64'] = video_base64
|
||||||
|
videos.append(video_data)
|
||||||
|
elif item_type == 'link':
|
||||||
|
links.append(item.get('link', {}))
|
||||||
|
|
||||||
|
if texts:
|
||||||
|
message_data['content'] = ' '.join(texts)
|
||||||
|
if images:
|
||||||
|
message_data['images'] = images
|
||||||
|
message_data['picurl'] = images[0]
|
||||||
|
if files:
|
||||||
|
message_data['files'] = files
|
||||||
|
message_data['file'] = files[0]
|
||||||
|
if voices:
|
||||||
|
message_data['voices'] = voices
|
||||||
|
message_data['voice'] = voices[0]
|
||||||
|
if videos:
|
||||||
|
message_data['videos'] = videos
|
||||||
|
message_data['video'] = videos[0]
|
||||||
|
if links:
|
||||||
|
message_data['link'] = links[0]
|
||||||
|
if items:
|
||||||
|
message_data['attachments'] = items
|
||||||
|
else:
|
||||||
|
message_data['raw_msg'] = msg_json
|
||||||
|
|
||||||
|
from_info = msg_json.get('from', {})
|
||||||
|
message_data['userid'] = from_info.get('userid', '')
|
||||||
|
message_data['username'] = from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
|
||||||
|
|
||||||
|
if msg_json.get('chattype', '') == 'group':
|
||||||
|
message_data['chatid'] = msg_json.get('chatid', '')
|
||||||
|
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
|
||||||
|
|
||||||
|
message_data['msgid'] = msg_json.get('msgid', '')
|
||||||
|
|
||||||
|
if msg_json.get('aibotid'):
|
||||||
|
message_data['aibotid'] = msg_json.get('aibotid', '')
|
||||||
|
|
||||||
|
return message_data
|
||||||
|
|
||||||
|
|
||||||
class WecomBotClient:
|
class WecomBotClient:
|
||||||
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger, unified_mode: bool = False):
|
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger, unified_mode: bool = False):
|
||||||
"""企业微信智能机器人客户端。
|
"""企业微信智能机器人客户端。
|
||||||
@@ -455,196 +702,7 @@ class WecomBotClient:
|
|||||||
return await self._handle_post_initial_response(msg_json, nonce)
|
return await self._handle_post_initial_response(msg_json, nonce)
|
||||||
|
|
||||||
async def get_message(self, msg_json):
|
async def get_message(self, msg_json):
|
||||||
message_data = {}
|
return await parse_wecom_bot_message(msg_json, self.EnCodingAESKey, self.logger)
|
||||||
|
|
||||||
msg_type = msg_json.get('msgtype', '')
|
|
||||||
if msg_type:
|
|
||||||
message_data['msgtype'] = msg_type
|
|
||||||
|
|
||||||
if msg_json.get('chattype', '') == 'single':
|
|
||||||
message_data['type'] = 'single'
|
|
||||||
elif msg_json.get('chattype', '') == 'group':
|
|
||||||
message_data['type'] = 'group'
|
|
||||||
|
|
||||||
max_inline_file_size = 5 * 1024 * 1024 # avoid decoding very large payloads by default
|
|
||||||
|
|
||||||
async def _safe_download(url: str):
|
|
||||||
if not url:
|
|
||||||
return None
|
|
||||||
return await self.download_url_to_base64(url, self.EnCodingAESKey)
|
|
||||||
|
|
||||||
if msg_type == 'text':
|
|
||||||
message_data['content'] = msg_json.get('text', {}).get('content')
|
|
||||||
elif msg_type == 'markdown':
|
|
||||||
message_data['content'] = msg_json.get('markdown', {}).get('content') or msg_json.get('text', {}).get(
|
|
||||||
'content', ''
|
|
||||||
)
|
|
||||||
elif msg_type == 'image':
|
|
||||||
picurl = msg_json.get('image', {}).get('url', '')
|
|
||||||
base64_data = await _safe_download(picurl)
|
|
||||||
if base64_data:
|
|
||||||
message_data['picurl'] = base64_data
|
|
||||||
message_data['images'] = [base64_data]
|
|
||||||
elif msg_type == 'voice':
|
|
||||||
voice_info = msg_json.get('voice', {}) or {}
|
|
||||||
download_url = voice_info.get('url')
|
|
||||||
message_data['voice'] = {
|
|
||||||
'url': download_url,
|
|
||||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
|
||||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
|
||||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
|
||||||
}
|
|
||||||
# 企业微信智能转写文本(如果已有)直接复用,避免重复转写
|
|
||||||
if voice_info.get('content'):
|
|
||||||
message_data['content'] = voice_info.get('content')
|
|
||||||
if (message_data['voice'].get('filesize') or 0) <= max_inline_file_size:
|
|
||||||
voice_base64 = await _safe_download(download_url)
|
|
||||||
if voice_base64:
|
|
||||||
message_data['voice']['base64'] = voice_base64
|
|
||||||
elif msg_type == 'video':
|
|
||||||
video_info = msg_json.get('video', {}) or {}
|
|
||||||
download_url = video_info.get('url')
|
|
||||||
video_data = {
|
|
||||||
'url': download_url,
|
|
||||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
|
||||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
|
||||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
|
||||||
'filename': video_info.get('filename') or video_info.get('name'),
|
|
||||||
}
|
|
||||||
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
|
||||||
video_base64 = await _safe_download(download_url)
|
|
||||||
if video_base64:
|
|
||||||
video_data['base64'] = video_base64
|
|
||||||
message_data['video'] = video_data
|
|
||||||
elif msg_type == 'file':
|
|
||||||
file_info = msg_json.get('file', {}) or {}
|
|
||||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
|
||||||
file_data = {
|
|
||||||
'filename': file_info.get('filename') or file_info.get('name'),
|
|
||||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
|
||||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
|
||||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
|
||||||
'download_url': download_url,
|
|
||||||
'extra': file_info,
|
|
||||||
}
|
|
||||||
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
|
||||||
file_base64 = await _safe_download(download_url)
|
|
||||||
if file_base64:
|
|
||||||
file_data['base64'] = file_base64
|
|
||||||
message_data['file'] = file_data
|
|
||||||
elif msg_type == 'link':
|
|
||||||
message_data['link'] = msg_json.get('link', {})
|
|
||||||
if not message_data.get('content'):
|
|
||||||
title = message_data['link'].get('title', '')
|
|
||||||
desc = message_data['link'].get('description') or message_data['link'].get('digest', '')
|
|
||||||
message_data['content'] = '\n'.join(filter(None, [title, desc]))
|
|
||||||
elif msg_type == 'mixed':
|
|
||||||
items = msg_json.get('mixed', {}).get('msg_item', [])
|
|
||||||
texts = []
|
|
||||||
images = []
|
|
||||||
files = []
|
|
||||||
voices = []
|
|
||||||
videos = []
|
|
||||||
links = []
|
|
||||||
for item in items:
|
|
||||||
item_type = item.get('msgtype')
|
|
||||||
if item_type == 'text':
|
|
||||||
texts.append(item.get('text', {}).get('content', ''))
|
|
||||||
elif item_type == 'image':
|
|
||||||
img_url = item.get('image', {}).get('url')
|
|
||||||
base64_data = await _safe_download(img_url)
|
|
||||||
if base64_data:
|
|
||||||
images.append(base64_data)
|
|
||||||
elif item_type == 'file':
|
|
||||||
file_info = item.get('file', {}) or {}
|
|
||||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
|
||||||
file_data = {
|
|
||||||
'filename': file_info.get('filename') or file_info.get('name'),
|
|
||||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
|
||||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
|
||||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
|
||||||
'download_url': download_url,
|
|
||||||
'extra': file_info,
|
|
||||||
}
|
|
||||||
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
|
||||||
file_base64 = await _safe_download(download_url)
|
|
||||||
if file_base64:
|
|
||||||
file_data['base64'] = file_base64
|
|
||||||
files.append(file_data)
|
|
||||||
elif item_type == 'voice':
|
|
||||||
voice_info = item.get('voice', {}) or {}
|
|
||||||
download_url = voice_info.get('url')
|
|
||||||
voice_data = {
|
|
||||||
'url': download_url,
|
|
||||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
|
||||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
|
||||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
|
||||||
}
|
|
||||||
if voice_info.get('content'):
|
|
||||||
texts.append(voice_info.get('content'))
|
|
||||||
if (voice_data.get('filesize') or 0) <= max_inline_file_size:
|
|
||||||
voice_base64 = await _safe_download(download_url)
|
|
||||||
if voice_base64:
|
|
||||||
voice_data['base64'] = voice_base64
|
|
||||||
voices.append(voice_data)
|
|
||||||
elif item_type == 'video':
|
|
||||||
video_info = item.get('video', {}) or {}
|
|
||||||
download_url = video_info.get('url')
|
|
||||||
video_data = {
|
|
||||||
'url': download_url,
|
|
||||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
|
||||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
|
||||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
|
||||||
'filename': video_info.get('filename') or video_info.get('name'),
|
|
||||||
}
|
|
||||||
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
|
||||||
video_base64 = await _safe_download(download_url)
|
|
||||||
if video_base64:
|
|
||||||
video_data['base64'] = video_base64
|
|
||||||
videos.append(video_data)
|
|
||||||
elif item_type == 'link':
|
|
||||||
links.append(item.get('link', {}))
|
|
||||||
|
|
||||||
if texts:
|
|
||||||
message_data['content'] = ' '.join(texts) # 拼接所有 text
|
|
||||||
if images:
|
|
||||||
message_data['images'] = images
|
|
||||||
message_data['picurl'] = images[0] # 只保留第一个 image
|
|
||||||
if files:
|
|
||||||
message_data['files'] = files
|
|
||||||
message_data['file'] = files[0]
|
|
||||||
if voices:
|
|
||||||
message_data['voices'] = voices
|
|
||||||
message_data['voice'] = voices[0]
|
|
||||||
if videos:
|
|
||||||
message_data['videos'] = videos
|
|
||||||
message_data['video'] = videos[0]
|
|
||||||
if links:
|
|
||||||
message_data['link'] = links[0]
|
|
||||||
if items:
|
|
||||||
message_data['attachments'] = items
|
|
||||||
else:
|
|
||||||
message_data['raw_msg'] = msg_json
|
|
||||||
|
|
||||||
# Extract user information
|
|
||||||
from_info = msg_json.get('from', {})
|
|
||||||
message_data['userid'] = from_info.get('userid', '')
|
|
||||||
message_data['username'] = (
|
|
||||||
from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract chat/group information
|
|
||||||
if msg_json.get('chattype', '') == 'group':
|
|
||||||
message_data['chatid'] = msg_json.get('chatid', '')
|
|
||||||
# Try to get group name if available
|
|
||||||
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
|
|
||||||
|
|
||||||
message_data['msgid'] = msg_json.get('msgid', '')
|
|
||||||
|
|
||||||
if msg_json.get('aibotid'):
|
|
||||||
message_data['aibotid'] = msg_json.get('aibotid', '')
|
|
||||||
|
|
||||||
return message_data
|
|
||||||
|
|
||||||
async def _handle_message(self, event: wecombotevent.WecomBotEvent):
|
async def _handle_message(self, event: wecombotevent.WecomBotEvent):
|
||||||
"""
|
"""
|
||||||
@@ -712,39 +770,7 @@ class WecomBotClient:
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
async def download_url_to_base64(self, download_url, encoding_aes_key):
|
async def download_url_to_base64(self, download_url, encoding_aes_key):
|
||||||
async with httpx.AsyncClient() as client:
|
return await download_encrypted_file(download_url, encoding_aes_key, self.logger)
|
||||||
response = await client.get(download_url)
|
|
||||||
if response.status_code != 200:
|
|
||||||
await self.logger.error(f'failed to get file: {response.text}')
|
|
||||||
return None
|
|
||||||
|
|
||||||
encrypted_bytes = response.content
|
|
||||||
|
|
||||||
aes_key = base64.b64decode(encoding_aes_key + '=') # base64 补齐
|
|
||||||
iv = aes_key[:16]
|
|
||||||
|
|
||||||
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
|
||||||
decrypted = cipher.decrypt(encrypted_bytes)
|
|
||||||
|
|
||||||
pad_len = decrypted[-1]
|
|
||||||
decrypted = decrypted[:-pad_len]
|
|
||||||
|
|
||||||
if decrypted.startswith(b'\xff\xd8'): # JPEG
|
|
||||||
mime_type = 'image/jpeg'
|
|
||||||
elif decrypted.startswith(b'\x89PNG'): # PNG
|
|
||||||
mime_type = 'image/png'
|
|
||||||
elif decrypted.startswith((b'GIF87a', b'GIF89a')): # GIF
|
|
||||||
mime_type = 'image/gif'
|
|
||||||
elif decrypted.startswith(b'BM'): # BMP
|
|
||||||
mime_type = 'image/bmp'
|
|
||||||
elif decrypted.startswith(b'II*\x00') or decrypted.startswith(b'MM\x00*'): # TIFF
|
|
||||||
mime_type = 'image/tiff'
|
|
||||||
else:
|
|
||||||
mime_type = 'application/octet-stream'
|
|
||||||
|
|
||||||
# 转 base64
|
|
||||||
base64_str = base64.b64encode(decrypted).decode('utf-8')
|
|
||||||
return f'data:{mime_type};base64,{base64_str}'
|
|
||||||
|
|
||||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
596
src/langbot/libs/wecom_ai_bot_api/ws_client.py
Normal file
596
src/langbot/libs/wecom_ai_bot_api/ws_client.py
Normal file
@@ -0,0 +1,596 @@
|
|||||||
|
"""WeChat Work AI Bot WebSocket long connection client.
|
||||||
|
|
||||||
|
Implements the WebSocket protocol for receiving messages and sending replies
|
||||||
|
via a persistent connection to wss://openws.work.weixin.qq.com, as an
|
||||||
|
alternative to the HTTP callback (webhook) mode.
|
||||||
|
|
||||||
|
Protocol reference: https://developer.work.weixin.qq.com/document/path/101463
|
||||||
|
Official Node.js SDK: https://github.com/WecomTeam/aibot-node-sdk
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from langbot.libs.wecom_ai_bot_api import wecombotevent
|
||||||
|
from langbot.libs.wecom_ai_bot_api.api import parse_wecom_bot_message
|
||||||
|
from langbot.pkg.platform.logger import EventLogger
|
||||||
|
|
||||||
|
DEFAULT_WS_URL = 'wss://openws.work.weixin.qq.com'
|
||||||
|
|
||||||
|
# WebSocket frame command constants
|
||||||
|
CMD_SUBSCRIBE = 'aibot_subscribe'
|
||||||
|
CMD_HEARTBEAT = 'ping'
|
||||||
|
CMD_MSG_CALLBACK = 'aibot_msg_callback'
|
||||||
|
CMD_EVENT_CALLBACK = 'aibot_event_callback'
|
||||||
|
CMD_RESPOND_MSG = 'aibot_respond_msg'
|
||||||
|
CMD_RESPOND_WELCOME = 'aibot_respond_welcome_msg'
|
||||||
|
CMD_RESPOND_UPDATE = 'aibot_respond_update_msg'
|
||||||
|
CMD_SEND_MSG = 'aibot_send_msg'
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_req_id(prefix: str) -> str:
|
||||||
|
"""Generate a unique request ID in the format: {prefix}_{timestamp}_{random}."""
|
||||||
|
ts = int(time.time() * 1000)
|
||||||
|
rand = secrets.token_hex(4)
|
||||||
|
return f'{prefix}_{ts}_{rand}'
|
||||||
|
|
||||||
|
|
||||||
|
class WecomBotWsClient:
|
||||||
|
"""WeChat Work AI Bot WebSocket long connection client.
|
||||||
|
|
||||||
|
Provides message receiving, streaming reply, proactive message sending,
|
||||||
|
and event callback handling over a persistent WebSocket connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bot_id: str,
|
||||||
|
secret: str,
|
||||||
|
logger: EventLogger,
|
||||||
|
encoding_aes_key: str = '',
|
||||||
|
ws_url: str = DEFAULT_WS_URL,
|
||||||
|
heartbeat_interval: float = 30.0,
|
||||||
|
max_reconnect_attempts: int = -1,
|
||||||
|
reconnect_base_delay: float = 1.0,
|
||||||
|
reconnect_max_delay: float = 30.0,
|
||||||
|
):
|
||||||
|
self.bot_id = bot_id
|
||||||
|
self.secret = secret
|
||||||
|
self.logger = logger
|
||||||
|
self.encoding_aes_key = encoding_aes_key
|
||||||
|
self.ws_url = ws_url
|
||||||
|
self.heartbeat_interval = heartbeat_interval
|
||||||
|
self.max_reconnect_attempts = max_reconnect_attempts
|
||||||
|
self.reconnect_base_delay = reconnect_base_delay
|
||||||
|
self.reconnect_max_delay = reconnect_max_delay
|
||||||
|
|
||||||
|
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
self._running = False
|
||||||
|
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||||
|
self._missed_pong_count = 0
|
||||||
|
self._max_missed_pong = 2
|
||||||
|
self._reconnect_attempts = 0
|
||||||
|
|
||||||
|
# Message handler registry (same pattern as WecomBotClient)
|
||||||
|
self._message_handlers: dict[str, list[Callable]] = {}
|
||||||
|
# Message deduplication
|
||||||
|
self._msg_id_map: dict[str, int] = {}
|
||||||
|
|
||||||
|
# Pending ACK futures: req_id -> Future[dict]
|
||||||
|
self._pending_acks: dict[str, asyncio.Future] = {}
|
||||||
|
# Per-req_id serial reply queues
|
||||||
|
self._reply_queues: dict[str, asyncio.Queue] = {}
|
||||||
|
self._reply_workers: dict[str, asyncio.Task] = {}
|
||||||
|
self._reply_ack_timeout = 5.0
|
||||||
|
|
||||||
|
# Stream ID tracking for WebSocket mode
|
||||||
|
self._stream_ids: dict[str, str] = {} # msg_id -> req_id|stream_id
|
||||||
|
# Dedup: skip sending when content hasn't changed
|
||||||
|
self._stream_last_content: dict[str, str] = {} # msg_id -> last content sent
|
||||||
|
|
||||||
|
# ── Public API ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""Connect to WebSocket server with automatic reconnection.
|
||||||
|
|
||||||
|
This method blocks until disconnect() is called or max reconnect
|
||||||
|
attempts are exhausted.
|
||||||
|
"""
|
||||||
|
self._running = True
|
||||||
|
self._reconnect_attempts = 0
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
await self._connect_once()
|
||||||
|
except Exception:
|
||||||
|
if not self._running:
|
||||||
|
break
|
||||||
|
await self.logger.error(f'WebSocket connection error: {traceback.format_exc()}')
|
||||||
|
|
||||||
|
if not self._running:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reconnect with exponential backoff
|
||||||
|
if self.max_reconnect_attempts != -1 and self._reconnect_attempts >= self.max_reconnect_attempts:
|
||||||
|
await self.logger.error(f'Max reconnect attempts reached ({self.max_reconnect_attempts}), giving up')
|
||||||
|
break
|
||||||
|
|
||||||
|
self._reconnect_attempts += 1
|
||||||
|
delay = min(
|
||||||
|
self.reconnect_base_delay * (2 ** (self._reconnect_attempts - 1)),
|
||||||
|
self.reconnect_max_delay,
|
||||||
|
)
|
||||||
|
await self.logger.info(f'Reconnecting in {delay:.1f}s (attempt {self._reconnect_attempts})...')
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
"""Gracefully disconnect from the WebSocket server."""
|
||||||
|
self._running = False
|
||||||
|
if self._heartbeat_task and not self._heartbeat_task.done():
|
||||||
|
self._heartbeat_task.cancel()
|
||||||
|
for task in self._reply_workers.values():
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
if self._ws and not self._ws.closed:
|
||||||
|
await self._ws.close()
|
||||||
|
self._ws = None
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
def on_message(self, msg_type: str) -> Callable:
|
||||||
|
"""Decorator to register a message handler.
|
||||||
|
|
||||||
|
Same interface as WecomBotClient.on_message for compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg_type: 'single', 'group', or specific message type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable[[wecombotevent.WecomBotEvent], Any]):
|
||||||
|
if msg_type not in self._message_handlers:
|
||||||
|
self._message_handlers[msg_type] = []
|
||||||
|
self._message_handlers[msg_type].append(func)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
async def reply_stream(
|
||||||
|
self,
|
||||||
|
req_id: str,
|
||||||
|
stream_id: str,
|
||||||
|
content: str,
|
||||||
|
finish: bool = False,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Send a streaming reply frame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req_id: The req_id from the original message frame (must be passed through).
|
||||||
|
stream_id: The stream ID for this streaming session.
|
||||||
|
content: The content to send (supports Markdown).
|
||||||
|
finish: Whether this is the final chunk.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ACK frame dict, or None on failure.
|
||||||
|
"""
|
||||||
|
body = {
|
||||||
|
'msgtype': 'stream',
|
||||||
|
'stream': {
|
||||||
|
'id': stream_id,
|
||||||
|
'finish': finish,
|
||||||
|
'content': content,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return await self._send_reply(req_id, body)
|
||||||
|
|
||||||
|
async def reply_text(self, req_id: str, content: str) -> Optional[dict]:
|
||||||
|
"""Send a non-streaming text reply.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req_id: The req_id from the original message frame.
|
||||||
|
content: The text content to reply.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ACK frame dict, or None on failure.
|
||||||
|
"""
|
||||||
|
body = {
|
||||||
|
'msgtype': 'markdown',
|
||||||
|
'markdown': {
|
||||||
|
'content': content,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return await self._send_reply(req_id, body)
|
||||||
|
|
||||||
|
async def send_message(self, chat_id: str, content: str, msgtype: str = 'markdown') -> Optional[dict]:
|
||||||
|
"""Proactively send a message to a specified chat.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: The chat ID (userid for single chat, chatid for group chat).
|
||||||
|
content: The message content.
|
||||||
|
msgtype: Message type, 'markdown' by default.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ACK frame dict, or None on failure.
|
||||||
|
"""
|
||||||
|
req_id = _generate_req_id(CMD_SEND_MSG)
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
'chatid': chat_id,
|
||||||
|
'msgtype': msgtype,
|
||||||
|
}
|
||||||
|
if msgtype == 'markdown':
|
||||||
|
body['markdown'] = {'content': content}
|
||||||
|
elif msgtype == 'text':
|
||||||
|
body['text'] = {'content': content}
|
||||||
|
return await self._send_reply(req_id, body, cmd=CMD_SEND_MSG)
|
||||||
|
|
||||||
|
async def push_stream_chunk(self, msg_id: str, content: str, is_final: bool = False) -> bool:
|
||||||
|
"""Push a streaming chunk for a given message ID.
|
||||||
|
|
||||||
|
Compatible interface with WecomBotClient.push_stream_chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg_id: The original message ID.
|
||||||
|
content: The cumulative content from the pipeline.
|
||||||
|
is_final: Whether this is the final chunk.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the stream session exists and chunk was sent.
|
||||||
|
"""
|
||||||
|
key = self._stream_ids.get(msg_id)
|
||||||
|
if not key:
|
||||||
|
return False
|
||||||
|
req_id, stream_id = key.split('|', 1)
|
||||||
|
try:
|
||||||
|
# Skip sending if content hasn't changed (e.g. during tool call argument streaming)
|
||||||
|
if not is_final and content == self._stream_last_content.get(msg_id):
|
||||||
|
return True
|
||||||
|
await self.reply_stream(req_id, stream_id, content, finish=is_final)
|
||||||
|
self._stream_last_content[msg_id] = content
|
||||||
|
if is_final:
|
||||||
|
self._stream_ids.pop(msg_id, None)
|
||||||
|
self._stream_last_content.pop(msg_id, None)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Failed to push stream chunk: {traceback.format_exc()}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def set_message(self, msg_id: str, content: str):
|
||||||
|
"""Fallback: send content as a final stream chunk or direct reply.
|
||||||
|
|
||||||
|
Compatible interface with WecomBotClient.set_message.
|
||||||
|
"""
|
||||||
|
handled = await self.push_stream_chunk(msg_id, content, is_final=True)
|
||||||
|
if not handled:
|
||||||
|
await self.logger.warning(f'No active stream for msg_id={msg_id}, message dropped')
|
||||||
|
|
||||||
|
# ── Connection lifecycle ────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _connect_once(self):
|
||||||
|
"""Establish a single WebSocket connection, authenticate, and listen."""
|
||||||
|
await self.logger.info(f'Connecting to {self.ws_url}...')
|
||||||
|
|
||||||
|
self._session = aiohttp.ClientSession()
|
||||||
|
try:
|
||||||
|
self._ws = await self._session.ws_connect(self.ws_url)
|
||||||
|
self._missed_pong_count = 0
|
||||||
|
self._reconnect_attempts = 0
|
||||||
|
await self.logger.info('WebSocket connected, sending auth...')
|
||||||
|
|
||||||
|
await self._send_auth()
|
||||||
|
|
||||||
|
# Wait for auth response
|
||||||
|
auth_ok = await self._wait_for_auth()
|
||||||
|
if not auth_ok:
|
||||||
|
await self.logger.error('Authentication failed')
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.logger.info('Authenticated successfully')
|
||||||
|
|
||||||
|
# Start heartbeat
|
||||||
|
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._listen_loop()
|
||||||
|
finally:
|
||||||
|
if self._heartbeat_task and not self._heartbeat_task.done():
|
||||||
|
self._heartbeat_task.cancel()
|
||||||
|
self._clear_pending_acks('Connection closed')
|
||||||
|
finally:
|
||||||
|
if self._ws and not self._ws.closed:
|
||||||
|
await self._ws.close()
|
||||||
|
self._ws = None
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
async def _send_auth(self):
|
||||||
|
"""Send the authentication frame."""
|
||||||
|
frame = {
|
||||||
|
'cmd': CMD_SUBSCRIBE,
|
||||||
|
'headers': {'req_id': _generate_req_id(CMD_SUBSCRIBE)},
|
||||||
|
'body': {
|
||||||
|
'bot_id': self.bot_id,
|
||||||
|
'secret': self.secret,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await self._send_frame(frame)
|
||||||
|
|
||||||
|
async def _wait_for_auth(self) -> bool:
|
||||||
|
"""Wait for and validate the authentication response."""
|
||||||
|
try:
|
||||||
|
msg = await asyncio.wait_for(self._ws.receive(), timeout=10.0)
|
||||||
|
if msg.type in (aiohttp.WSMsgType.TEXT,):
|
||||||
|
frame = json.loads(msg.data)
|
||||||
|
req_id = frame.get('headers', {}).get('req_id', '')
|
||||||
|
if req_id.startswith(CMD_SUBSCRIBE) and frame.get('errcode') == 0:
|
||||||
|
return True
|
||||||
|
await self.logger.error(f'Auth response: errcode={frame.get("errcode")}, errmsg={frame.get("errmsg")}')
|
||||||
|
return False
|
||||||
|
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||||
|
await self.logger.error(f'WebSocket closed during auth: {msg.type}')
|
||||||
|
return False
|
||||||
|
await self.logger.error(f'Unexpected message type during auth: {msg.type}')
|
||||||
|
return False
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
await self.logger.error('Auth response timeout')
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _heartbeat_loop(self):
|
||||||
|
"""Periodically send heartbeat pings."""
|
||||||
|
try:
|
||||||
|
while self._running and self._ws and not self._ws.closed:
|
||||||
|
await asyncio.sleep(self.heartbeat_interval)
|
||||||
|
if not self._running or not self._ws or self._ws.closed:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self._missed_pong_count >= self._max_missed_pong:
|
||||||
|
await self.logger.warning(
|
||||||
|
f'No heartbeat ack for {self._missed_pong_count} consecutive pings, connection considered dead'
|
||||||
|
)
|
||||||
|
await self._ws.close()
|
||||||
|
break
|
||||||
|
|
||||||
|
self._missed_pong_count += 1
|
||||||
|
frame = {
|
||||||
|
'cmd': CMD_HEARTBEAT,
|
||||||
|
'headers': {'req_id': _generate_req_id(CMD_HEARTBEAT)},
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
await self._send_frame(frame)
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _listen_loop(self):
|
||||||
|
"""Listen for incoming WebSocket frames and dispatch them."""
|
||||||
|
async for msg in self._ws:
|
||||||
|
if not self._running:
|
||||||
|
break
|
||||||
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||||
|
try:
|
||||||
|
frame = json.loads(msg.data)
|
||||||
|
await self._handle_frame(frame)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
await self.logger.error(f'Failed to parse WebSocket message: {str(msg.data)[:200]}')
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Error handling frame: {traceback.format_exc()}')
|
||||||
|
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||||
|
try:
|
||||||
|
frame = json.loads(msg.data)
|
||||||
|
await self._handle_frame(frame)
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Error handling binary frame: {traceback.format_exc()}')
|
||||||
|
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||||
|
await self.logger.warning(f'WebSocket connection closed: {msg.type}')
|
||||||
|
break
|
||||||
|
|
||||||
|
# ── Frame handling ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _handle_frame(self, frame: dict):
|
||||||
|
"""Route an incoming frame to the appropriate handler."""
|
||||||
|
cmd = frame.get('cmd', '')
|
||||||
|
|
||||||
|
# Message push
|
||||||
|
if cmd == CMD_MSG_CALLBACK:
|
||||||
|
asyncio.create_task(self._handle_message_callback(frame))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Event push
|
||||||
|
if cmd == CMD_EVENT_CALLBACK:
|
||||||
|
asyncio.create_task(self._handle_event_callback(frame))
|
||||||
|
return
|
||||||
|
|
||||||
|
# No cmd → response/ACK frame, dispatch by req_id prefix
|
||||||
|
req_id = frame.get('headers', {}).get('req_id', '')
|
||||||
|
|
||||||
|
# Check pending ACKs first
|
||||||
|
if req_id in self._pending_acks:
|
||||||
|
future = self._pending_acks.pop(req_id)
|
||||||
|
if not future.done():
|
||||||
|
future.set_result(frame)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Heartbeat response
|
||||||
|
if req_id.startswith(CMD_HEARTBEAT):
|
||||||
|
if frame.get('errcode') == 0:
|
||||||
|
self._missed_pong_count = 0
|
||||||
|
return
|
||||||
|
|
||||||
|
# Unknown frame
|
||||||
|
await self.logger.warning(f'Unknown frame: {json.dumps(frame, ensure_ascii=False)[:200]}')
|
||||||
|
|
||||||
|
async def _handle_message_callback(self, frame: dict):
|
||||||
|
"""Handle an incoming message callback frame."""
|
||||||
|
try:
|
||||||
|
body = frame.get('body', {})
|
||||||
|
req_id = frame.get('headers', {}).get('req_id', '')
|
||||||
|
|
||||||
|
# Parse message using shared logic
|
||||||
|
message_data = await parse_wecom_bot_message(body, self.encoding_aes_key, self.logger)
|
||||||
|
if not message_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Generate stream_id for this message and store the mapping
|
||||||
|
stream_id = _generate_req_id('stream')
|
||||||
|
msg_id = message_data.get('msgid', '')
|
||||||
|
if msg_id:
|
||||||
|
self._stream_ids[msg_id] = f'{req_id}|{stream_id}'
|
||||||
|
message_data['stream_id'] = stream_id
|
||||||
|
message_data['req_id'] = req_id
|
||||||
|
|
||||||
|
event = wecombotevent.WecomBotEvent(message_data)
|
||||||
|
await self._dispatch_event(event)
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Error in message callback: {traceback.format_exc()}')
|
||||||
|
|
||||||
|
async def _handle_event_callback(self, frame: dict):
|
||||||
|
"""Handle an incoming event callback frame (enter_chat, template_card_event, etc.)."""
|
||||||
|
try:
|
||||||
|
body = frame.get('body', {})
|
||||||
|
req_id = frame.get('headers', {}).get('req_id', '')
|
||||||
|
|
||||||
|
event_info = body.get('event', {})
|
||||||
|
event_type = event_info.get('eventtype', '')
|
||||||
|
|
||||||
|
message_data = {
|
||||||
|
'msgtype': 'event',
|
||||||
|
'type': body.get('chattype', 'single'),
|
||||||
|
'event': event_info,
|
||||||
|
'eventtype': event_type,
|
||||||
|
'msgid': body.get('msgid', ''),
|
||||||
|
'aibotid': body.get('aibotid', ''),
|
||||||
|
'req_id': req_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
from_info = body.get('from', {})
|
||||||
|
message_data['userid'] = from_info.get('userid', '')
|
||||||
|
message_data['username'] = from_info.get('alias', '') or from_info.get('userid', '')
|
||||||
|
|
||||||
|
if body.get('chatid'):
|
||||||
|
message_data['chatid'] = body.get('chatid', '')
|
||||||
|
|
||||||
|
event = wecombotevent.WecomBotEvent(message_data)
|
||||||
|
|
||||||
|
# Dispatch to event-specific handlers
|
||||||
|
if event_type in self._message_handlers:
|
||||||
|
for handler in self._message_handlers[event_type]:
|
||||||
|
await handler(event)
|
||||||
|
|
||||||
|
# Also dispatch to generic 'event' handlers
|
||||||
|
if 'event' in self._message_handlers:
|
||||||
|
for handler in self._message_handlers['event']:
|
||||||
|
await handler(event)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Error in event callback: {traceback.format_exc()}')
|
||||||
|
|
||||||
|
async def _dispatch_event(self, event: wecombotevent.WecomBotEvent):
|
||||||
|
"""Dispatch a message event to registered handlers with deduplication."""
|
||||||
|
try:
|
||||||
|
message_id = event.message_id
|
||||||
|
if message_id in self._msg_id_map:
|
||||||
|
self._msg_id_map[message_id] += 1
|
||||||
|
return
|
||||||
|
self._msg_id_map[message_id] = 1
|
||||||
|
|
||||||
|
msg_type = event.type
|
||||||
|
if msg_type in self._message_handlers:
|
||||||
|
for handler in self._message_handlers[msg_type]:
|
||||||
|
await handler(event)
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Error dispatching event: {traceback.format_exc()}')
|
||||||
|
|
||||||
|
# ── Reply sending with serial queue ─────────────────────────────
|
||||||
|
|
||||||
|
async def _send_reply(
|
||||||
|
self,
|
||||||
|
req_id: str,
|
||||||
|
body: dict,
|
||||||
|
cmd: str = CMD_RESPOND_MSG,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Send a reply frame and wait for ACK.
|
||||||
|
|
||||||
|
Replies with the same req_id are serialized to maintain ordering.
|
||||||
|
"""
|
||||||
|
if not self._ws or self._ws.closed:
|
||||||
|
return None
|
||||||
|
|
||||||
|
frame = {
|
||||||
|
'cmd': cmd,
|
||||||
|
'headers': {'req_id': req_id},
|
||||||
|
'body': body,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Ensure serial delivery per req_id
|
||||||
|
if req_id not in self._reply_queues:
|
||||||
|
self._reply_queues[req_id] = asyncio.Queue()
|
||||||
|
self._reply_workers[req_id] = asyncio.create_task(self._reply_queue_worker(req_id))
|
||||||
|
|
||||||
|
future: asyncio.Future = asyncio.get_event_loop().create_future()
|
||||||
|
await self._reply_queues[req_id].put((frame, future))
|
||||||
|
return await future
|
||||||
|
|
||||||
|
async def _reply_queue_worker(self, req_id: str):
|
||||||
|
"""Process reply queue items serially for a given req_id."""
|
||||||
|
queue = self._reply_queues[req_id]
|
||||||
|
try:
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
frame, future = await asyncio.wait_for(queue.get(), timeout=60.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Queue idle, clean up worker
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
ack = await self._send_and_wait_ack(frame)
|
||||||
|
if not future.done():
|
||||||
|
future.set_result(ack)
|
||||||
|
except Exception as e:
|
||||||
|
if not future.done():
|
||||||
|
future.set_exception(e)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self._reply_queues.pop(req_id, None)
|
||||||
|
self._reply_workers.pop(req_id, None)
|
||||||
|
|
||||||
|
async def _send_and_wait_ack(self, frame: dict) -> Optional[dict]:
|
||||||
|
"""Send a frame and wait for the corresponding ACK."""
|
||||||
|
req_id = frame['headers']['req_id']
|
||||||
|
ack_future: asyncio.Future = asyncio.get_event_loop().create_future()
|
||||||
|
self._pending_acks[req_id] = ack_future
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._send_frame(frame)
|
||||||
|
result = await asyncio.wait_for(ack_future, timeout=self._reply_ack_timeout)
|
||||||
|
if result.get('errcode', 0) != 0:
|
||||||
|
await self.logger.warning(
|
||||||
|
f'Reply ACK error: errcode={result.get("errcode")}, errmsg={result.get("errmsg")}'
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._pending_acks.pop(req_id, None)
|
||||||
|
await self.logger.warning(f'Reply ACK timeout ({self._reply_ack_timeout}s) for req_id={req_id}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _send_frame(self, frame: dict):
|
||||||
|
"""Send a JSON frame over the WebSocket connection."""
|
||||||
|
if self._ws and not self._ws.closed:
|
||||||
|
await self._ws.send_str(json.dumps(frame, ensure_ascii=False))
|
||||||
|
|
||||||
|
def _clear_pending_acks(self, reason: str):
|
||||||
|
"""Reject all pending ACK futures on disconnection."""
|
||||||
|
for req_id, future in self._pending_acks.items():
|
||||||
|
if not future.done():
|
||||||
|
future.set_exception(ConnectionError(reason))
|
||||||
|
self._pending_acks.clear()
|
||||||
@@ -4,6 +4,7 @@ import base64
|
|||||||
import binascii
|
import binascii
|
||||||
import httpx
|
import httpx
|
||||||
import traceback
|
import traceback
|
||||||
|
from urllib.parse import quote
|
||||||
from quart import Quart
|
from quart import Quart
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from typing import Callable, Dict, Any
|
from typing import Callable, Dict, Any
|
||||||
@@ -67,6 +68,31 @@ class WecomClient:
|
|||||||
await self.logger.error(f'获取accesstoken失败:{response.json()}')
|
await self.logger.error(f'获取accesstoken失败:{response.json()}')
|
||||||
raise Exception(f'未获取access token: {data}')
|
raise Exception(f'未获取access token: {data}')
|
||||||
|
|
||||||
|
async def get_user_info(self, userid: str) -> dict:
|
||||||
|
"""
|
||||||
|
Get user information by user ID using the application secret.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
userid: The user ID to look up.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: User information including 'name' field.
|
||||||
|
"""
|
||||||
|
if not await self.check_access_token():
|
||||||
|
self.access_token = await self.get_access_token(self.secret)
|
||||||
|
|
||||||
|
url = self.base_url + '/user/get?access_token=' + self.access_token + '&userid=' + quote(userid)
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(url)
|
||||||
|
data = response.json()
|
||||||
|
if data.get('errcode') == 40014 or data.get('errcode') == 42001:
|
||||||
|
self.access_token = await self.get_access_token(self.secret)
|
||||||
|
return await self.get_user_info(userid)
|
||||||
|
if data.get('errcode', 0) != 0:
|
||||||
|
await self.logger.error(f'获取用户信息失败:{data}')
|
||||||
|
return {}
|
||||||
|
return data
|
||||||
|
|
||||||
async def get_users(self):
|
async def get_users(self):
|
||||||
if not self.check_access_token_for_contacts():
|
if not self.check_access_token_for_contacts():
|
||||||
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
|
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import Callable
|
|||||||
from .wecomcsevent import WecomCSEvent
|
from .wecomcsevent import WecomCSEvent
|
||||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
import aiofiles
|
import aiofiles
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
class WecomCSClient:
|
class WecomCSClient:
|
||||||
@@ -34,6 +35,10 @@ class WecomCSClient:
|
|||||||
self.unified_mode = unified_mode
|
self.unified_mode = unified_mode
|
||||||
self.app = Quart(__name__)
|
self.app = Quart(__name__)
|
||||||
|
|
||||||
|
# Customer info cache: {external_userid: (info_dict, timestamp)}
|
||||||
|
self._customer_cache: dict[str, tuple[dict, float]] = {}
|
||||||
|
self._cache_ttl = 60 # Cache TTL in seconds (1 minute)
|
||||||
|
|
||||||
# 只有在非统一模式下才注册独立路由
|
# 只有在非统一模式下才注册独立路由
|
||||||
if not self.unified_mode:
|
if not self.unified_mode:
|
||||||
self.app.add_url_rule(
|
self.app.add_url_rule(
|
||||||
@@ -378,3 +383,53 @@ class WecomCSClient:
|
|||||||
async def get_media_id(self, image: platform_message.Image):
|
async def get_media_id(self, image: platform_message.Image):
|
||||||
media_id = await self.upload_to_work(image=image)
|
media_id = await self.upload_to_work(image=image)
|
||||||
return media_id
|
return media_id
|
||||||
|
|
||||||
|
async def get_customer_info(self, external_userid: str) -> dict | None:
|
||||||
|
"""
|
||||||
|
Get customer information by external_userid with caching.
|
||||||
|
|
||||||
|
Uses a 1-minute cache to avoid repeated API calls for the same user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
external_userid: The external user ID of the customer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Customer info dict with 'nickname', 'avatar', etc., or None if not found.
|
||||||
|
"""
|
||||||
|
# Check cache first
|
||||||
|
current_time = time.time()
|
||||||
|
if external_userid in self._customer_cache:
|
||||||
|
cached_info, cached_time = self._customer_cache[external_userid]
|
||||||
|
if current_time - cached_time < self._cache_ttl:
|
||||||
|
return cached_info
|
||||||
|
|
||||||
|
# Cache miss or expired, fetch from API
|
||||||
|
if not await self.check_access_token():
|
||||||
|
self.access_token = await self.get_access_token(self.secret)
|
||||||
|
|
||||||
|
url = f'{self.base_url}/kf/customer/batchget?access_token={self.access_token}'
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
'external_userid_list': [external_userid],
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(url, json=payload)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if data.get('errcode') in [40014, 42001]:
|
||||||
|
self.access_token = await self.get_access_token(self.secret)
|
||||||
|
return await self.get_customer_info(external_userid)
|
||||||
|
|
||||||
|
if data.get('errcode', 0) != 0:
|
||||||
|
if self.logger:
|
||||||
|
await self.logger.warning(f'Failed to get customer info: {data}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
customer_list = data.get('customer_list', [])
|
||||||
|
if customer_list:
|
||||||
|
customer_info = customer_list[0]
|
||||||
|
# Store in cache
|
||||||
|
self._customer_cache[external_userid] = (customer_info, current_time)
|
||||||
|
return customer_info
|
||||||
|
return None
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ from .. import group
|
|||||||
@group.group_class('files', '/api/v1/files')
|
@group.group_class('files', '/api/v1/files')
|
||||||
class FilesRouterGroup(group.RouterGroup):
|
class FilesRouterGroup(group.RouterGroup):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@self.route('/image/<image_key>', methods=['GET'], auth_type=group.AuthType.NONE)
|
@self.route('/image/<path:image_key>', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||||
async def _(image_key: str) -> quart.Response:
|
async def _(image_key: str) -> quart.Response:
|
||||||
if '/' in image_key or '\\' in image_key:
|
if '..' in image_key or '\\' in image_key:
|
||||||
return quart.Response(status=404)
|
return quart.Response(status=404)
|
||||||
|
|
||||||
if not await self.ap.storage_mgr.storage_provider.exists(image_key):
|
if not await self.ap.storage_mgr.storage_provider.exists(image_key):
|
||||||
|
|||||||
@@ -13,7 +13,10 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
|||||||
|
|
||||||
elif quart.request.method == 'POST':
|
elif quart.request.method == 'POST':
|
||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
knowledge_base_uuid = await self.ap.knowledge_service.create_knowledge_base(json_data)
|
try:
|
||||||
|
knowledge_base_uuid = await self.ap.knowledge_service.create_knowledge_base(json_data)
|
||||||
|
except ValueError as e:
|
||||||
|
return self.http_status(400, -1, str(e))
|
||||||
return self.success(data={'uuid': knowledge_base_uuid})
|
return self.success(data={'uuid': knowledge_base_uuid})
|
||||||
|
|
||||||
return self.http_status(405, -1, 'Method not allowed')
|
return self.http_status(405, -1, 'Method not allowed')
|
||||||
@@ -39,7 +42,7 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
|||||||
elif quart.request.method == 'PUT':
|
elif quart.request.method == 'PUT':
|
||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
await self.ap.knowledge_service.update_knowledge_base(knowledge_base_uuid, json_data)
|
await self.ap.knowledge_service.update_knowledge_base(knowledge_base_uuid, json_data)
|
||||||
return self.success({})
|
return self.success(data={'uuid': knowledge_base_uuid})
|
||||||
|
|
||||||
elif quart.request.method == 'DELETE':
|
elif quart.request.method == 'DELETE':
|
||||||
await self.ap.knowledge_service.delete_knowledge_base(knowledge_base_uuid)
|
await self.ap.knowledge_service.delete_knowledge_base(knowledge_base_uuid)
|
||||||
@@ -65,8 +68,12 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
|||||||
if not file_id:
|
if not file_id:
|
||||||
return self.http_status(400, -1, 'File ID is required')
|
return self.http_status(400, -1, 'File ID is required')
|
||||||
|
|
||||||
|
parser_plugin_id = json_data.get('parser_plugin_id')
|
||||||
|
|
||||||
# 调用服务层方法将文件与知识库关联
|
# 调用服务层方法将文件与知识库关联
|
||||||
task_id = await self.ap.knowledge_service.store_file(knowledge_base_uuid, file_id)
|
task_id = await self.ap.knowledge_service.store_file(
|
||||||
|
knowledge_base_uuid, file_id, parser_plugin_id=parser_plugin_id
|
||||||
|
)
|
||||||
return self.success(
|
return self.success(
|
||||||
{
|
{
|
||||||
'task_id': task_id,
|
'task_id': task_id,
|
||||||
@@ -90,5 +97,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
|||||||
async def retrieve_knowledge_base(knowledge_base_uuid: str) -> str:
|
async def retrieve_knowledge_base(knowledge_base_uuid: str) -> str:
|
||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
query = json_data.get('query')
|
query = json_data.get('query')
|
||||||
results = await self.ap.knowledge_service.retrieve_knowledge_base(knowledge_base_uuid, query)
|
|
||||||
|
if not query or not query.strip():
|
||||||
|
return self.http_status(400, -1, 'Query is required and cannot be empty')
|
||||||
|
|
||||||
|
# Extract retrieval_settings to allow dynamic control over Knowledge Engine behavior (e.g. top_k, filters)
|
||||||
|
retrieval_settings = json_data.get('retrieval_settings', {})
|
||||||
|
results = await self.ap.knowledge_service.retrieve_knowledge_base(
|
||||||
|
knowledge_base_uuid, query, retrieval_settings
|
||||||
|
)
|
||||||
return self.success(data={'results': results})
|
return self.success(data={'results': results})
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
import quart
|
||||||
|
from urllib.parse import unquote
|
||||||
|
from ... import group
|
||||||
|
|
||||||
|
|
||||||
|
@group.group_class('knowledge_engines', '/api/v1/knowledge/engines')
|
||||||
|
class KnowledgeEnginesRouterGroup(group.RouterGroup):
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||||
|
async def list_knowledge_engines() -> quart.Response:
|
||||||
|
"""List all available Knowledge Engines from plugins.
|
||||||
|
|
||||||
|
Returns a list of Knowledge Engines with their capabilities and configuration schemas.
|
||||||
|
This is used by the frontend to render the knowledge base creation wizard.
|
||||||
|
"""
|
||||||
|
engines = await self.ap.knowledge_service.list_knowledge_engines()
|
||||||
|
return self.success(data={'engines': engines})
|
||||||
|
|
||||||
|
@self.route(
|
||||||
|
'/<path:plugin_id>/creation-schema', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY
|
||||||
|
)
|
||||||
|
async def get_engine_creation_schema(plugin_id: str) -> quart.Response:
|
||||||
|
"""Get creation settings schema for a specific Knowledge Engine.
|
||||||
|
|
||||||
|
plugin_id is in 'author/name' format, captured via <path:> converter.
|
||||||
|
"""
|
||||||
|
plugin_id = unquote(plugin_id)
|
||||||
|
if '/' not in plugin_id:
|
||||||
|
return self.http_status(400, -1, 'Invalid plugin_id format. Expected author/name.')
|
||||||
|
schema = await self.ap.knowledge_service.get_engine_creation_schema(plugin_id)
|
||||||
|
return self.success(data={'schema': schema})
|
||||||
|
|
||||||
|
@self.route(
|
||||||
|
'/<path:plugin_id>/retrieval-schema', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY
|
||||||
|
)
|
||||||
|
async def get_engine_retrieval_schema(plugin_id: str) -> quart.Response:
|
||||||
|
"""Get retrieval settings schema for a specific Knowledge Engine.
|
||||||
|
|
||||||
|
plugin_id is in 'author/name' format, captured via <path:> converter.
|
||||||
|
"""
|
||||||
|
plugin_id = unquote(plugin_id)
|
||||||
|
if '/' not in plugin_id:
|
||||||
|
return self.http_status(400, -1, 'Invalid plugin_id format. Expected author/name.')
|
||||||
|
schema = await self.ap.knowledge_service.get_engine_retrieval_schema(plugin_id)
|
||||||
|
return self.success(data={'schema': schema})
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
import quart
|
|
||||||
from ... import group
|
|
||||||
|
|
||||||
|
|
||||||
@group.group_class('external_knowledge_base', '/api/v1/knowledge/external-bases')
|
|
||||||
class ExternalKnowledgeBaseRouterGroup(group.RouterGroup):
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
@self.route('/retrievers', methods=['GET'])
|
|
||||||
async def list_knowledge_retrievers() -> quart.Response:
|
|
||||||
"""List all available knowledge retrievers from plugins."""
|
|
||||||
retrievers = await self.ap.plugin_connector.list_knowledge_retrievers()
|
|
||||||
return self.success(data={'retrievers': retrievers})
|
|
||||||
|
|
||||||
@self.route('', methods=['POST', 'GET'])
|
|
||||||
async def handle_external_knowledge_bases() -> quart.Response:
|
|
||||||
if quart.request.method == 'GET':
|
|
||||||
external_kbs = await self.ap.external_kb_service.get_external_knowledge_bases()
|
|
||||||
return self.success(data={'bases': external_kbs})
|
|
||||||
|
|
||||||
elif quart.request.method == 'POST':
|
|
||||||
json_data = await quart.request.json
|
|
||||||
kb_uuid = await self.ap.external_kb_service.create_external_knowledge_base(json_data)
|
|
||||||
return self.success(data={'uuid': kb_uuid})
|
|
||||||
|
|
||||||
return self.http_status(405, -1, 'Method not allowed')
|
|
||||||
|
|
||||||
@self.route(
|
|
||||||
'/<kb_uuid>',
|
|
||||||
methods=['GET', 'DELETE', 'PUT'],
|
|
||||||
)
|
|
||||||
async def handle_specific_external_knowledge_base(kb_uuid: str) -> quart.Response:
|
|
||||||
if quart.request.method == 'GET':
|
|
||||||
external_kb = await self.ap.external_kb_service.get_external_knowledge_base(kb_uuid)
|
|
||||||
|
|
||||||
if external_kb is None:
|
|
||||||
return self.http_status(404, -1, 'external knowledge base not found')
|
|
||||||
|
|
||||||
return self.success(
|
|
||||||
data={
|
|
||||||
'base': external_kb,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
elif quart.request.method == 'PUT':
|
|
||||||
json_data = await quart.request.json
|
|
||||||
await self.ap.external_kb_service.update_external_knowledge_base(kb_uuid, json_data)
|
|
||||||
return self.success({})
|
|
||||||
|
|
||||||
elif quart.request.method == 'DELETE':
|
|
||||||
await self.ap.external_kb_service.delete_external_knowledge_base(kb_uuid)
|
|
||||||
return self.success({})
|
|
||||||
|
|
||||||
@self.route(
|
|
||||||
'/<kb_uuid>/retrieve',
|
|
||||||
methods=['POST'],
|
|
||||||
)
|
|
||||||
async def retrieve_external_knowledge_base(kb_uuid: str) -> str:
|
|
||||||
json_data = await quart.request.json
|
|
||||||
query = json_data.get('query')
|
|
||||||
results = await self.ap.external_kb_service.retrieve_external_knowledge_base(kb_uuid, query)
|
|
||||||
return self.success(data={'results': results})
|
|
||||||
@@ -0,0 +1,372 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import quart
|
||||||
|
import sqlalchemy
|
||||||
|
|
||||||
|
from ... import group
|
||||||
|
from ......core import taskmgr
|
||||||
|
from ......entity.persistence import metadata as persistence_metadata
|
||||||
|
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
|
||||||
|
|
||||||
|
LANGRAG_PLUGIN_AUTHOR = 'langbot-team'
|
||||||
|
LANGRAG_PLUGIN_NAME = 'LangRAG'
|
||||||
|
LANGRAG_PLUGIN_ID = f'{LANGRAG_PLUGIN_AUTHOR}/{LANGRAG_PLUGIN_NAME}'
|
||||||
|
DEFAULT_SPACE_URL = 'https://space.langbot.app'
|
||||||
|
|
||||||
|
# Old Retriever plugin_name -> New Connector plugin_name
|
||||||
|
EXTERNAL_PLUGIN_NAME_MAPPING = {
|
||||||
|
'DifyDatasetsRetriever': 'DifyDatasetsConnector',
|
||||||
|
'RAGFlowRetriever': 'RAGFlowConnector',
|
||||||
|
'FastGPTRetriever': 'FastGPTConnector',
|
||||||
|
}
|
||||||
|
|
||||||
|
# Per-plugin: which old retriever_config fields belong to creation_settings.
|
||||||
|
# Remaining fields go to retrieval_settings.
|
||||||
|
# None means ALL fields go to creation_settings (no retrieval_schema).
|
||||||
|
EXTERNAL_PLUGIN_CREATION_FIELDS: dict[str, set[str] | None] = {
|
||||||
|
'langbot-team/DifyDatasetsConnector': {'api_base_url', 'dify_apikey', 'dataset_id'},
|
||||||
|
'langbot-team/RAGFlowConnector': {'api_base_url', 'api_key', 'dataset_ids'},
|
||||||
|
'langbot-team/FastGPTConnector': None, # all fields -> creation_settings
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@group.group_class('knowledge/migration', '/api/v1/knowledge/migration')
|
||||||
|
class KnowledgeMigrationRouterGroup(group.RouterGroup):
|
||||||
|
async def _get_migration_flag(self) -> bool:
|
||||||
|
"""Check if rag_plugin_migration_needed flag is set."""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(persistence_metadata.Metadata).where(
|
||||||
|
persistence_metadata.Metadata.key == 'rag_plugin_migration_needed'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.first()
|
||||||
|
return row is not None and row.value == 'true'
|
||||||
|
|
||||||
|
async def _set_migration_flag(self, value: str):
|
||||||
|
"""Set rag_plugin_migration_needed flag."""
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.update(persistence_metadata.Metadata)
|
||||||
|
.where(persistence_metadata.Metadata.key == 'rag_plugin_migration_needed')
|
||||||
|
.values(value=value)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _table_exists(self, table_name: str) -> bool:
|
||||||
|
"""Check if a table exists."""
|
||||||
|
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
'SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = :table_name);'
|
||||||
|
).bindparams(table_name=table_name)
|
||||||
|
)
|
||||||
|
return result.scalar()
|
||||||
|
else:
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name;").bindparams(
|
||||||
|
table_name=table_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.first() is not None
|
||||||
|
|
||||||
|
async def _install_plugin_from_marketplace(
|
||||||
|
self, plugin_id: str, task_context: taskmgr.TaskContext, space_url: str
|
||||||
|
) -> None:
|
||||||
|
"""Install a single plugin from the marketplace."""
|
||||||
|
p_author, p_name = plugin_id.split('/', 1)
|
||||||
|
self.ap.logger.info(f'RAG migration: installing plugin {plugin_id} from marketplace...')
|
||||||
|
task_context.trace(f'Installing plugin {plugin_id} from marketplace...')
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(trust_env=True, timeout=15) as client:
|
||||||
|
resp = await client.get(f'{space_url}/api/v1/marketplace/plugins/{p_author}/{p_name}')
|
||||||
|
resp.raise_for_status()
|
||||||
|
p_data = resp.json().get('data', {}).get('plugin', {})
|
||||||
|
p_version = p_data.get('latest_version')
|
||||||
|
if not p_version:
|
||||||
|
raise Exception(f'Could not determine latest version for {plugin_id}')
|
||||||
|
|
||||||
|
await self.ap.plugin_connector.install_plugin(
|
||||||
|
PluginInstallSource.MARKETPLACE,
|
||||||
|
{
|
||||||
|
'plugin_author': p_author,
|
||||||
|
'plugin_name': p_name,
|
||||||
|
'plugin_version': p_version,
|
||||||
|
},
|
||||||
|
task_context=task_context,
|
||||||
|
)
|
||||||
|
self.ap.logger.info(f'RAG migration: plugin {plugin_id} install request sent.')
|
||||||
|
|
||||||
|
async def _execute_rag_migration(self, task_context: taskmgr.TaskContext, install_plugin: bool = True):
|
||||||
|
"""Execute RAG migration: install required plugins and restore backup data."""
|
||||||
|
warnings = []
|
||||||
|
|
||||||
|
# Collect all plugins we need: LangRAG (always) + connector plugins (from external KBs)
|
||||||
|
needed_plugins: dict[str, str] = {
|
||||||
|
LANGRAG_PLUGIN_ID: LANGRAG_PLUGIN_NAME,
|
||||||
|
}
|
||||||
|
|
||||||
|
has_external = await self._table_exists('external_knowledge_bases')
|
||||||
|
if has_external:
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('SELECT DISTINCT plugin_author, plugin_name FROM external_knowledge_bases;')
|
||||||
|
)
|
||||||
|
for row in result.fetchall():
|
||||||
|
plugin_author = row[0] or ''
|
||||||
|
plugin_name = row[1] or ''
|
||||||
|
mapped_name = EXTERNAL_PLUGIN_NAME_MAPPING.get(plugin_name, plugin_name)
|
||||||
|
plugin_id = f'{plugin_author}/{mapped_name}'
|
||||||
|
if plugin_id not in needed_plugins:
|
||||||
|
needed_plugins[plugin_id] = mapped_name
|
||||||
|
|
||||||
|
self.ap.logger.info(f'RAG migration: plugins needed: {list(needed_plugins.keys())}')
|
||||||
|
|
||||||
|
if install_plugin:
|
||||||
|
# Step 1: Install all required plugins from marketplace
|
||||||
|
task_context.trace('Installing required plugins...', action='install-plugin')
|
||||||
|
space_url = self.ap.instance_config.data.get('space', {}).get('url', DEFAULT_SPACE_URL).rstrip('/')
|
||||||
|
|
||||||
|
for plugin_id in needed_plugins:
|
||||||
|
try:
|
||||||
|
await self._install_plugin_from_marketplace(plugin_id, task_context, space_url)
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.warning(f'RAG migration: plugin {plugin_id} install returned: {e}')
|
||||||
|
task_context.trace(f'Plugin install note ({plugin_id}): {e}')
|
||||||
|
|
||||||
|
# Step 2: Wait for all plugins to become available as knowledge engines
|
||||||
|
task_context.trace(
|
||||||
|
f'Waiting for plugins to become available: {list(needed_plugins.keys())}...',
|
||||||
|
action='wait-plugin',
|
||||||
|
)
|
||||||
|
max_retries = 30
|
||||||
|
engine_id_set: set[str] = set()
|
||||||
|
for i in range(max_retries):
|
||||||
|
try:
|
||||||
|
engines = await self.ap.plugin_connector.list_knowledge_engines()
|
||||||
|
engine_id_set = {e.get('plugin_id') for e in engines}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if all(pid in engine_id_set for pid in needed_plugins):
|
||||||
|
self.ap.logger.info(f'RAG migration: all plugins ready: {engine_id_set}')
|
||||||
|
task_context.trace('All required plugins are ready.')
|
||||||
|
break
|
||||||
|
if i == max_retries - 1:
|
||||||
|
still_missing = [pid for pid in needed_plugins if pid not in engine_id_set]
|
||||||
|
warning = f'Plugin(s) {still_missing} did not become available after {max_retries} retries'
|
||||||
|
self.ap.logger.warning(f'RAG migration: {warning}')
|
||||||
|
warnings.append(warning)
|
||||||
|
task_context.trace(warning)
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
engines = await self.ap.plugin_connector.list_knowledge_engines()
|
||||||
|
engine_id_set = {e.get('plugin_id') for e in engines}
|
||||||
|
except Exception:
|
||||||
|
engine_id_set = set()
|
||||||
|
|
||||||
|
# Step 3: Restore internal knowledge bases from backup
|
||||||
|
task_context.trace('Restoring internal knowledge bases...', action='restore-internal')
|
||||||
|
if await self._table_exists('knowledge_bases_backup'):
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('SELECT * FROM knowledge_bases_backup;')
|
||||||
|
)
|
||||||
|
rows = result.fetchall()
|
||||||
|
columns = result.keys()
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
row_dict = dict(zip(columns, row))
|
||||||
|
kb_uuid = row_dict.get('uuid')
|
||||||
|
name = row_dict.get('name', 'Untitled')
|
||||||
|
description = row_dict.get('description', '')
|
||||||
|
emoji = row_dict.get('emoji', '\U0001f4da')
|
||||||
|
embedding_model_uuid = row_dict.get('embedding_model_uuid', '')
|
||||||
|
top_k = row_dict.get('top_k', 5)
|
||||||
|
created_at = row_dict.get('created_at')
|
||||||
|
updated_at = row_dict.get('updated_at')
|
||||||
|
|
||||||
|
creation_settings = json.dumps({'embedding_model_uuid': embedding_model_uuid})
|
||||||
|
retrieval_settings = json.dumps({'top_k': top_k})
|
||||||
|
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
'INSERT INTO knowledge_bases '
|
||||||
|
'(uuid, name, description, emoji, created_at, updated_at, '
|
||||||
|
'knowledge_engine_plugin_id, collection_id, creation_settings, retrieval_settings) '
|
||||||
|
'VALUES (:uuid, :name, :description, :emoji, :created_at, :updated_at, '
|
||||||
|
':plugin_id, :collection_id, :creation_settings, :retrieval_settings);'
|
||||||
|
).bindparams(
|
||||||
|
uuid=kb_uuid,
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
emoji=emoji,
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=updated_at,
|
||||||
|
plugin_id=LANGRAG_PLUGIN_ID,
|
||||||
|
collection_id=kb_uuid,
|
||||||
|
creation_settings=creation_settings,
|
||||||
|
retrieval_settings=retrieval_settings,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = {'embedding_model_uuid': embedding_model_uuid}
|
||||||
|
await self.ap.plugin_connector.rag_on_kb_create(LANGRAG_PLUGIN_ID, kb_uuid, config)
|
||||||
|
task_context.trace(f'Restored internal KB: {name} ({kb_uuid})')
|
||||||
|
except Exception as e:
|
||||||
|
warning = f'Failed to notify plugin for KB {name} ({kb_uuid}): {e}'
|
||||||
|
warnings.append(warning)
|
||||||
|
task_context.trace(warning)
|
||||||
|
|
||||||
|
await self.ap.rag_mgr.load_knowledge_bases_from_db()
|
||||||
|
|
||||||
|
# Step 4: Restore external knowledge bases
|
||||||
|
task_context.trace('Restoring external knowledge bases...', action='restore-external')
|
||||||
|
if has_external:
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('SELECT * FROM external_knowledge_bases;')
|
||||||
|
)
|
||||||
|
rows = result.fetchall()
|
||||||
|
columns = result.keys()
|
||||||
|
|
||||||
|
self.ap.logger.info(
|
||||||
|
f'RAG migration: {len(rows)} external KB(s) to restore. Available engines: {engine_id_set}'
|
||||||
|
)
|
||||||
|
task_context.trace(f'Found {len(rows)} external KB(s). Available engines: {engine_id_set}')
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
row_dict = dict(zip(columns, row))
|
||||||
|
kb_uuid = row_dict.get('uuid')
|
||||||
|
name = row_dict.get('name', 'Untitled')
|
||||||
|
description = row_dict.get('description', '')
|
||||||
|
emoji = row_dict.get('emoji', '\U0001f517')
|
||||||
|
plugin_author = row_dict.get('plugin_author', '')
|
||||||
|
plugin_name = row_dict.get('plugin_name', '')
|
||||||
|
retriever_config = row_dict.get('retriever_config', {})
|
||||||
|
created_at = row_dict.get('created_at')
|
||||||
|
|
||||||
|
mapped_plugin_name = EXTERNAL_PLUGIN_NAME_MAPPING.get(plugin_name, plugin_name)
|
||||||
|
external_plugin_id = f'{plugin_author}/{mapped_plugin_name}'
|
||||||
|
|
||||||
|
self.ap.logger.info(
|
||||||
|
f'RAG migration: processing external KB "{name}" ({kb_uuid}), '
|
||||||
|
f'plugin: {plugin_author}/{plugin_name} -> {external_plugin_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(retriever_config, str):
|
||||||
|
try:
|
||||||
|
retriever_config = json.loads(retriever_config)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
retriever_config = {}
|
||||||
|
|
||||||
|
creation_fields = EXTERNAL_PLUGIN_CREATION_FIELDS.get(external_plugin_id)
|
||||||
|
if creation_fields is None:
|
||||||
|
creation_settings_dict = retriever_config
|
||||||
|
retrieval_settings_dict = {}
|
||||||
|
else:
|
||||||
|
creation_settings_dict = {k: v for k, v in retriever_config.items() if k in creation_fields}
|
||||||
|
retrieval_settings_dict = {k: v for k, v in retriever_config.items() if k not in creation_fields}
|
||||||
|
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
'INSERT INTO knowledge_bases '
|
||||||
|
'(uuid, name, description, emoji, created_at, updated_at, '
|
||||||
|
'knowledge_engine_plugin_id, collection_id, creation_settings, retrieval_settings) '
|
||||||
|
'VALUES (:uuid, :name, :description, :emoji, :created_at, :updated_at, '
|
||||||
|
':plugin_id, :collection_id, :creation_settings, :retrieval_settings);'
|
||||||
|
).bindparams(
|
||||||
|
uuid=kb_uuid,
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
emoji=emoji,
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=created_at,
|
||||||
|
plugin_id=external_plugin_id,
|
||||||
|
collection_id=kb_uuid,
|
||||||
|
creation_settings=json.dumps(creation_settings_dict),
|
||||||
|
retrieval_settings=json.dumps(retrieval_settings_dict),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if external_plugin_id not in engine_id_set:
|
||||||
|
warning = (
|
||||||
|
f'External KB "{name}" ({kb_uuid}) record saved, but plugin {external_plugin_id} '
|
||||||
|
f'is not installed yet. Install the connector plugin to use it.'
|
||||||
|
)
|
||||||
|
warnings.append(warning)
|
||||||
|
task_context.trace(warning)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await self.ap.plugin_connector.rag_on_kb_create(
|
||||||
|
external_plugin_id, kb_uuid, creation_settings_dict
|
||||||
|
)
|
||||||
|
task_context.trace(f'Restored external KB: {name} ({kb_uuid})')
|
||||||
|
except Exception as e:
|
||||||
|
warning = f'Failed to notify plugin for external KB {name} ({kb_uuid}): {e}'
|
||||||
|
warnings.append(warning)
|
||||||
|
task_context.trace(warning)
|
||||||
|
|
||||||
|
await self.ap.rag_mgr.load_knowledge_bases_from_db()
|
||||||
|
|
||||||
|
# Step 5: Clear migration flag
|
||||||
|
await self._set_migration_flag('false')
|
||||||
|
task_context.trace('RAG migration completed.', action='done')
|
||||||
|
|
||||||
|
if warnings:
|
||||||
|
task_context.trace(f'Completed with {len(warnings)} warning(s).')
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
@self.route('/status', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
|
async def _() -> str:
|
||||||
|
needed = await self._get_migration_flag()
|
||||||
|
|
||||||
|
internal_kb_count = 0
|
||||||
|
external_kb_count = 0
|
||||||
|
|
||||||
|
if needed:
|
||||||
|
if await self._table_exists('knowledge_bases_backup'):
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('SELECT COUNT(*) FROM knowledge_bases_backup;')
|
||||||
|
)
|
||||||
|
internal_kb_count = result.scalar() or 0
|
||||||
|
|
||||||
|
if await self._table_exists('external_knowledge_bases'):
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('SELECT COUNT(*) FROM external_knowledge_bases;')
|
||||||
|
)
|
||||||
|
external_kb_count = result.scalar() or 0
|
||||||
|
|
||||||
|
return self.success(
|
||||||
|
data={
|
||||||
|
'needed': needed,
|
||||||
|
'internal_kb_count': internal_kb_count,
|
||||||
|
'external_kb_count': external_kb_count,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@self.route('/execute', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
|
async def _() -> str:
|
||||||
|
needed = await self._get_migration_flag()
|
||||||
|
if not needed:
|
||||||
|
return self.http_status(400, -1, 'RAG migration is not needed')
|
||||||
|
|
||||||
|
data = await quart.request.get_json(silent=True) or {}
|
||||||
|
install_plugin = data.get('install_plugin', True)
|
||||||
|
|
||||||
|
ctx = taskmgr.TaskContext.new()
|
||||||
|
wrapper = self.ap.task_mgr.create_user_task(
|
||||||
|
self._execute_rag_migration(task_context=ctx, install_plugin=install_plugin),
|
||||||
|
kind='rag-migration',
|
||||||
|
name='rag-migration-execute',
|
||||||
|
label='Migrating knowledge bases to plugin architecture',
|
||||||
|
context=ctx,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.success(data={'task_id': wrapper.id})
|
||||||
|
|
||||||
|
@self.route('/dismiss', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
|
async def _() -> str:
|
||||||
|
needed = await self._get_migration_flag()
|
||||||
|
if not needed:
|
||||||
|
return self.http_status(400, -1, 'RAG migration is not needed')
|
||||||
|
|
||||||
|
await self._set_migration_flag('false')
|
||||||
|
return self.success()
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
import quart
|
||||||
|
from ... import group
|
||||||
|
|
||||||
|
|
||||||
|
@group.group_class('parsers', '/api/v1/knowledge/parsers')
|
||||||
|
class ParsersRouterGroup(group.RouterGroup):
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||||
|
async def list_parsers() -> quart.Response:
|
||||||
|
"""List all available parsers from plugins.
|
||||||
|
|
||||||
|
Optional query parameter `mime_type` to filter parsers by supported MIME type.
|
||||||
|
"""
|
||||||
|
mime_type = quart.request.args.get('mime_type')
|
||||||
|
parsers = await self.ap.knowledge_service.list_parsers(mime_type)
|
||||||
|
return self.success(data={'parsers': parsers})
|
||||||
@@ -68,7 +68,7 @@ class PipelinesRouterGroup(group.RouterGroup):
|
|||||||
return self.http_status(404, -1, 'pipeline not found')
|
return self.http_status(404, -1, 'pipeline not found')
|
||||||
|
|
||||||
# Only include plugins with pipeline-related components (Command, EventListener, Tool)
|
# Only include plugins with pipeline-related components (Command, EventListener, Tool)
|
||||||
# Plugins that only have KnowledgeRetriever components are not suitable for pipeline extensions
|
# Plugins that only have KnowledgeEngine components are not suitable for pipeline extensions
|
||||||
pipeline_component_kinds = ['Command', 'EventListener', 'Tool']
|
pipeline_component_kinds = ['Command', 'EventListener', 'Tool']
|
||||||
plugins = await self.ap.plugin_connector.list_plugins(component_kinds=pipeline_component_kinds)
|
plugins = await self.ap.plugin_connector.list_plugins(component_kinds=pipeline_component_kinds)
|
||||||
mcp_servers = await self.ap.mcp_service.get_mcp_servers(contain_runtime_info=True)
|
mcp_servers = await self.ap.mcp_service.get_mcp_servers(contain_runtime_info=True)
|
||||||
|
|||||||
@@ -70,12 +70,17 @@ class BotService:
|
|||||||
'lark',
|
'lark',
|
||||||
]:
|
]:
|
||||||
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
||||||
|
extra_webhook_prefix = self.ap.instance_config.data['api'].get('extra_webhook_prefix', '')
|
||||||
webhook_url = f'/bots/{bot_uuid}'
|
webhook_url = f'/bots/{bot_uuid}'
|
||||||
adapter_runtime_values['webhook_url'] = webhook_url
|
adapter_runtime_values['webhook_url'] = webhook_url
|
||||||
adapter_runtime_values['webhook_full_url'] = f'{webhook_prefix}{webhook_url}'
|
adapter_runtime_values['webhook_full_url'] = f'{webhook_prefix}{webhook_url}'
|
||||||
|
adapter_runtime_values['extra_webhook_full_url'] = (
|
||||||
|
f'{extra_webhook_prefix}{webhook_url}' if extra_webhook_prefix else ''
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
adapter_runtime_values['webhook_url'] = None
|
adapter_runtime_values['webhook_url'] = None
|
||||||
adapter_runtime_values['webhook_full_url'] = None
|
adapter_runtime_values['webhook_full_url'] = None
|
||||||
|
adapter_runtime_values['extra_webhook_full_url'] = None
|
||||||
|
|
||||||
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
|
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
|
||||||
|
|
||||||
|
|||||||
@@ -1,80 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from ....core import app
|
|
||||||
import sqlalchemy
|
|
||||||
from langbot.pkg.entity.persistence import rag as persistence_rag
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
|
|
||||||
class ExternalKBService:
|
|
||||||
"""External KB service"""
|
|
||||||
|
|
||||||
ap: app.Application
|
|
||||||
|
|
||||||
def __init__(self, ap: app.Application) -> None:
|
|
||||||
self.ap = ap
|
|
||||||
|
|
||||||
# External Knowledge Base methods
|
|
||||||
async def get_external_knowledge_bases(self) -> list[dict]:
|
|
||||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.ExternalKnowledgeBase))
|
|
||||||
external_kbs = result.all()
|
|
||||||
return [
|
|
||||||
self.ap.persistence_mgr.serialize_model(persistence_rag.ExternalKnowledgeBase, external_kb)
|
|
||||||
for external_kb in external_kbs
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_external_knowledge_base(self, kb_uuid: str) -> dict | None:
|
|
||||||
result = await self.ap.persistence_mgr.execute_async(
|
|
||||||
sqlalchemy.select(persistence_rag.ExternalKnowledgeBase).where(
|
|
||||||
persistence_rag.ExternalKnowledgeBase.uuid == kb_uuid
|
|
||||||
)
|
|
||||||
)
|
|
||||||
external_kb = result.first()
|
|
||||||
if external_kb is None:
|
|
||||||
return None
|
|
||||||
return self.ap.persistence_mgr.serialize_model(persistence_rag.ExternalKnowledgeBase, external_kb)
|
|
||||||
|
|
||||||
async def create_external_knowledge_base(self, kb_data: dict) -> str:
|
|
||||||
kb_data['uuid'] = str(uuid.uuid4())
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
|
||||||
sqlalchemy.insert(persistence_rag.ExternalKnowledgeBase).values(kb_data)
|
|
||||||
)
|
|
||||||
|
|
||||||
kb = await self.get_external_knowledge_base(kb_data['uuid'])
|
|
||||||
|
|
||||||
await self.ap.rag_mgr.load_external_knowledge_base(kb)
|
|
||||||
|
|
||||||
return kb_data['uuid']
|
|
||||||
|
|
||||||
async def retrieve_external_knowledge_base(self, kb_uuid: str, query: str) -> list[dict]:
|
|
||||||
"""Retrieve external knowledge base"""
|
|
||||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
|
||||||
if runtime_kb is None:
|
|
||||||
raise Exception('Knowledge base not found')
|
|
||||||
return [
|
|
||||||
result.model_dump() for result in await runtime_kb.retrieve(query, 5)
|
|
||||||
] # top_k is just a placeholder for external knowledge base
|
|
||||||
|
|
||||||
async def update_external_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
|
|
||||||
if 'uuid' in kb_data:
|
|
||||||
del kb_data['uuid']
|
|
||||||
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
|
||||||
sqlalchemy.update(persistence_rag.ExternalKnowledgeBase)
|
|
||||||
.values(kb_data)
|
|
||||||
.where(persistence_rag.ExternalKnowledgeBase.uuid == kb_uuid)
|
|
||||||
)
|
|
||||||
await self.ap.rag_mgr.remove_knowledge_base_from_runtime(kb_uuid)
|
|
||||||
|
|
||||||
kb = await self.get_external_knowledge_base(kb_uuid)
|
|
||||||
|
|
||||||
await self.ap.rag_mgr.load_external_knowledge_base(kb)
|
|
||||||
|
|
||||||
async def delete_external_knowledge_base(self, kb_uuid: str) -> None:
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
|
||||||
sqlalchemy.delete(persistence_rag.ExternalKnowledgeBase).where(
|
|
||||||
persistence_rag.ExternalKnowledgeBase.uuid == kb_uuid
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.ap.rag_mgr.delete_knowledge_base(kb_uuid)
|
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
from ....core import app
|
from ....core import app
|
||||||
@@ -17,64 +16,77 @@ class KnowledgeService:
|
|||||||
|
|
||||||
async def get_knowledge_bases(self) -> list[dict]:
|
async def get_knowledge_bases(self) -> list[dict]:
|
||||||
"""获取所有知识库"""
|
"""获取所有知识库"""
|
||||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
|
return await self.ap.rag_mgr.get_all_knowledge_base_details()
|
||||||
knowledge_bases = result.all()
|
|
||||||
return [
|
|
||||||
self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base)
|
|
||||||
for knowledge_base in knowledge_bases
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_knowledge_base(self, kb_uuid: str) -> dict | None:
|
async def get_knowledge_base(self, kb_uuid: str) -> dict | None:
|
||||||
"""获取知识库"""
|
"""获取知识库"""
|
||||||
result = await self.ap.persistence_mgr.execute_async(
|
return await self.ap.rag_mgr.get_knowledge_base_details(kb_uuid)
|
||||||
sqlalchemy.select(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
|
||||||
)
|
|
||||||
knowledge_base = result.first()
|
|
||||||
if knowledge_base is None:
|
|
||||||
return None
|
|
||||||
return self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base)
|
|
||||||
|
|
||||||
async def create_knowledge_base(self, kb_data: dict) -> str:
|
async def create_knowledge_base(self, kb_data: dict) -> str:
|
||||||
"""创建知识库"""
|
"""创建知识库"""
|
||||||
kb_data['uuid'] = str(uuid.uuid4())
|
# In new architecture, we delegate entirely to RAGManager which uses plugins.
|
||||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.KnowledgeBase).values(kb_data))
|
# Legacy internal KB creation is removed.
|
||||||
|
|
||||||
kb = await self.get_knowledge_base(kb_data['uuid'])
|
knowledge_engine_plugin_id = kb_data.get('knowledge_engine_plugin_id')
|
||||||
|
if not knowledge_engine_plugin_id:
|
||||||
|
raise ValueError('knowledge_engine_plugin_id is required')
|
||||||
|
|
||||||
await self.ap.rag_mgr.load_knowledge_base(kb)
|
kb = await self.ap.rag_mgr.create_knowledge_base(
|
||||||
|
name=kb_data.get('name', 'Untitled'),
|
||||||
return kb_data['uuid']
|
knowledge_engine_plugin_id=knowledge_engine_plugin_id,
|
||||||
|
creation_settings=kb_data.get('creation_settings', {}),
|
||||||
|
retrieval_settings=kb_data.get('retrieval_settings', {}),
|
||||||
|
description=kb_data.get('description', ''),
|
||||||
|
)
|
||||||
|
return kb.uuid
|
||||||
|
|
||||||
async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
|
async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
|
||||||
"""更新知识库"""
|
"""更新知识库"""
|
||||||
if 'uuid' in kb_data:
|
# Filter to only mutable fields
|
||||||
del kb_data['uuid']
|
filtered_data = {k: v for k, v in kb_data.items() if k in persistence_rag.KnowledgeBase.MUTABLE_FIELDS}
|
||||||
|
|
||||||
if 'embedding_model_uuid' in kb_data:
|
if not filtered_data:
|
||||||
del kb_data['embedding_model_uuid']
|
return
|
||||||
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.update(persistence_rag.KnowledgeBase)
|
sqlalchemy.update(persistence_rag.KnowledgeBase)
|
||||||
.values(kb_data)
|
.values(filtered_data)
|
||||||
.where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
.where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||||
)
|
)
|
||||||
await self.ap.rag_mgr.remove_knowledge_base_from_runtime(kb_uuid)
|
await self.ap.rag_mgr.remove_knowledge_base_from_runtime(kb_uuid)
|
||||||
|
|
||||||
kb = await self.get_knowledge_base(kb_uuid)
|
kb = await self.get_knowledge_base(kb_uuid)
|
||||||
|
if kb is None:
|
||||||
|
raise Exception('Knowledge base not found after update')
|
||||||
|
|
||||||
await self.ap.rag_mgr.load_knowledge_base(kb)
|
await self.ap.rag_mgr.load_knowledge_base(kb)
|
||||||
|
|
||||||
async def store_file(self, kb_uuid: str, file_id: str) -> int:
|
async def _check_doc_capability(self, kb_uuid: str, operation: str) -> None:
|
||||||
|
"""Check if the KB's Knowledge Engine supports document operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kb_uuid: Knowledge base UUID.
|
||||||
|
operation: Human-readable operation name for error messages.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the KB does not support doc_ingestion.
|
||||||
|
"""
|
||||||
|
kb_info = await self.ap.rag_mgr.get_knowledge_base_details(kb_uuid)
|
||||||
|
if not kb_info:
|
||||||
|
raise Exception('Knowledge base not found')
|
||||||
|
capabilities = kb_info.get('knowledge_engine', {}).get('capabilities', [])
|
||||||
|
if 'doc_ingestion' not in capabilities:
|
||||||
|
raise Exception(f'This knowledge base does not support {operation}')
|
||||||
|
|
||||||
|
async def store_file(self, kb_uuid: str, file_id: str, parser_plugin_id: str | None = None) -> str:
|
||||||
"""存储文件"""
|
"""存储文件"""
|
||||||
# await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(kb_id=kb_uuid, file_id=file_id))
|
|
||||||
# await self.ap.rag_mgr.store_file(file_id)
|
|
||||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||||
if runtime_kb is None:
|
if runtime_kb is None:
|
||||||
raise Exception('Knowledge base not found')
|
raise Exception('Knowledge base not found')
|
||||||
# Only internal KBs support file storage
|
|
||||||
if runtime_kb.get_type() != 'internal':
|
await self._check_doc_capability(kb_uuid, 'document upload')
|
||||||
raise Exception('Only internal knowledge bases support file storage')
|
|
||||||
result = await runtime_kb.store_file(file_id)
|
result = await runtime_kb.store_file(file_id, parser_plugin_id=parser_plugin_id)
|
||||||
|
|
||||||
# Update the KB's updated_at timestamp
|
# Update the KB's updated_at timestamp
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
@@ -85,14 +97,18 @@ class KnowledgeService:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def retrieve_knowledge_base(self, kb_uuid: str, query: str) -> list[dict]:
|
async def retrieve_knowledge_base(
|
||||||
|
self, kb_uuid: str, query: str, retrieval_settings: dict | None = None
|
||||||
|
) -> list[dict]:
|
||||||
"""检索知识库"""
|
"""检索知识库"""
|
||||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||||
if runtime_kb is None:
|
if runtime_kb is None:
|
||||||
raise Exception('Knowledge base not found')
|
raise Exception('Knowledge base not found')
|
||||||
return [
|
|
||||||
result.model_dump() for result in await runtime_kb.retrieve(query, runtime_kb.knowledge_base_entity.top_k)
|
# Pass retrieval_settings
|
||||||
]
|
results = await runtime_kb.retrieve(query, settings=retrieval_settings)
|
||||||
|
|
||||||
|
return [result.model_dump() for result in results]
|
||||||
|
|
||||||
async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]:
|
async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]:
|
||||||
"""获取知识库文件"""
|
"""获取知识库文件"""
|
||||||
@@ -107,9 +123,9 @@ class KnowledgeService:
|
|||||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||||
if runtime_kb is None:
|
if runtime_kb is None:
|
||||||
raise Exception('Knowledge base not found')
|
raise Exception('Knowledge base not found')
|
||||||
# Only internal KBs support file deletion
|
|
||||||
if runtime_kb.get_type() != 'internal':
|
await self._check_doc_capability(kb_uuid, 'document deletion')
|
||||||
raise Exception('Only internal knowledge bases support file deletion')
|
|
||||||
await runtime_kb.delete_file(file_id)
|
await runtime_kb.delete_file(file_id)
|
||||||
|
|
||||||
# Update the KB's updated_at timestamp
|
# Update the KB's updated_at timestamp
|
||||||
@@ -121,13 +137,14 @@ class KnowledgeService:
|
|||||||
|
|
||||||
async def delete_knowledge_base(self, kb_uuid: str) -> None:
|
async def delete_knowledge_base(self, kb_uuid: str) -> None:
|
||||||
"""删除知识库"""
|
"""删除知识库"""
|
||||||
await self.ap.rag_mgr.delete_knowledge_base(kb_uuid)
|
# Delete from DB first to commit the deletion, then clean up runtime/plugin (best-effort)
|
||||||
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||||
)
|
)
|
||||||
|
|
||||||
# delete files
|
# delete files
|
||||||
|
# NOTE: Chunk cleanup is for legacy (pre-plugin) KBs that stored chunks locally.
|
||||||
|
# For plugin-based Knowledge Engines, the Chunk table is not populated, so this is a no-op.
|
||||||
files = await self.ap.persistence_mgr.execute_async(
|
files = await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid)
|
sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid)
|
||||||
)
|
)
|
||||||
@@ -140,3 +157,53 @@ class KnowledgeService:
|
|||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file.uuid)
|
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file.uuid)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Remove from runtime and notify plugin (best-effort, DB is already cleaned up)
|
||||||
|
await self.ap.rag_mgr.delete_knowledge_base(kb_uuid)
|
||||||
|
|
||||||
|
# ================= Knowledge Engine Discovery =================
|
||||||
|
|
||||||
|
async def list_knowledge_engines(self) -> list[dict]:
|
||||||
|
"""List all available Knowledge Engines from plugins."""
|
||||||
|
engines = []
|
||||||
|
|
||||||
|
if not self.ap.plugin_connector.is_enable_plugin:
|
||||||
|
return engines
|
||||||
|
|
||||||
|
# Get KnowledgeEngine plugins
|
||||||
|
try:
|
||||||
|
knowledge_engines = await self.ap.plugin_connector.list_knowledge_engines()
|
||||||
|
engines.extend(knowledge_engines)
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.warning(f'Failed to list Knowledge Engines from plugins: {e}')
|
||||||
|
|
||||||
|
return engines
|
||||||
|
|
||||||
|
async def list_parsers(self, mime_type: str | None = None) -> list[dict]:
|
||||||
|
"""List available parsers, optionally filtered by MIME type."""
|
||||||
|
if not self.ap.plugin_connector.is_enable_plugin:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
parsers = await self.ap.plugin_connector.list_parsers()
|
||||||
|
if mime_type:
|
||||||
|
parsers = [p for p in parsers if mime_type in p.get('supported_mime_types', [])]
|
||||||
|
return parsers
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.warning(f'Failed to list parsers: {e}')
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_engine_creation_schema(self, plugin_id: str) -> dict:
|
||||||
|
"""Get creation settings schema for a specific Knowledge Engine."""
|
||||||
|
try:
|
||||||
|
return await self.ap.plugin_connector.get_rag_creation_schema(plugin_id)
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.warning(f'Failed to get creation schema for {plugin_id}: {e}')
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def get_engine_retrieval_schema(self, plugin_id: str) -> dict:
|
||||||
|
"""Get retrieval settings schema for a specific Knowledge Engine."""
|
||||||
|
try:
|
||||||
|
return await self.ap.plugin_connector.get_rag_retrieval_schema(plugin_id)
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.warning(f'Failed to get retrieval schema for {plugin_id}: {e}')
|
||||||
|
return {}
|
||||||
|
|||||||
@@ -105,11 +105,16 @@ class LLMModelsService:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
pipeline = result.first()
|
pipeline = result.first()
|
||||||
if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '':
|
if pipeline is not None:
|
||||||
pipeline_config = pipeline.config
|
model_config = pipeline.config.get('ai', {}).get('local-agent', {}).get('model', {})
|
||||||
pipeline_config['ai']['local-agent']['model'] = model_data['uuid']
|
if not model_config.get('primary', ''):
|
||||||
pipeline_data = {'config': pipeline_config}
|
pipeline_config = pipeline.config
|
||||||
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
|
pipeline_config['ai']['local-agent']['model'] = {
|
||||||
|
'primary': model_data['uuid'],
|
||||||
|
'fallbacks': [],
|
||||||
|
}
|
||||||
|
pipeline_data = {'config': pipeline_config}
|
||||||
|
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
|
||||||
|
|
||||||
return model_data['uuid']
|
return model_data['uuid']
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class MonitoringService:
|
|||||||
level: str = 'info',
|
level: str = 'info',
|
||||||
platform: str | None = None,
|
platform: str | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
|
user_name: str | None = None,
|
||||||
runner_name: str | None = None,
|
runner_name: str | None = None,
|
||||||
variables: str | None = None,
|
variables: str | None = None,
|
||||||
role: str = 'user',
|
role: str = 'user',
|
||||||
@@ -49,6 +50,7 @@ class MonitoringService:
|
|||||||
'level': level,
|
'level': level,
|
||||||
'platform': platform,
|
'platform': platform,
|
||||||
'user_id': user_id,
|
'user_id': user_id,
|
||||||
|
'user_name': user_name,
|
||||||
'runner_name': runner_name,
|
'runner_name': runner_name,
|
||||||
'variables': variables,
|
'variables': variables,
|
||||||
'role': role,
|
'role': role,
|
||||||
@@ -152,6 +154,7 @@ class MonitoringService:
|
|||||||
pipeline_name: str,
|
pipeline_name: str,
|
||||||
platform: str | None = None,
|
platform: str | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
|
user_name: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Record a new session"""
|
"""Record a new session"""
|
||||||
session_data = {
|
session_data = {
|
||||||
@@ -166,6 +169,7 @@ class MonitoringService:
|
|||||||
'is_active': True,
|
'is_active': True,
|
||||||
'platform': platform,
|
'platform': platform,
|
||||||
'user_id': user_id,
|
'user_id': user_id,
|
||||||
|
'user_name': user_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import aiohttp
|
from langbot.pkg.utils import httpclient
|
||||||
import typing
|
import typing
|
||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
@@ -99,49 +99,49 @@ class SpaceService:
|
|||||||
space_config = self._get_space_config()
|
space_config = self._get_space_config()
|
||||||
space_url = space_config['url']
|
space_url = space_config['url']
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f'{space_url}/api/v1/accounts/oauth/token',
|
f'{space_url}/api/v1/accounts/oauth/token',
|
||||||
json={'code': code, 'instance_id': constants.instance_id},
|
json={'code': code, 'instance_id': constants.instance_id},
|
||||||
) as response:
|
) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
raise ValueError(f'Failed to exchange OAuth code: {await response.text()}')
|
raise ValueError(f'Failed to exchange OAuth code: {await response.text()}')
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
if data.get('code') != 0:
|
if data.get('code') != 0:
|
||||||
raise ValueError(f'Failed to exchange OAuth code: {data.get("msg")}')
|
raise ValueError(f'Failed to exchange OAuth code: {data.get("msg")}')
|
||||||
return data.get('data', {})
|
return data.get('data', {})
|
||||||
|
|
||||||
async def refresh_token(self, refresh_token: str) -> typing.Dict:
|
async def refresh_token(self, refresh_token: str) -> typing.Dict:
|
||||||
"""Refresh Space access token"""
|
"""Refresh Space access token"""
|
||||||
space_config = self._get_space_config()
|
space_config = self._get_space_config()
|
||||||
space_url = space_config['url']
|
space_url = space_config['url']
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f'{space_url}/api/v1/accounts/token/refresh', json={'refresh_token': refresh_token}
|
f'{space_url}/api/v1/accounts/token/refresh', json={'refresh_token': refresh_token}
|
||||||
) as response:
|
) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
raise ValueError(f'Failed to refresh token: {await response.text()}')
|
raise ValueError(f'Failed to refresh token: {await response.text()}')
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
if data.get('code') != 0:
|
if data.get('code') != 0:
|
||||||
raise ValueError(f'Failed to refresh token: {data.get("msg")}')
|
raise ValueError(f'Failed to refresh token: {data.get("msg")}')
|
||||||
return data.get('data', {})
|
return data.get('data', {})
|
||||||
|
|
||||||
async def get_user_info_raw(self, access_token: str) -> typing.Dict:
|
async def get_user_info_raw(self, access_token: str) -> typing.Dict:
|
||||||
"""Get user info from Space using access token (no validation)"""
|
"""Get user info from Space using access token (no validation)"""
|
||||||
space_config = self._get_space_config()
|
space_config = self._get_space_config()
|
||||||
space_url = space_config['url']
|
space_url = space_config['url']
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f'{space_url}/api/v1/accounts/me', headers={'Authorization': f'Bearer {access_token}'}
|
f'{space_url}/api/v1/accounts/me', headers={'Authorization': f'Bearer {access_token}'}
|
||||||
) as response:
|
) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
raise ValueError(f'Failed to get user info: {await response.text()}')
|
raise ValueError(f'Failed to get user info: {await response.text()}')
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
if data.get('code') != 0:
|
if data.get('code') != 0:
|
||||||
raise ValueError(f'Failed to get user info: {data.get("msg")}')
|
raise ValueError(f'Failed to get user info: {data.get("msg")}')
|
||||||
return data.get('data', {})
|
return data.get('data', {})
|
||||||
|
|
||||||
# === API calls with token validation ===
|
# === API calls with token validation ===
|
||||||
|
|
||||||
@@ -178,12 +178,12 @@ class SpaceService:
|
|||||||
space_config = self._get_space_config()
|
space_config = self._get_space_config()
|
||||||
space_url = space_config['url']
|
space_url = space_config['url']
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(f'{space_url}/api/v1/models') as response:
|
async with session.get(f'{space_url}/api/v1/models') as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
raise ValueError(f'Failed to get models: {await response.text()}')
|
raise ValueError(f'Failed to get models: {await response.text()}')
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
if data.get('code') != 0:
|
if data.get('code') != 0:
|
||||||
raise ValueError(f'Failed to get models: {data.get("msg")}')
|
raise ValueError(f'Failed to get models: {data.get("msg")}')
|
||||||
models_data = data.get('data', {}).get('models', [])
|
models_data = data.get('data', {}).get('models', [])
|
||||||
return [SpaceModel.model_validate(model_dict) for model_dict in models_data]
|
return [SpaceModel.model_validate(model_dict) for model_dict in models_data]
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from ..platform import botmgr as im_mgr
|
|||||||
from ..platform.webhook_pusher import WebhookPusher
|
from ..platform.webhook_pusher import WebhookPusher
|
||||||
from ..provider.session import sessionmgr as llm_session_mgr
|
from ..provider.session import sessionmgr as llm_session_mgr
|
||||||
from ..provider.modelmgr import modelmgr as llm_model_mgr
|
from ..provider.modelmgr import modelmgr as llm_model_mgr
|
||||||
|
|
||||||
from langbot.pkg.provider.tools import toolmgr as llm_tool_mgr
|
from langbot.pkg.provider.tools import toolmgr as llm_tool_mgr
|
||||||
from ..config import manager as config_mgr
|
from ..config import manager as config_mgr
|
||||||
from ..command import cmdmgr
|
from ..command import cmdmgr
|
||||||
@@ -29,14 +30,15 @@ from ..api.http.service import knowledge as knowledge_service
|
|||||||
from ..api.http.service import mcp as mcp_service
|
from ..api.http.service import mcp as mcp_service
|
||||||
from ..api.http.service import apikey as apikey_service
|
from ..api.http.service import apikey as apikey_service
|
||||||
from ..api.http.service import webhook as webhook_service
|
from ..api.http.service import webhook as webhook_service
|
||||||
from ..api.http.service import external_kb as external_kb_service
|
|
||||||
from ..api.http.service import monitoring as monitoring_service
|
from ..api.http.service import monitoring as monitoring_service
|
||||||
|
|
||||||
from ..discover import engine as discover_engine
|
from ..discover import engine as discover_engine
|
||||||
from ..storage import mgr as storagemgr
|
from ..storage import mgr as storagemgr
|
||||||
from ..utils import logcache
|
from ..utils import logcache
|
||||||
from . import taskmgr
|
from . import taskmgr
|
||||||
from . import entities as core_entities
|
from . import entities as core_entities
|
||||||
from ..rag.knowledge import kbmgr as rag_mgr
|
from ..rag.knowledge import kbmgr as rag_mgr
|
||||||
|
from ..rag.service import RAGRuntimeService
|
||||||
from ..vector import mgr as vectordb_mgr
|
from ..vector import mgr as vectordb_mgr
|
||||||
from ..telemetry import telemetry as telemetry_module
|
from ..telemetry import telemetry as telemetry_module
|
||||||
from ..survey import manager as survey_module
|
from ..survey import manager as survey_module
|
||||||
@@ -63,6 +65,7 @@ class Application:
|
|||||||
model_mgr: llm_model_mgr.ModelManager = None
|
model_mgr: llm_model_mgr.ModelManager = None
|
||||||
|
|
||||||
rag_mgr: rag_mgr.RAGManager = None
|
rag_mgr: rag_mgr.RAGManager = None
|
||||||
|
rag_runtime_service: RAGRuntimeService = None
|
||||||
|
|
||||||
# TODO move to pipeline
|
# TODO move to pipeline
|
||||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||||
@@ -138,8 +141,6 @@ class Application:
|
|||||||
|
|
||||||
knowledge_service: knowledge_service.KnowledgeService = None
|
knowledge_service: knowledge_service.KnowledgeService = None
|
||||||
|
|
||||||
external_kb_service: external_kb_service.ExternalKBService = None
|
|
||||||
|
|
||||||
mcp_service: mcp_service.MCPService = None
|
mcp_service: mcp_service.MCPService = None
|
||||||
|
|
||||||
apikey_service: apikey_service.ApiKeyService = None
|
apikey_service: apikey_service.ApiKeyService = None
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import importlib.util
|
||||||
import pip
|
import pip
|
||||||
import os
|
import os
|
||||||
from ...utils import pkgmgr
|
from ...utils import pkgmgr
|
||||||
@@ -49,9 +50,10 @@ async def check_deps() -> list[str]:
|
|||||||
|
|
||||||
missing_deps = []
|
missing_deps = []
|
||||||
for dep in required_deps:
|
for dep in required_deps:
|
||||||
try:
|
# Use find_spec instead of __import__ to avoid actually loading
|
||||||
__import__(dep)
|
# all modules into memory. find_spec only checks if the module
|
||||||
except ImportError:
|
# can be found, without executing module-level code.
|
||||||
|
if importlib.util.find_spec(dep) is None:
|
||||||
missing_deps.append(dep)
|
missing_deps.append(dep)
|
||||||
return missing_deps
|
return missing_deps
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from ...provider.session import sessionmgr as llm_session_mgr
|
|||||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||||
from ...rag.knowledge import kbmgr as rag_mgr
|
from ...rag.knowledge import kbmgr as rag_mgr
|
||||||
|
from ...rag.service import RAGRuntimeService
|
||||||
from ...platform import botmgr as im_mgr
|
from ...platform import botmgr as im_mgr
|
||||||
from ...platform.webhook_pusher import WebhookPusher
|
from ...platform.webhook_pusher import WebhookPusher
|
||||||
from ...persistence import mgr as persistencemgr
|
from ...persistence import mgr as persistencemgr
|
||||||
@@ -26,7 +27,6 @@ from ...api.http.service import knowledge as knowledge_service
|
|||||||
from ...api.http.service import mcp as mcp_service
|
from ...api.http.service import mcp as mcp_service
|
||||||
from ...api.http.service import apikey as apikey_service
|
from ...api.http.service import apikey as apikey_service
|
||||||
from ...api.http.service import webhook as webhook_service
|
from ...api.http.service import webhook as webhook_service
|
||||||
from ...api.http.service import external_kb as external_kb_service
|
|
||||||
from ...api.http.service import monitoring as monitoring_service
|
from ...api.http.service import monitoring as monitoring_service
|
||||||
from ...discover import engine as discover_engine
|
from ...discover import engine as discover_engine
|
||||||
from ...storage import mgr as storagemgr
|
from ...storage import mgr as storagemgr
|
||||||
@@ -73,9 +73,6 @@ class BuildAppStage(stage.BootingStage):
|
|||||||
knowledge_service_inst = knowledge_service.KnowledgeService(ap)
|
knowledge_service_inst = knowledge_service.KnowledgeService(ap)
|
||||||
ap.knowledge_service = knowledge_service_inst
|
ap.knowledge_service = knowledge_service_inst
|
||||||
|
|
||||||
external_kb_service_inst = external_kb_service.ExternalKBService(ap)
|
|
||||||
ap.external_kb_service = external_kb_service_inst
|
|
||||||
|
|
||||||
mcp_service_inst = mcp_service.MCPService(ap)
|
mcp_service_inst = mcp_service.MCPService(ap)
|
||||||
ap.mcp_service = mcp_service_inst
|
ap.mcp_service = mcp_service_inst
|
||||||
|
|
||||||
@@ -152,6 +149,9 @@ class BuildAppStage(stage.BootingStage):
|
|||||||
await rag_mgr_inst.initialize()
|
await rag_mgr_inst.initialize()
|
||||||
ap.rag_mgr = rag_mgr_inst
|
ap.rag_mgr = rag_mgr_inst
|
||||||
|
|
||||||
|
# Initialize RAG Runtime Service for plugins
|
||||||
|
ap.rag_runtime_service = RAGRuntimeService(ap)
|
||||||
|
|
||||||
# 初始化向量数据库管理器
|
# 初始化向量数据库管理器
|
||||||
vectordb_mgr_inst = vectordb_mgr.VectorDBManager(ap)
|
vectordb_mgr_inst = vectordb_mgr.VectorDBManager(ap)
|
||||||
await vectordb_mgr_inst.initialize()
|
await vectordb_mgr_inst.initialize()
|
||||||
|
|||||||
@@ -74,20 +74,26 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict:
|
|||||||
current = cfg
|
current = cfg
|
||||||
|
|
||||||
for i, key in enumerate(keys):
|
for i, key in enumerate(keys):
|
||||||
if not isinstance(current, dict) or key not in current:
|
if not isinstance(current, dict):
|
||||||
break
|
break
|
||||||
|
|
||||||
if i == len(keys) - 1:
|
if i == len(keys) - 1:
|
||||||
# At the final key - check if it's a scalar value
|
# At the final key
|
||||||
if isinstance(current[key], (dict, list)):
|
if key in current:
|
||||||
# Skip dict and list types
|
if isinstance(current[key], (dict, list)):
|
||||||
pass
|
# Skip dict and list types
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Valid scalar value - convert and set it
|
||||||
|
converted_value = convert_value(env_value, current[key])
|
||||||
|
current[key] = converted_value
|
||||||
else:
|
else:
|
||||||
# Valid scalar value - convert and set it
|
# Key doesn't exist yet - create it as string
|
||||||
converted_value = convert_value(env_value, current[key])
|
current[key] = env_value
|
||||||
current[key] = converted_value
|
|
||||||
else:
|
else:
|
||||||
# Navigate deeper
|
# Navigate deeper - create intermediate dict if needed
|
||||||
|
if key not in current:
|
||||||
|
current[key] = {}
|
||||||
current = current[key]
|
current = current[key]
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
@@ -146,16 +152,50 @@ class LoadConfigStage(stage.BootingStage):
|
|||||||
await ap.instance_config.dump_config()
|
await ap.instance_config.dump_config()
|
||||||
|
|
||||||
# load or generate instance id
|
# load or generate instance id
|
||||||
ap.instance_id = await config.load_json_config(
|
# Priority:
|
||||||
'data/labels/instance_id.json',
|
# 1. system.instance_id from config.yaml (can be set via SYSTEM__INSTANCE_ID env var)
|
||||||
template_data={
|
# 2. data/labels/instance_id.json (if file exists)
|
||||||
'instance_id': f'instance_{str(uuid.uuid4())}',
|
# 3. Generate new and save to file
|
||||||
'instance_create_ts': int(time.time()),
|
config_instance_id = ap.instance_config.data.get('system', {}).get('instance_id', '')
|
||||||
},
|
|
||||||
completion=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
constants.instance_id = ap.instance_id.data['instance_id']
|
if config_instance_id:
|
||||||
|
# Use the instance_id from config.yaml
|
||||||
|
constants.instance_id = config_instance_id
|
||||||
|
# Still load/create the file for backward compat, but don't use its value
|
||||||
|
ap.instance_id = await config.load_json_config(
|
||||||
|
'data/labels/instance_id.json',
|
||||||
|
template_data={
|
||||||
|
'instance_id': f'instance_{str(uuid.uuid4())}',
|
||||||
|
'instance_create_ts': int(time.time()),
|
||||||
|
},
|
||||||
|
completion=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Try loading file-based instance id
|
||||||
|
instance_id_path = os.path.join('data', 'labels', 'instance_id.json')
|
||||||
|
if os.path.exists(instance_id_path):
|
||||||
|
# File exists, read it
|
||||||
|
ap.instance_id = await config.load_json_config(
|
||||||
|
'data/labels/instance_id.json',
|
||||||
|
template_data={
|
||||||
|
'instance_id': '',
|
||||||
|
'instance_create_ts': 0,
|
||||||
|
},
|
||||||
|
completion=False,
|
||||||
|
)
|
||||||
|
constants.instance_id = ap.instance_id.data['instance_id']
|
||||||
|
else:
|
||||||
|
# Neither config nor file, generate new and save to file
|
||||||
|
new_id = f'instance_{str(uuid.uuid4())}'
|
||||||
|
ap.instance_id = await config.load_json_config(
|
||||||
|
'data/labels/instance_id.json',
|
||||||
|
template_data={
|
||||||
|
'instance_id': new_id,
|
||||||
|
'instance_create_ts': int(time.time()),
|
||||||
|
},
|
||||||
|
completion=False,
|
||||||
|
)
|
||||||
|
constants.instance_id = new_id
|
||||||
constants.edition = ap.instance_config.data.get('system', {}).get('edition', 'community')
|
constants.edition = ap.instance_config.data.get('system', {}).get('edition', 'community')
|
||||||
|
|
||||||
print(f'LangBot instance id: {constants.instance_id}')
|
print(f'LangBot instance id: {constants.instance_id}')
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class MonitoringMessage(Base):
|
|||||||
level = sqlalchemy.Column(sqlalchemy.String(50), nullable=False) # info, warning, error, debug
|
level = sqlalchemy.Column(sqlalchemy.String(50), nullable=False) # info, warning, error, debug
|
||||||
platform = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
platform = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||||
user_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
user_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||||
|
user_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) # User display name
|
||||||
runner_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) # Runner name for this query
|
runner_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) # Runner name for this query
|
||||||
variables = sqlalchemy.Column(sqlalchemy.Text, nullable=True) # Query variables as JSON string
|
variables = sqlalchemy.Column(sqlalchemy.Text, nullable=True) # Query variables as JSON string
|
||||||
role = sqlalchemy.Column(sqlalchemy.String(50), nullable=True, default='user') # user, assistant
|
role = sqlalchemy.Column(sqlalchemy.String(50), nullable=True, default='user') # user, assistant
|
||||||
@@ -64,6 +65,7 @@ class MonitoringSession(Base):
|
|||||||
is_active = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True, index=True)
|
is_active = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True, index=True)
|
||||||
platform = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
platform = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||||
user_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
user_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||||
|
user_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) # User display name
|
||||||
|
|
||||||
|
|
||||||
class MonitoringError(Base):
|
class MonitoringError(Base):
|
||||||
|
|||||||
@@ -10,8 +10,21 @@ class KnowledgeBase(Base):
|
|||||||
emoji = sqlalchemy.Column(sqlalchemy.String(10), nullable=True, default='📚')
|
emoji = sqlalchemy.Column(sqlalchemy.String(10), nullable=True, default='📚')
|
||||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now())
|
created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now())
|
||||||
updated_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now())
|
updated_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now())
|
||||||
embedding_model_uuid = sqlalchemy.Column(sqlalchemy.String, default='')
|
# New fields for plugin-based RAG
|
||||||
top_k = sqlalchemy.Column(sqlalchemy.Integer, default=5)
|
knowledge_engine_plugin_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||||
|
collection_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||||
|
creation_settings = sqlalchemy.Column(sqlalchemy.JSON, nullable=True, default=None)
|
||||||
|
retrieval_settings = sqlalchemy.Column(sqlalchemy.JSON, nullable=True, default=None)
|
||||||
|
|
||||||
|
# Field sets for different operations
|
||||||
|
MUTABLE_FIELDS = {'name', 'description', 'retrieval_settings'}
|
||||||
|
"""Fields that can be updated after creation."""
|
||||||
|
|
||||||
|
CREATE_FIELDS = MUTABLE_FIELDS | {'uuid', 'knowledge_engine_plugin_id', 'collection_id', 'creation_settings'}
|
||||||
|
"""Fields used when creating a new knowledge base."""
|
||||||
|
|
||||||
|
ALL_DB_FIELDS = CREATE_FIELDS | {'emoji', 'created_at', 'updated_at'}
|
||||||
|
"""All fields stored in database (for loading from DB row)."""
|
||||||
|
|
||||||
|
|
||||||
class File(Base):
|
class File(Base):
|
||||||
@@ -29,16 +42,3 @@ class Chunk(Base):
|
|||||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||||
file_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
file_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||||
text = sqlalchemy.Column(sqlalchemy.Text)
|
text = sqlalchemy.Column(sqlalchemy.Text)
|
||||||
|
|
||||||
|
|
||||||
class ExternalKnowledgeBase(Base):
|
|
||||||
__tablename__ = 'external_knowledge_bases'
|
|
||||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
|
||||||
name = sqlalchemy.Column(sqlalchemy.String, index=True)
|
|
||||||
description = sqlalchemy.Column(sqlalchemy.Text)
|
|
||||||
emoji = sqlalchemy.Column(sqlalchemy.String(10), nullable=True, default='🔗')
|
|
||||||
plugin_author = sqlalchemy.Column(sqlalchemy.String, nullable=False)
|
|
||||||
plugin_name = sqlalchemy.Column(sqlalchemy.String, nullable=False)
|
|
||||||
retriever_name = sqlalchemy.Column(sqlalchemy.String, nullable=False)
|
|
||||||
retriever_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
|
||||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now())
|
|
||||||
|
|||||||
@@ -0,0 +1,161 @@
|
|||||||
|
import sqlalchemy
|
||||||
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
|
@migration.migration_class(20)
|
||||||
|
class DBMigrateKnowledgeEnginePluginArchitecture(migration.DBMigration):
|
||||||
|
"""Migrate to unified Knowledge Engine plugin architecture.
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
- Backup existing knowledge_bases data to knowledge_bases_backup
|
||||||
|
- Clear knowledge_bases table and add new plugin architecture columns
|
||||||
|
- Drop old columns (PostgreSQL only; SQLite leaves them unmapped)
|
||||||
|
- Preserve external_knowledge_bases table as-is for future migration
|
||||||
|
- Set rag_plugin_migration_needed flag in metadata if old data exists
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def upgrade(self):
|
||||||
|
"""Upgrade"""
|
||||||
|
has_internal_data = await self._backup_knowledge_bases()
|
||||||
|
has_external_data = await self._check_external_knowledge_bases()
|
||||||
|
await self._clear_knowledge_bases()
|
||||||
|
await self._add_columns_to_knowledge_bases()
|
||||||
|
await self._drop_old_columns()
|
||||||
|
if has_internal_data or has_external_data:
|
||||||
|
await self._set_migration_flag()
|
||||||
|
|
||||||
|
async def _get_table_columns(self, table_name: str) -> list[str]:
|
||||||
|
"""Get column names from a table (works for both SQLite and PostgreSQL)."""
|
||||||
|
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
'SELECT column_name FROM information_schema.columns WHERE table_name = :table_name;'
|
||||||
|
).bindparams(table_name=table_name)
|
||||||
|
)
|
||||||
|
return [row[0] for row in result.fetchall()]
|
||||||
|
else:
|
||||||
|
# SQLite PRAGMA does not support bind parameters; validate identifier.
|
||||||
|
if not table_name.isidentifier():
|
||||||
|
raise ValueError(f'Invalid table name: {table_name}')
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name});'))
|
||||||
|
return [row[1] for row in result.fetchall()]
|
||||||
|
|
||||||
|
async def _table_exists(self, table_name: str) -> bool:
|
||||||
|
"""Check if a table exists."""
|
||||||
|
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
'SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = :table_name);'
|
||||||
|
).bindparams(table_name=table_name)
|
||||||
|
)
|
||||||
|
return result.scalar()
|
||||||
|
else:
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name;").bindparams(
|
||||||
|
table_name=table_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.first() is not None
|
||||||
|
|
||||||
|
async def _backup_knowledge_bases(self) -> bool:
|
||||||
|
"""Backup knowledge_bases data. Returns True if data was backed up."""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text('SELECT COUNT(*) FROM knowledge_bases;'))
|
||||||
|
count = result.scalar()
|
||||||
|
if count == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Drop backup table if it already exists (from a previous failed migration)
|
||||||
|
if await self._table_exists('knowledge_bases_backup'):
|
||||||
|
await self.ap.persistence_mgr.execute_async(sqlalchemy.text('DROP TABLE knowledge_bases_backup;'))
|
||||||
|
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('CREATE TABLE knowledge_bases_backup AS SELECT * FROM knowledge_bases;')
|
||||||
|
)
|
||||||
|
self.ap.logger.info(
|
||||||
|
'Backed up %d knowledge base(s) to knowledge_bases_backup table.',
|
||||||
|
count,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _check_external_knowledge_bases(self) -> bool:
|
||||||
|
"""Check if external_knowledge_bases table exists and has data.
|
||||||
|
|
||||||
|
The table is preserved as-is (not dropped) for future migration.
|
||||||
|
"""
|
||||||
|
if not await self._table_exists('external_knowledge_bases'):
|
||||||
|
return False
|
||||||
|
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('SELECT COUNT(*) FROM external_knowledge_bases;')
|
||||||
|
)
|
||||||
|
count = result.scalar()
|
||||||
|
if count > 0:
|
||||||
|
self.ap.logger.info(
|
||||||
|
'Found %d external knowledge base(s) in external_knowledge_bases table. '
|
||||||
|
'Table preserved for future migration.',
|
||||||
|
count,
|
||||||
|
)
|
||||||
|
return count > 0
|
||||||
|
|
||||||
|
async def _clear_knowledge_bases(self):
|
||||||
|
"""Clear all rows from knowledge_bases table (preserve table structure)."""
|
||||||
|
await self.ap.persistence_mgr.execute_async(sqlalchemy.text('DELETE FROM knowledge_bases;'))
|
||||||
|
|
||||||
|
async def _add_columns_to_knowledge_bases(self):
|
||||||
|
"""Add new RAG plugin architecture columns to knowledge_bases table."""
|
||||||
|
columns = await self._get_table_columns('knowledge_bases')
|
||||||
|
|
||||||
|
new_columns = {
|
||||||
|
'knowledge_engine_plugin_id': 'VARCHAR',
|
||||||
|
'collection_id': 'VARCHAR',
|
||||||
|
'creation_settings': 'TEXT', # JSON stored as TEXT for SQLite compatibility
|
||||||
|
'retrieval_settings': 'TEXT',
|
||||||
|
}
|
||||||
|
|
||||||
|
for col_name, col_type in new_columns.items():
|
||||||
|
if col_name not in columns:
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(f'ALTER TABLE knowledge_bases ADD COLUMN {col_name} {col_type};')
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _drop_old_columns(self):
|
||||||
|
"""Drop embedding_model_uuid and top_k columns (PostgreSQL only).
|
||||||
|
|
||||||
|
SQLite does not support DROP COLUMN in older versions, so we leave the
|
||||||
|
columns in place — the SQLAlchemy entity simply won't map them.
|
||||||
|
"""
|
||||||
|
if self.ap.persistence_mgr.db.name != 'postgresql':
|
||||||
|
return
|
||||||
|
|
||||||
|
columns = await self._get_table_columns('knowledge_bases')
|
||||||
|
|
||||||
|
if 'embedding_model_uuid' in columns:
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('ALTER TABLE knowledge_bases DROP COLUMN embedding_model_uuid;')
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'top_k' in columns:
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('ALTER TABLE knowledge_bases DROP COLUMN top_k;')
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _set_migration_flag(self):
|
||||||
|
"""Set rag_plugin_migration_needed flag in metadata table."""
|
||||||
|
# Check if the key already exists
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text("SELECT value FROM metadata WHERE key = 'rag_plugin_migration_needed';")
|
||||||
|
)
|
||||||
|
row = result.first()
|
||||||
|
if row is not None:
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text("UPDATE metadata SET value = 'true' WHERE key = 'rag_plugin_migration_needed';")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text("INSERT INTO metadata (key, value) VALUES ('rag_plugin_migration_needed', 'true');")
|
||||||
|
)
|
||||||
|
self.ap.logger.info('Set rag_plugin_migration_needed=true in metadata.')
|
||||||
|
|
||||||
|
async def downgrade(self):
|
||||||
|
"""Downgrade"""
|
||||||
|
pass
|
||||||
@@ -0,0 +1,74 @@
|
|||||||
|
from .. import migration
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
@migration.migration_class(21)
|
||||||
|
class DBMigrateMergeExceptionHandling(migration.DBMigration):
|
||||||
|
"""Merge hide-exception and block-failed-request-output into a single exception-handling select option,
|
||||||
|
and add failure-hint field.
|
||||||
|
|
||||||
|
Conversion logic:
|
||||||
|
- block-failed-request-output=true -> exception-handling: hide
|
||||||
|
- hide-exception=true -> exception-handling: show-hint
|
||||||
|
- hide-exception=false -> exception-handling: show-error
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def upgrade(self):
|
||||||
|
"""Upgrade"""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||||
|
)
|
||||||
|
pipelines = result.fetchall()
|
||||||
|
|
||||||
|
current_version = self.ap.ver_mgr.get_current_version()
|
||||||
|
|
||||||
|
for pipeline_row in pipelines:
|
||||||
|
uuid = pipeline_row[0]
|
||||||
|
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||||
|
|
||||||
|
if 'output' not in config:
|
||||||
|
config['output'] = {}
|
||||||
|
if 'misc' not in config['output']:
|
||||||
|
config['output']['misc'] = {}
|
||||||
|
|
||||||
|
misc = config['output']['misc']
|
||||||
|
|
||||||
|
# Determine new exception-handling value from legacy fields
|
||||||
|
hide_exception = misc.get('hide-exception', True)
|
||||||
|
block_failed = misc.get('block-failed-request-output', False)
|
||||||
|
|
||||||
|
if block_failed:
|
||||||
|
exception_handling = 'hide'
|
||||||
|
elif hide_exception:
|
||||||
|
exception_handling = 'show-hint'
|
||||||
|
else:
|
||||||
|
exception_handling = 'show-error'
|
||||||
|
|
||||||
|
misc['exception-handling'] = exception_handling
|
||||||
|
|
||||||
|
# Add failure-hint with default value
|
||||||
|
misc['failure-hint'] = 'Request failed.'
|
||||||
|
|
||||||
|
# Remove legacy fields
|
||||||
|
misc.pop('hide-exception', None)
|
||||||
|
|
||||||
|
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||||
|
),
|
||||||
|
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||||
|
),
|
||||||
|
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def downgrade(self):
|
||||||
|
"""Downgrade"""
|
||||||
|
pass
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
import sqlalchemy
|
||||||
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
|
@migration.migration_class(22)
|
||||||
|
class DBMigrateMonitoringUserId(migration.DBMigration):
|
||||||
|
"""Add user_id and user_name columns to monitoring_sessions table
|
||||||
|
|
||||||
|
This migration adds the missing user_id column and also ensures user_name
|
||||||
|
column exists (in case migration 21 failed or was skipped).
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _table_exists(self, table_name: str) -> bool:
|
||||||
|
"""Check if a table exists (works for both SQLite and PostgreSQL)."""
|
||||||
|
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
'SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = :table_name);'
|
||||||
|
).bindparams(table_name=table_name)
|
||||||
|
)
|
||||||
|
return bool(result.scalar())
|
||||||
|
else:
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name;").bindparams(
|
||||||
|
table_name=table_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.first() is not None
|
||||||
|
|
||||||
|
async def _get_table_columns(self, table_name: str) -> list[str]:
|
||||||
|
"""Get column names from a table (works for both SQLite and PostgreSQL)."""
|
||||||
|
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
'SELECT column_name FROM information_schema.columns WHERE table_name = :table_name;'
|
||||||
|
).bindparams(table_name=table_name)
|
||||||
|
)
|
||||||
|
return [row[0] for row in result.fetchall()]
|
||||||
|
else:
|
||||||
|
if not table_name.isidentifier():
|
||||||
|
raise ValueError(f'Invalid table name: {table_name}')
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name});'))
|
||||||
|
return [row[1] for row in result.fetchall()]
|
||||||
|
|
||||||
|
async def _add_column_if_not_exists(self, table_name: str, column_name: str, column_type: str):
|
||||||
|
"""Add a column to a table if it does not already exist."""
|
||||||
|
columns = await self._get_table_columns(table_name)
|
||||||
|
if column_name in columns:
|
||||||
|
self.ap.logger.debug('%s column already exists in %s.', column_name, table_name)
|
||||||
|
return
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(f'ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type};')
|
||||||
|
)
|
||||||
|
self.ap.logger.info('Added %s column to %s table.', column_name, table_name)
|
||||||
|
|
||||||
|
async def upgrade(self):
|
||||||
|
# Check if monitoring_sessions table exists
|
||||||
|
if not await self._table_exists('monitoring_sessions'):
|
||||||
|
self.ap.logger.warning('monitoring_sessions table does not exist, skipping migration.')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add user_id column to monitoring_sessions table
|
||||||
|
await self._add_column_if_not_exists('monitoring_sessions', 'user_id', 'VARCHAR(255)')
|
||||||
|
|
||||||
|
# Add user_name column to monitoring_sessions table (in case migration 21 failed)
|
||||||
|
await self._add_column_if_not_exists('monitoring_sessions', 'user_name', 'VARCHAR(255)')
|
||||||
|
|
||||||
|
# Add user_name column to monitoring_messages table (in case migration 21 failed)
|
||||||
|
if await self._table_exists('monitoring_messages'):
|
||||||
|
await self._add_column_if_not_exists('monitoring_messages', 'user_name', 'VARCHAR(255)')
|
||||||
|
|
||||||
|
async def downgrade(self):
|
||||||
|
pass
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
from .. import migration
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
@migration.migration_class(23)
|
||||||
|
class DBMigrateModelFallbackConfig(migration.DBMigration):
|
||||||
|
"""Convert model field from plain UUID string to object with primary/fallbacks"""
|
||||||
|
|
||||||
|
async def upgrade(self):
|
||||||
|
"""Upgrade"""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||||
|
)
|
||||||
|
pipelines = result.fetchall()
|
||||||
|
|
||||||
|
current_version = self.ap.ver_mgr.get_current_version()
|
||||||
|
|
||||||
|
for pipeline_row in pipelines:
|
||||||
|
uuid = pipeline_row[0]
|
||||||
|
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||||
|
|
||||||
|
if 'ai' not in config or 'local-agent' not in config['ai']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
local_agent = config['ai']['local-agent']
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
# Convert model from string to object
|
||||||
|
model_value = local_agent.get('model', '')
|
||||||
|
if isinstance(model_value, str):
|
||||||
|
local_agent['model'] = {
|
||||||
|
'primary': model_value,
|
||||||
|
'fallbacks': [],
|
||||||
|
}
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
# Remove leftover fallback-models field if present
|
||||||
|
if 'fallback-models' in local_agent:
|
||||||
|
del local_agent['fallback-models']
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
if not changed:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||||
|
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||||
|
),
|
||||||
|
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||||
|
),
|
||||||
|
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def downgrade(self):
|
||||||
|
"""Downgrade"""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('SELECT uuid, config FROM legacy_pipelines')
|
||||||
|
)
|
||||||
|
pipelines = result.fetchall()
|
||||||
|
|
||||||
|
current_version = self.ap.ver_mgr.get_current_version()
|
||||||
|
|
||||||
|
for pipeline_row in pipelines:
|
||||||
|
uuid = pipeline_row[0]
|
||||||
|
config = json.loads(pipeline_row[1]) if isinstance(pipeline_row[1], str) else pipeline_row[1]
|
||||||
|
|
||||||
|
if 'ai' not in config or 'local-agent' not in config['ai']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
local_agent = config['ai']['local-agent']
|
||||||
|
|
||||||
|
# Convert model from object back to string
|
||||||
|
model_value = local_agent.get('model', '')
|
||||||
|
if isinstance(model_value, dict):
|
||||||
|
local_agent['model'] = model_value.get('primary', '')
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update using raw SQL with compatibility for both SQLite and PostgreSQL
|
||||||
|
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
'UPDATE legacy_pipelines SET config = :config::jsonb, for_version = :for_version WHERE uuid = :uuid'
|
||||||
|
),
|
||||||
|
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
'UPDATE legacy_pipelines SET config = :config, for_version = :for_version WHERE uuid = :uuid'
|
||||||
|
),
|
||||||
|
{'config': json.dumps(config), 'for_version': current_version, 'uuid': uuid},
|
||||||
|
)
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
from .. import migration
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
@migration.migration_class(24)
|
||||||
|
class DBMigrateWecomBotWebSocketMode(migration.DBMigration):
|
||||||
|
"""Add enable-webhook field to existing wecombot adapter configs.
|
||||||
|
|
||||||
|
Existing wecombot bots were all using webhook mode, so we set
|
||||||
|
enable-webhook=true to preserve their behavior after the new
|
||||||
|
WebSocket long connection mode is introduced as default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def upgrade(self):
|
||||||
|
"""Upgrade"""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text("SELECT uuid, adapter_config FROM bots WHERE adapter = 'wecombot'")
|
||||||
|
)
|
||||||
|
bots = result.fetchall()
|
||||||
|
|
||||||
|
for bot_row in bots:
|
||||||
|
bot_uuid = bot_row[0]
|
||||||
|
adapter_config = json.loads(bot_row[1]) if isinstance(bot_row[1], str) else bot_row[1]
|
||||||
|
|
||||||
|
if 'enable-webhook' in adapter_config:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Determine mode based on existing config: if webhook fields are present, keep webhook mode
|
||||||
|
has_webhook_config = bool(
|
||||||
|
adapter_config.get('Token') and adapter_config.get('EncodingAESKey') and adapter_config.get('Corpid')
|
||||||
|
)
|
||||||
|
adapter_config['enable-webhook'] = has_webhook_config
|
||||||
|
|
||||||
|
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('UPDATE bots SET adapter_config = :config::jsonb WHERE uuid = :uuid'),
|
||||||
|
{'config': json.dumps(adapter_config), 'uuid': bot_uuid},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('UPDATE bots SET adapter_config = :config WHERE uuid = :uuid'),
|
||||||
|
{'config': json.dumps(adapter_config), 'uuid': bot_uuid},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def downgrade(self):
|
||||||
|
"""Downgrade"""
|
||||||
|
pass
|
||||||
@@ -1,10 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
|
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from .. import filter as filter_model
|
from .. import filter as filter_model
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
from langbot.pkg.utils import httpclient
|
||||||
|
|
||||||
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
|
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
|
||||||
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
|
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
|
||||||
@@ -15,50 +14,50 @@ class BaiduCloudExamine(filter_model.ContentFilter):
|
|||||||
"""百度云内容审核"""
|
"""百度云内容审核"""
|
||||||
|
|
||||||
async def _get_token(self) -> str:
|
async def _get_token(self) -> str:
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.post(
|
async with session.post(
|
||||||
BAIDU_EXAMINE_TOKEN_URL,
|
BAIDU_EXAMINE_TOKEN_URL,
|
||||||
params={
|
params={
|
||||||
'grant_type': 'client_credentials',
|
'grant_type': 'client_credentials',
|
||||||
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
|
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
|
||||||
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'],
|
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'],
|
||||||
},
|
},
|
||||||
) as resp:
|
) as resp:
|
||||||
return (await resp.json())['access_token']
|
return (await resp.json())['access_token']
|
||||||
|
|
||||||
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.post(
|
async with session.post(
|
||||||
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
||||||
headers={
|
headers={
|
||||||
'Content-Type': 'application/x-www-form-urlencoded',
|
'Content-Type': 'application/x-www-form-urlencoded',
|
||||||
'Accept': 'application/json',
|
'Accept': 'application/json',
|
||||||
},
|
},
|
||||||
data=f'text={message}'.encode('utf-8'),
|
data=f'text={message}'.encode('utf-8'),
|
||||||
) as resp:
|
) as resp:
|
||||||
result = await resp.json()
|
result = await resp.json()
|
||||||
|
|
||||||
if 'error_code' in result:
|
if 'error_code' in result:
|
||||||
|
return entities.FilterResult(
|
||||||
|
level=entities.ResultLevel.BLOCK,
|
||||||
|
replacement=message,
|
||||||
|
user_notice='',
|
||||||
|
console_notice=f'百度云判定出错,错误信息:{result["error_msg"]}',
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
conclusion = result['conclusion']
|
||||||
|
|
||||||
|
if conclusion in ('合规'):
|
||||||
|
return entities.FilterResult(
|
||||||
|
level=entities.ResultLevel.PASS,
|
||||||
|
replacement=message,
|
||||||
|
user_notice='',
|
||||||
|
console_notice=f'百度云判定结果:{conclusion}',
|
||||||
|
)
|
||||||
|
else:
|
||||||
return entities.FilterResult(
|
return entities.FilterResult(
|
||||||
level=entities.ResultLevel.BLOCK,
|
level=entities.ResultLevel.BLOCK,
|
||||||
replacement=message,
|
replacement=message,
|
||||||
user_notice='',
|
user_notice='消息中存在不合适的内容, 请修改',
|
||||||
console_notice=f'百度云判定出错,错误信息:{result["error_msg"]}',
|
console_notice=f'百度云判定结果:{conclusion}',
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
conclusion = result['conclusion']
|
|
||||||
|
|
||||||
if conclusion in ('合规'):
|
|
||||||
return entities.FilterResult(
|
|
||||||
level=entities.ResultLevel.PASS,
|
|
||||||
replacement=message,
|
|
||||||
user_notice='',
|
|
||||||
console_notice=f'百度云判定结果:{conclusion}',
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return entities.FilterResult(
|
|
||||||
level=entities.ResultLevel.BLOCK,
|
|
||||||
replacement=message,
|
|
||||||
user_notice='消息中存在不合适的内容, 请修改',
|
|
||||||
console_notice=f'百度云判定结果:{conclusion}',
|
|
||||||
)
|
|
||||||
|
|||||||
105
src/langbot/pkg/pipeline/config_coercion.py
Normal file
105
src/langbot/pkg/pipeline/config_coercion.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# metadata type -> coercion function
|
||||||
|
_COERCE_MAP = {
|
||||||
|
'integer': lambda v: int(v),
|
||||||
|
'number': lambda v: float(v),
|
||||||
|
'float': lambda v: float(v),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_bool(v):
|
||||||
|
if isinstance(v, bool):
|
||||||
|
return v
|
||||||
|
if isinstance(v, str):
|
||||||
|
if v.lower() == 'true':
|
||||||
|
return True
|
||||||
|
if v.lower() == 'false':
|
||||||
|
return False
|
||||||
|
raise ValueError(f'Cannot convert string {v!r} to bool')
|
||||||
|
return bool(v)
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_value(value, expected_type: str):
|
||||||
|
"""Convert a single value to the expected type.
|
||||||
|
|
||||||
|
Returns the converted value, or the original value if no conversion needed.
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
|
||||||
|
if expected_type == 'boolean':
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
return _coerce_bool(value)
|
||||||
|
|
||||||
|
coerce_fn = _COERCE_MAP.get(expected_type)
|
||||||
|
if coerce_fn is None:
|
||||||
|
return value
|
||||||
|
|
||||||
|
# Already the correct type
|
||||||
|
if expected_type == 'integer' and isinstance(value, int) and not isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if expected_type in ('number', 'float') and isinstance(value, (int, float)) and not isinstance(value, bool):
|
||||||
|
return float(value)
|
||||||
|
|
||||||
|
return coerce_fn(value)
|
||||||
|
|
||||||
|
|
||||||
|
def coerce_pipeline_config(
|
||||||
|
config: dict,
|
||||||
|
*metadata_list: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Coerce pipeline config values according to metadata type definitions.
|
||||||
|
|
||||||
|
Walks each metadata dict (trigger, safety, ai, output) and converts
|
||||||
|
config values in-place so that strings coming from the JSON column are
|
||||||
|
cast to their declared types (integer, number/float, boolean).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The pipeline config dict to modify in-place.
|
||||||
|
*metadata_list: Metadata dicts loaded from the YAML templates.
|
||||||
|
"""
|
||||||
|
for meta in metadata_list:
|
||||||
|
section_name = meta.get('name')
|
||||||
|
if not section_name or section_name not in config:
|
||||||
|
continue
|
||||||
|
|
||||||
|
section = config[section_name]
|
||||||
|
if not isinstance(section, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for stage_def in meta.get('stages', []):
|
||||||
|
stage_name = stage_def.get('name')
|
||||||
|
if not stage_name or stage_name not in section:
|
||||||
|
continue
|
||||||
|
|
||||||
|
stage_config = section[stage_name]
|
||||||
|
if not isinstance(stage_config, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for field_def in stage_def.get('config', []):
|
||||||
|
field_name = field_def.get('name')
|
||||||
|
field_type = field_def.get('type')
|
||||||
|
if not field_name or not field_type or field_name not in stage_config:
|
||||||
|
continue
|
||||||
|
|
||||||
|
old_value = stage_config[field_name]
|
||||||
|
try:
|
||||||
|
new_value = _coerce_value(old_value, field_type)
|
||||||
|
if new_value is not old_value:
|
||||||
|
stage_config[field_name] = new_value
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
logger.warning(
|
||||||
|
'Failed to coerce config %s.%s.%s (%r) to %s: %s',
|
||||||
|
section_name,
|
||||||
|
stage_name,
|
||||||
|
field_name,
|
||||||
|
old_value,
|
||||||
|
field_type,
|
||||||
|
e,
|
||||||
|
)
|
||||||
@@ -34,6 +34,15 @@ class MonitoringHelper:
|
|||||||
# Check if session exists, if not, record session start
|
# Check if session exists, if not, record session start
|
||||||
session_id = f'{query.launcher_type}_{query.launcher_id}'
|
session_id = f'{query.launcher_type}_{query.launcher_id}'
|
||||||
|
|
||||||
|
# Get sender name from message event
|
||||||
|
sender_name = None
|
||||||
|
if hasattr(query, 'message_event'):
|
||||||
|
if hasattr(query.message_event, 'sender'):
|
||||||
|
if hasattr(query.message_event.sender, 'nickname'):
|
||||||
|
sender_name = query.message_event.sender.nickname
|
||||||
|
elif hasattr(query.message_event.sender, 'member_name'):
|
||||||
|
sender_name = query.message_event.sender.member_name
|
||||||
|
|
||||||
# Try to record message
|
# Try to record message
|
||||||
# Use JSON serialization to preserve message chain structure (including image URLs, etc.)
|
# Use JSON serialization to preserve message chain structure (including image URLs, etc.)
|
||||||
if hasattr(query, 'message_chain') and hasattr(query.message_chain, 'model_dump'):
|
if hasattr(query, 'message_chain') and hasattr(query.message_chain, 'model_dump'):
|
||||||
@@ -57,6 +66,7 @@ class MonitoringHelper:
|
|||||||
if hasattr(query.launcher_type, 'value')
|
if hasattr(query.launcher_type, 'value')
|
||||||
else str(query.launcher_type),
|
else str(query.launcher_type),
|
||||||
user_id=query.sender_id,
|
user_id=query.sender_id,
|
||||||
|
user_name=sender_name,
|
||||||
runner_name=runner_name,
|
runner_name=runner_name,
|
||||||
variables=None, # Will be updated in record_query_success
|
variables=None, # Will be updated in record_query_success
|
||||||
)
|
)
|
||||||
@@ -80,6 +90,7 @@ class MonitoringHelper:
|
|||||||
if hasattr(query.launcher_type, 'value')
|
if hasattr(query.launcher_type, 'value')
|
||||||
else str(query.launcher_type),
|
else str(query.launcher_type),
|
||||||
user_id=query.sender_id,
|
user_id=query.sender_id,
|
||||||
|
user_name=sender_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return message_id
|
return message_id
|
||||||
@@ -128,6 +139,15 @@ class MonitoringHelper:
|
|||||||
try:
|
try:
|
||||||
session_id = f'{query.launcher_type}_{query.launcher_id}'
|
session_id = f'{query.launcher_type}_{query.launcher_id}'
|
||||||
|
|
||||||
|
# Get sender name from message event
|
||||||
|
sender_name = None
|
||||||
|
if hasattr(query, 'message_event'):
|
||||||
|
if hasattr(query.message_event, 'sender'):
|
||||||
|
if hasattr(query.message_event.sender, 'nickname'):
|
||||||
|
sender_name = query.message_event.sender.nickname
|
||||||
|
elif hasattr(query.message_event.sender, 'member_name'):
|
||||||
|
sender_name = query.message_event.sender.member_name
|
||||||
|
|
||||||
# Extract response content from resp_message_chain
|
# Extract response content from resp_message_chain
|
||||||
if hasattr(query, 'resp_message_chain') and query.resp_message_chain:
|
if hasattr(query, 'resp_message_chain') and query.resp_message_chain:
|
||||||
# Serialize the last response message chain
|
# Serialize the last response message chain
|
||||||
@@ -162,6 +182,7 @@ class MonitoringHelper:
|
|||||||
if hasattr(query.launcher_type, 'value')
|
if hasattr(query.launcher_type, 'value')
|
||||||
else str(query.launcher_type),
|
else str(query.launcher_type),
|
||||||
user_id=query.sender_id,
|
user_id=query.sender_id,
|
||||||
|
user_name=sender_name,
|
||||||
runner_name=runner_name,
|
runner_name=runner_name,
|
||||||
role='assistant',
|
role='assistant',
|
||||||
)
|
)
|
||||||
@@ -183,6 +204,15 @@ class MonitoringHelper:
|
|||||||
try:
|
try:
|
||||||
session_id = f'{query.launcher_type}_{query.launcher_id}'
|
session_id = f'{query.launcher_type}_{query.launcher_id}'
|
||||||
|
|
||||||
|
# Get sender name from message event
|
||||||
|
sender_name = None
|
||||||
|
if hasattr(query, 'message_event'):
|
||||||
|
if hasattr(query.message_event, 'sender'):
|
||||||
|
if hasattr(query.message_event.sender, 'nickname'):
|
||||||
|
sender_name = query.message_event.sender.nickname
|
||||||
|
elif hasattr(query.message_event.sender, 'member_name'):
|
||||||
|
sender_name = query.message_event.sender.member_name
|
||||||
|
|
||||||
# Record error message
|
# Record error message
|
||||||
message_id = await ap.monitoring_service.record_message(
|
message_id = await ap.monitoring_service.record_message(
|
||||||
bot_id=bot_id,
|
bot_id=bot_id,
|
||||||
@@ -197,6 +227,7 @@ class MonitoringHelper:
|
|||||||
if hasattr(query.launcher_type, 'value')
|
if hasattr(query.launcher_type, 'value')
|
||||||
else str(query.launcher_type),
|
else str(query.launcher_type),
|
||||||
user_id=query.sender_id,
|
user_id=query.sender_id,
|
||||||
|
user_name=sender_name,
|
||||||
runner_name=runner_name,
|
runner_name=runner_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
|||||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
import langbot_plugin.api.entities.events as events
|
import langbot_plugin.api.entities.events as events
|
||||||
from ..utils import importutil
|
from ..utils import importutil
|
||||||
|
from .config_coercion import coerce_pipeline_config
|
||||||
|
|
||||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
@@ -420,6 +421,14 @@ class PipelineManager:
|
|||||||
elif isinstance(pipeline_entity, dict):
|
elif isinstance(pipeline_entity, dict):
|
||||||
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
|
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
|
||||||
|
|
||||||
|
coerce_pipeline_config(
|
||||||
|
pipeline_entity.config,
|
||||||
|
getattr(self.ap, 'pipeline_config_meta_trigger', {'name': 'trigger', 'stages': []}),
|
||||||
|
getattr(self.ap, 'pipeline_config_meta_safety', {'name': 'safety', 'stages': []}),
|
||||||
|
getattr(self.ap, 'pipeline_config_meta_ai', {'name': 'ai', 'stages': []}),
|
||||||
|
getattr(self.ap, 'pipeline_config_meta_output', {'name': 'output', 'stages': []}),
|
||||||
|
)
|
||||||
|
|
||||||
# initialize stage containers according to pipeline_entity.stages
|
# initialize stage containers according to pipeline_entity.stages
|
||||||
stage_containers: list[StageInstContainer] = []
|
stage_containers: list[StageInstContainer] = []
|
||||||
for stage_name in pipeline_entity.stages:
|
for stage_name in pipeline_entity.stages:
|
||||||
|
|||||||
@@ -36,17 +36,36 @@ class PreProcessor(stage.PipelineStage):
|
|||||||
session = await self.ap.sess_mgr.get_session(query)
|
session = await self.ap.sess_mgr.get_session(query)
|
||||||
|
|
||||||
# When not local-agent, llm_model is None
|
# When not local-agent, llm_model is None
|
||||||
try:
|
llm_model = None
|
||||||
llm_model = (
|
if selected_runner == 'local-agent':
|
||||||
await self.ap.model_mgr.get_model_by_uuid(query.pipeline_config['ai']['local-agent']['model'])
|
# Read model config — new format is { primary: str, fallbacks: [str] },
|
||||||
if selected_runner == 'local-agent'
|
# but handle legacy plain string for backward compatibility
|
||||||
else None
|
model_config = query.pipeline_config['ai']['local-agent'].get('model', {})
|
||||||
)
|
if isinstance(model_config, str):
|
||||||
except ValueError:
|
# Legacy format: plain UUID string
|
||||||
self.ap.logger.warning(
|
primary_uuid = model_config
|
||||||
f'LLM model {query.pipeline_config["ai"]["local-agent"]["model"] + " "}not found or not configured'
|
fallback_uuids = []
|
||||||
)
|
else:
|
||||||
llm_model = None
|
primary_uuid = model_config.get('primary', '')
|
||||||
|
fallback_uuids = model_config.get('fallbacks', [])
|
||||||
|
|
||||||
|
if primary_uuid:
|
||||||
|
try:
|
||||||
|
llm_model = await self.ap.model_mgr.get_model_by_uuid(primary_uuid)
|
||||||
|
except ValueError:
|
||||||
|
self.ap.logger.warning(f'LLM model {primary_uuid} not found or not configured')
|
||||||
|
|
||||||
|
# Resolve fallback model UUIDs
|
||||||
|
if fallback_uuids:
|
||||||
|
valid_fallbacks = []
|
||||||
|
for fb_uuid in fallback_uuids:
|
||||||
|
try:
|
||||||
|
await self.ap.model_mgr.get_model_by_uuid(fb_uuid)
|
||||||
|
valid_fallbacks.append(fb_uuid)
|
||||||
|
except ValueError:
|
||||||
|
self.ap.logger.warning(f'Fallback model {fb_uuid} not found, skipping')
|
||||||
|
if valid_fallbacks:
|
||||||
|
query.variables['_fallback_model_uuids'] = valid_fallbacks
|
||||||
|
|
||||||
conversation = await self.ap.sess_mgr.get_conversation(
|
conversation = await self.ap.sess_mgr.get_conversation(
|
||||||
query,
|
query,
|
||||||
@@ -61,20 +80,28 @@ class PreProcessor(stage.PipelineStage):
|
|||||||
query.prompt = conversation.prompt.copy()
|
query.prompt = conversation.prompt.copy()
|
||||||
query.messages = conversation.messages.copy()
|
query.messages = conversation.messages.copy()
|
||||||
|
|
||||||
if selected_runner == 'local-agent' and llm_model:
|
if selected_runner == 'local-agent':
|
||||||
query.use_funcs = []
|
query.use_funcs = []
|
||||||
query.use_llm_model_uuid = llm_model.model_entity.uuid
|
if llm_model:
|
||||||
|
query.use_llm_model_uuid = llm_model.model_entity.uuid
|
||||||
|
|
||||||
if llm_model.model_entity.abilities.__contains__('func_call'):
|
if llm_model.model_entity.abilities.__contains__('func_call'):
|
||||||
# Get bound plugins and MCP servers for filtering tools
|
# Get bound plugins and MCP servers for filtering tools
|
||||||
|
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
|
||||||
|
bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None)
|
||||||
|
query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers)
|
||||||
|
|
||||||
|
self.ap.logger.debug(f'Bound plugins: {bound_plugins}')
|
||||||
|
self.ap.logger.debug(f'Bound MCP servers: {bound_mcp_servers}')
|
||||||
|
self.ap.logger.debug(f'Use funcs: {query.use_funcs}')
|
||||||
|
|
||||||
|
# If primary model doesn't support func_call but fallback models exist,
|
||||||
|
# load tools anyway since fallback models may support them
|
||||||
|
if not query.use_funcs and query.variables.get('_fallback_model_uuids'):
|
||||||
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
|
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
|
||||||
bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None)
|
bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None)
|
||||||
query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers)
|
query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers)
|
||||||
|
|
||||||
self.ap.logger.debug(f'Bound plugins: {bound_plugins}')
|
|
||||||
self.ap.logger.debug(f'Bound MCP servers: {bound_mcp_servers}')
|
|
||||||
self.ap.logger.debug(f'Use funcs: {query.use_funcs}')
|
|
||||||
|
|
||||||
sender_name = ''
|
sender_name = ''
|
||||||
|
|
||||||
if isinstance(query.message_event, platform_events.GroupMessage):
|
if isinstance(query.message_event, platform_events.GroupMessage):
|
||||||
@@ -149,6 +176,16 @@ class PreProcessor(stage.PipelineStage):
|
|||||||
query.variables['user_message_text'] = plain_text
|
query.variables['user_message_text'] = plain_text
|
||||||
|
|
||||||
query.user_message = provider_message.Message(role='user', content=content_list)
|
query.user_message = provider_message.Message(role='user', content=content_list)
|
||||||
|
|
||||||
|
# Extract knowledge base UUIDs into query variables so plugins can modify them
|
||||||
|
# during PromptPreProcessing before the runner performs retrieval.
|
||||||
|
kb_uuids = query.pipeline_config['ai']['local-agent'].get('knowledge-bases', [])
|
||||||
|
if not kb_uuids:
|
||||||
|
old_kb_uuid = query.pipeline_config['ai']['local-agent'].get('knowledge-base', '')
|
||||||
|
if old_kb_uuid and old_kb_uuid != '__none__':
|
||||||
|
kb_uuids = [old_kb_uuid]
|
||||||
|
query.variables['_knowledge_base_uuids'] = list(kb_uuids)
|
||||||
|
|
||||||
# =========== 触发事件 PromptPreProcessing
|
# =========== 触发事件 PromptPreProcessing
|
||||||
|
|
||||||
event = events.PromptPreProcessing(
|
event = events.PromptPreProcessing(
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from ... import entities
|
|||||||
from ....provider import runner as runner_module
|
from ....provider import runner as runner_module
|
||||||
|
|
||||||
import langbot_plugin.api.entities.events as events
|
import langbot_plugin.api.entities.events as events
|
||||||
from ....utils import importutil, constants
|
from ....utils import importutil, constants, runner as runner_utils
|
||||||
from ....provider import runners
|
from ....provider import runners
|
||||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
@@ -149,12 +149,19 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||||||
self.ap.logger.error(f'Conversation({query.query_id}) Request Failed: {error_info}')
|
self.ap.logger.error(f'Conversation({query.query_id}) Request Failed: {error_info}')
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
|
exception_handling = query.pipeline_config['output']['misc'].get('exception-handling', 'show-hint')
|
||||||
|
|
||||||
|
if exception_handling == 'show-error':
|
||||||
|
user_notice = f'{e}'
|
||||||
|
elif exception_handling == 'show-hint':
|
||||||
|
user_notice = query.pipeline_config['output']['misc'].get('failure-hint', 'Request failed.')
|
||||||
|
else: # hide
|
||||||
|
user_notice = None
|
||||||
|
|
||||||
yield entities.StageProcessResult(
|
yield entities.StageProcessResult(
|
||||||
result_type=entities.ResultType.INTERRUPT,
|
result_type=entities.ResultType.INTERRUPT,
|
||||||
new_query=query,
|
new_query=query,
|
||||||
user_notice='请求失败' if hide_exception_info else f'{e}',
|
user_notice=user_notice,
|
||||||
error_notice=f'{e}',
|
error_notice=f'{e}',
|
||||||
debug_notice=traceback.format_exc(),
|
debug_notice=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
@@ -185,10 +192,15 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||||||
|
|
||||||
pipeline_plugins = query.variables.get('_pipeline_bound_plugins', None)
|
pipeline_plugins = query.variables.get('_pipeline_bound_plugins', None)
|
||||||
|
|
||||||
|
runner_category = runner_utils.get_runner_category_from_runner(
|
||||||
|
runner_name, runner, query.pipeline_config
|
||||||
|
)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
'query_id': query.query_id,
|
'query_id': query.query_id,
|
||||||
'adapter': adapter_name,
|
'adapter': adapter_name,
|
||||||
'runner': runner_name,
|
'runner': runner_name,
|
||||||
|
'runner_category': runner_category,
|
||||||
'duration_ms': duration_ms,
|
'duration_ms': duration_ms,
|
||||||
'model_name': model_name,
|
'model_name': model_name,
|
||||||
'version': constants.semantic_version,
|
'version': constants.semantic_version,
|
||||||
|
|||||||
@@ -282,6 +282,8 @@ class PlatformManager:
|
|||||||
return runtime_bot
|
return runtime_bot
|
||||||
|
|
||||||
async def get_bot_by_uuid(self, bot_uuid: str) -> RuntimeBot | None:
|
async def get_bot_by_uuid(self, bot_uuid: str) -> RuntimeBot | None:
|
||||||
|
if self.websocket_proxy_bot and self.websocket_proxy_bot.bot_entity.uuid == bot_uuid:
|
||||||
|
return self.websocket_proxy_bot
|
||||||
for bot in self.bots:
|
for bot in self.bots:
|
||||||
if bot.bot_entity.uuid == bot_uuid:
|
if bot.bot_entity.uuid == bot_uuid:
|
||||||
return bot
|
return bot
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import io
|
|||||||
import asyncio
|
import asyncio
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import aiohttp
|
from langbot.pkg.utils import httpclient
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||||
@@ -622,23 +622,23 @@ class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter
|
|||||||
image_bytes = base64.b64decode(base64_data)
|
image_bytes = base64.b64decode(base64_data)
|
||||||
elif ele.url:
|
elif ele.url:
|
||||||
# 从URL下载图片
|
# 从URL下载图片
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(ele.url) as response:
|
async with session.get(ele.url) as response:
|
||||||
image_bytes = await response.read()
|
image_bytes = await response.read()
|
||||||
# 从URL或Content-Type推断文件类型
|
# 从URL或Content-Type推断文件类型
|
||||||
content_type = response.headers.get('Content-Type', '')
|
content_type = response.headers.get('Content-Type', '')
|
||||||
if 'jpeg' in content_type or 'jpg' in content_type:
|
if 'jpeg' in content_type or 'jpg' in content_type:
|
||||||
filename = f'{uuid.uuid4()}.jpg'
|
filename = f'{uuid.uuid4()}.jpg'
|
||||||
elif 'gif' in content_type:
|
elif 'gif' in content_type:
|
||||||
filename = f'{uuid.uuid4()}.gif'
|
filename = f'{uuid.uuid4()}.gif'
|
||||||
elif 'webp' in content_type:
|
elif 'webp' in content_type:
|
||||||
filename = f'{uuid.uuid4()}.webp'
|
filename = f'{uuid.uuid4()}.webp'
|
||||||
elif ele.url.lower().endswith(('.jpg', '.jpeg')):
|
elif ele.url.lower().endswith(('.jpg', '.jpeg')):
|
||||||
filename = f'{uuid.uuid4()}.jpg'
|
filename = f'{uuid.uuid4()}.jpg'
|
||||||
elif ele.url.lower().endswith('.gif'):
|
elif ele.url.lower().endswith('.gif'):
|
||||||
filename = f'{uuid.uuid4()}.gif'
|
filename = f'{uuid.uuid4()}.gif'
|
||||||
elif ele.url.lower().endswith('.webp'):
|
elif ele.url.lower().endswith('.webp'):
|
||||||
filename = f'{uuid.uuid4()}.webp'
|
filename = f'{uuid.uuid4()}.webp'
|
||||||
elif ele.path:
|
elif ele.path:
|
||||||
# 从文件路径读取图片
|
# 从文件路径读取图片
|
||||||
# 确保路径没有空字节
|
# 确保路径没有空字节
|
||||||
@@ -702,9 +702,9 @@ class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter
|
|||||||
file_base64 = ele.base64.split(',')[-1]
|
file_base64 = ele.base64.split(',')[-1]
|
||||||
file_bytes = base64.b64decode(file_base64)
|
file_bytes = base64.b64decode(file_base64)
|
||||||
elif ele.url:
|
elif ele.url:
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(ele.url) as response:
|
async with session.get(ele.url) as response:
|
||||||
file_bytes = await response.read()
|
file_bytes = await response.read()
|
||||||
if file_bytes:
|
if file_bytes:
|
||||||
files.append(discord.File(fp=io.BytesIO(file_bytes), filename=filename))
|
files.append(discord.File(fp=io.BytesIO(file_bytes), filename=filename))
|
||||||
elif isinstance(ele, platform_message.File):
|
elif isinstance(ele, platform_message.File):
|
||||||
@@ -717,9 +717,9 @@ class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter
|
|||||||
else:
|
else:
|
||||||
file_bytes = base64.b64decode(ele.base64)
|
file_bytes = base64.b64decode(ele.base64)
|
||||||
elif ele.url:
|
elif ele.url:
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(ele.url) as response:
|
async with session.get(ele.url) as response:
|
||||||
file_bytes = await response.read()
|
file_bytes = await response.read()
|
||||||
if file_bytes:
|
if file_bytes:
|
||||||
files.append(discord.File(fp=io.BytesIO(file_bytes), filename=filename))
|
files.append(discord.File(fp=io.BytesIO(file_bytes), filename=filename))
|
||||||
elif isinstance(ele, platform_message.Forward):
|
elif isinstance(ele, platform_message.Forward):
|
||||||
@@ -775,12 +775,12 @@ class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter
|
|||||||
|
|
||||||
# attachments
|
# attachments
|
||||||
for attachment in message.attachments:
|
for attachment in message.attachments:
|
||||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
session = httpclient.get_session(trust_env=True)
|
||||||
async with session.get(attachment.url) as response:
|
async with session.get(attachment.url) as response:
|
||||||
image_data = await response.read()
|
image_data = await response.read()
|
||||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||||
image_format = response.headers['Content-Type']
|
image_format = response.headers['Content-Type']
|
||||||
element_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}'))
|
element_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}'))
|
||||||
|
|
||||||
return platform_message.MessageChain(element_list)
|
return platform_message.MessageChain(element_list)
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import traceback
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
from langbot.pkg.utils import httpclient
|
||||||
import websockets
|
import websockets
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
@@ -120,16 +122,16 @@ class KookMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
|||||||
if content:
|
if content:
|
||||||
# Download image and convert to base64
|
# Download image and convert to base64
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(content) as response:
|
async with session.get(content) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
image_bytes = await response.read()
|
image_bytes = await response.read()
|
||||||
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||||
# Detect image format
|
# Detect image format
|
||||||
content_type = response.headers.get('Content-Type', 'image/png')
|
content_type = response.headers.get('Content-Type', 'image/png')
|
||||||
components.append(
|
components.append(
|
||||||
platform_message.Image(base64=f'data:{content_type};base64,{image_base64}')
|
platform_message.Image(base64=f'data:{content_type};base64,{image_base64}')
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# If download fails, just add as plain text
|
# If download fails, just add as plain text
|
||||||
components.append(platform_message.Plain(text=f'[Image: {content}]'))
|
components.append(platform_message.Plain(text=f'[Image: {content}]'))
|
||||||
@@ -295,17 +297,17 @@ class KookAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
'Authorization': f'Bot {self.config["token"]}',
|
'Authorization': f'Bot {self.config["token"]}',
|
||||||
}
|
}
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(base_url, params=params, headers=headers) as response:
|
async with session.get(base_url, params=params, headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
if data.get('code') == 0:
|
if data.get('code') == 0:
|
||||||
gateway_url = data['data']['url']
|
gateway_url = data['data']['url']
|
||||||
return gateway_url
|
return gateway_url
|
||||||
else:
|
|
||||||
raise Exception(f'Failed to get gateway URL: {data.get("message")}')
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f'Failed to get gateway URL: HTTP {response.status}')
|
raise Exception(f'Failed to get gateway URL: {data.get("message")}')
|
||||||
|
else:
|
||||||
|
raise Exception(f'Failed to get gateway URL: HTTP {response.status}')
|
||||||
|
|
||||||
async def _get_bot_user_info(self) -> dict:
|
async def _get_bot_user_info(self) -> dict:
|
||||||
"""Get bot's own user information from KOOK API"""
|
"""Get bot's own user information from KOOK API"""
|
||||||
@@ -315,17 +317,17 @@ class KookAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
'Authorization': f'Bot {self.config["token"]}',
|
'Authorization': f'Bot {self.config["token"]}',
|
||||||
}
|
}
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(base_url, headers=headers) as response:
|
async with session.get(base_url, headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
if data.get('code') == 0:
|
if data.get('code') == 0:
|
||||||
user_info = data['data']
|
user_info = data['data']
|
||||||
return user_info
|
return user_info
|
||||||
else:
|
|
||||||
raise Exception(f'Failed to get bot user info: {data.get("message")}')
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f'Failed to get bot user info: HTTP {response.status}')
|
raise Exception(f'Failed to get bot user info: {data.get("message")}')
|
||||||
|
else:
|
||||||
|
raise Exception(f'Failed to get bot user info: HTTP {response.status}')
|
||||||
|
|
||||||
async def _handle_hello(self, data: dict):
|
async def _handle_hello(self, data: dict):
|
||||||
"""Handle HELLO signal (signal 1)"""
|
"""Handle HELLO signal (signal 1)"""
|
||||||
@@ -510,7 +512,7 @@ class KookAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if not self.http_session:
|
if not self.http_session:
|
||||||
self.http_session = aiohttp.ClientSession()
|
self.http_session = httpclient.get_session()
|
||||||
|
|
||||||
async with self.http_session.post(url, json=payload, headers=headers) as response:
|
async with self.http_session.post(url, json=payload, headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
@@ -576,7 +578,7 @@ class KookAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if not self.http_session:
|
if not self.http_session:
|
||||||
self.http_session = aiohttp.ClientSession()
|
self.http_session = httpclient.get_session()
|
||||||
|
|
||||||
async with self.http_session.post(url, json=payload, headers=headers) as response:
|
async with self.http_session.post(url, json=payload, headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
@@ -624,7 +626,7 @@ class KookAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Create HTTP session
|
# Create HTTP session
|
||||||
self.http_session = aiohttp.ClientSession()
|
self.http_session = httpclient.get_session()
|
||||||
|
|
||||||
await self.logger.info('Starting KOOK adapter')
|
await self.logger.info('Starting KOOK adapter')
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import tempfile
|
|||||||
import os
|
import os
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
|
||||||
import aiohttp
|
from langbot.pkg.utils import httpclient
|
||||||
import lark_oapi.ws.exception
|
import lark_oapi.ws.exception
|
||||||
import quart
|
import quart
|
||||||
from lark_oapi.api.im.v1 import *
|
from lark_oapi.api.im.v1 import *
|
||||||
@@ -78,13 +78,13 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
|||||||
return None
|
return None
|
||||||
elif msg.url:
|
elif msg.url:
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(msg.url) as response:
|
async with session.get(msg.url) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
image_bytes = await response.read()
|
image_bytes = await response.read()
|
||||||
else:
|
else:
|
||||||
print(f'Failed to download image from {msg.url}: HTTP {response.status}')
|
print(f'Failed to download image from {msg.url}: HTTP {response.status}')
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Failed to download image from {msg.url}: {e}')
|
print(f'Failed to download image from {msg.url}: {e}')
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@@ -208,10 +208,10 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
|||||||
pass
|
pass
|
||||||
elif msg.url:
|
elif msg.url:
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(msg.url) as resp:
|
async with session.get(msg.url) as resp:
|
||||||
if resp.status == 200:
|
if resp.status == 200:
|
||||||
data = await resp.read()
|
data = await resp.read()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
elif msg.path:
|
elif msg.path:
|
||||||
@@ -575,6 +575,127 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
|||||||
|
|
||||||
|
|
||||||
class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||||
|
_processed_thread_quote_cache: typing.ClassVar[dict[str, float]] = {}
|
||||||
|
_processed_thread_quote_cache_max_size: typing.ClassVar[int] = 4096
|
||||||
|
_processed_thread_quote_cache_ttl_seconds: typing.ClassVar[int] = 86400
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _prune_processed_thread_quote_cache(cls, now: typing.Optional[float] = None) -> None:
|
||||||
|
if now is None:
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
expire_before = now - cls._processed_thread_quote_cache_ttl_seconds
|
||||||
|
while cls._processed_thread_quote_cache:
|
||||||
|
oldest_key, oldest_ts = next(iter(cls._processed_thread_quote_cache.items()))
|
||||||
|
if oldest_ts >= expire_before:
|
||||||
|
break
|
||||||
|
cls._processed_thread_quote_cache.pop(oldest_key, None)
|
||||||
|
|
||||||
|
while len(cls._processed_thread_quote_cache) > cls._processed_thread_quote_cache_max_size:
|
||||||
|
oldest_key = next(iter(cls._processed_thread_quote_cache))
|
||||||
|
cls._processed_thread_quote_cache.pop(oldest_key, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _mark_thread_quote_processed(cls, thread_id: str) -> None:
|
||||||
|
now = time.time()
|
||||||
|
cls._prune_processed_thread_quote_cache(now)
|
||||||
|
cls._processed_thread_quote_cache[thread_id] = now
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_quote_message_id(cls, message: EventMessage) -> typing.Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract the message ID to quote from the given message.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- First thread reply in a topic: return parent_id and mark topic as processed
|
||||||
|
- Follow-up thread replies in the same topic: return None
|
||||||
|
- Non-thread message: return parent_id if valid (non-empty, different from message_id)
|
||||||
|
|
||||||
|
Thread reply state is kept in a bounded TTL cache to avoid unbounded memory growth.
|
||||||
|
"""
|
||||||
|
parent_id = getattr(message, 'parent_id', None)
|
||||||
|
if not parent_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
message_id = getattr(message, 'message_id', None)
|
||||||
|
if parent_id == message_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
thread_id = getattr(message, 'thread_id', None)
|
||||||
|
if thread_id:
|
||||||
|
cls._prune_processed_thread_quote_cache()
|
||||||
|
if thread_id in cls._processed_thread_quote_cache:
|
||||||
|
return None
|
||||||
|
cls._mark_thread_quote_processed(thread_id)
|
||||||
|
|
||||||
|
return parent_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_event_message_from_message_item(message_item: Message) -> typing.Optional[EventMessage]:
|
||||||
|
"""
|
||||||
|
Build EventMessage from SDK typed Message item.
|
||||||
|
|
||||||
|
Returns None if body or content is missing.
|
||||||
|
"""
|
||||||
|
body = getattr(message_item, 'body', None)
|
||||||
|
if not body:
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = getattr(body, 'content', None)
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
event_data = {
|
||||||
|
'message_id': message_item.message_id,
|
||||||
|
'message_type': message_item.msg_type,
|
||||||
|
'content': content,
|
||||||
|
'create_time': message_item.create_time,
|
||||||
|
'mentions': getattr(message_item, 'mentions', []) or [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Preserve thread-related fields
|
||||||
|
if hasattr(message_item, 'parent_id') and message_item.parent_id:
|
||||||
|
event_data['parent_id'] = message_item.parent_id
|
||||||
|
if hasattr(message_item, 'root_id') and message_item.root_id:
|
||||||
|
event_data['root_id'] = message_item.root_id
|
||||||
|
if hasattr(message_item, 'thread_id') and message_item.thread_id:
|
||||||
|
event_data['thread_id'] = message_item.thread_id
|
||||||
|
if hasattr(message_item, 'chat_id') and message_item.chat_id:
|
||||||
|
event_data['chat_id'] = message_item.chat_id
|
||||||
|
|
||||||
|
return EventMessage(event_data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _fetch_quoted_message(
|
||||||
|
quote_message_id: str,
|
||||||
|
api_client: lark_oapi.Client,
|
||||||
|
) -> typing.Optional[platform_message.MessageChain]:
|
||||||
|
"""
|
||||||
|
Fetch the quoted message and convert to MessageChain.
|
||||||
|
|
||||||
|
Returns None if:
|
||||||
|
- API call fails
|
||||||
|
- Response items is empty
|
||||||
|
- Message item normalization fails
|
||||||
|
"""
|
||||||
|
request = GetMessageRequest.builder().message_id(quote_message_id).build()
|
||||||
|
response = await api_client.im.v1.message.aget(request)
|
||||||
|
|
||||||
|
if not response.success():
|
||||||
|
return None
|
||||||
|
|
||||||
|
items = getattr(response.data, 'items', None)
|
||||||
|
if not items:
|
||||||
|
return None
|
||||||
|
|
||||||
|
message_item = items[0]
|
||||||
|
event_message = LarkEventConverter._build_event_message_from_message_item(message_item)
|
||||||
|
if event_message is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
quote_chain = await LarkMessageConverter.target2yiri(event_message, api_client)
|
||||||
|
return quote_chain
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def yiri2target(
|
async def yiri2target(
|
||||||
event: platform_events.MessageEvent,
|
event: platform_events.MessageEvent,
|
||||||
@@ -587,6 +708,23 @@ class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
|||||||
) -> platform_events.Event:
|
) -> platform_events.Event:
|
||||||
message_chain = await LarkMessageConverter.target2yiri(event.event.message, api_client)
|
message_chain = await LarkMessageConverter.target2yiri(event.event.message, api_client)
|
||||||
|
|
||||||
|
# Check for quote/reply message
|
||||||
|
quote_message_id = LarkEventConverter._extract_quote_message_id(event.event.message)
|
||||||
|
if quote_message_id:
|
||||||
|
quote_chain = await LarkEventConverter._fetch_quoted_message(quote_message_id, api_client)
|
||||||
|
if quote_chain:
|
||||||
|
# Filter out Source component from quoted chain, keep only content
|
||||||
|
quote_origin = platform_message.MessageChain(
|
||||||
|
[comp for comp in quote_chain if not isinstance(comp, platform_message.Source)]
|
||||||
|
)
|
||||||
|
if quote_origin:
|
||||||
|
message_chain.append(
|
||||||
|
platform_message.Quote(
|
||||||
|
message_id=quote_message_id,
|
||||||
|
origin=quote_origin,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if event.event.message.chat_type == 'p2p':
|
if event.event.message.chat_type == 'p2p':
|
||||||
return platform_events.FriendMessage(
|
return platform_events.FriendMessage(
|
||||||
sender=platform_entities.Friend(
|
sender=platform_entities.Friend(
|
||||||
@@ -770,6 +908,32 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
self.request_tenant_access_token(tenant_key)
|
self.request_tenant_access_token(tenant_key)
|
||||||
return self.tenant_access_tokens.get(tenant_key)['token'] if self.tenant_access_tokens.get(tenant_key) else None
|
return self.tenant_access_tokens.get(tenant_key)['token'] if self.tenant_access_tokens.get(tenant_key) else None
|
||||||
|
|
||||||
|
def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None:
|
||||||
|
"""
|
||||||
|
Get topic-scoped launcher_id for thread-aware session isolation.
|
||||||
|
|
||||||
|
For group thread messages, returns "{group_id}_{thread_id}"
|
||||||
|
to ensure conversation context stays stable per topic.
|
||||||
|
|
||||||
|
Returns None for non-thread messages or P2P messages.
|
||||||
|
"""
|
||||||
|
source_event = getattr(event.source_platform_object, 'event', None)
|
||||||
|
if not source_event:
|
||||||
|
return None
|
||||||
|
|
||||||
|
message = getattr(source_event, 'message', None)
|
||||||
|
if not message:
|
||||||
|
return None
|
||||||
|
|
||||||
|
thread_id = getattr(message, 'thread_id', None)
|
||||||
|
if not thread_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(event, platform_events.GroupMessage):
|
||||||
|
return f'{event.group.id}_{thread_id}'
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def build_api_client(self, config):
|
def build_api_client(self, config):
|
||||||
app_id = config['app_id']
|
app_id = config['app_id']
|
||||||
app_secret = config['app_secret']
|
app_secret = config['app_secret']
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import copy
|
|||||||
import threading
|
import threading
|
||||||
|
|
||||||
import quart
|
import quart
|
||||||
import aiohttp
|
from langbot.pkg.utils import httpclient
|
||||||
|
|
||||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||||
from ....core import app
|
from ....core import app
|
||||||
@@ -639,14 +639,14 @@ class GeWeChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
|
|
||||||
async def run_async(self):
|
async def run_async(self):
|
||||||
if not self.config['token']:
|
if not self.config['token']:
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f'{self.config["gewechat_url"]}/v2/api/tools/getTokenId',
|
f'{self.config["gewechat_url"]}/v2/api/tools/getTokenId',
|
||||||
json={'app_id': self.config['app_id']},
|
json={'app_id': self.config['app_id']},
|
||||||
) as response:
|
) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
raise Exception(f'获取gewechat token失败: {await response.text()}')
|
raise Exception(f'获取gewechat token失败: {await response.text()}')
|
||||||
self.config['token'] = (await response.json())['data']
|
self.config['token'] = (await response.json())['data']
|
||||||
|
|
||||||
self.bot = gewechat_client.GewechatClient(f'{self.config["gewechat_url"]}/v2/api', self.config['token'])
|
self.bot = gewechat_client.GewechatClient(f'{self.config["gewechat_url"]}/v2/api', self.config['token'])
|
||||||
|
|
||||||
|
|||||||
577
src/langbot/pkg/platform/sources/openclaw_weixin.py
Normal file
577
src/langbot/pkg/platform/sources/openclaw_weixin.py
Normal file
@@ -0,0 +1,577 @@
|
|||||||
|
"""OpenClaw WeChat adapter for LangBot.
|
||||||
|
|
||||||
|
Uses the OpenClaw WeChat HTTP JSON API (long-poll getUpdates + sendMessage)
|
||||||
|
to integrate personal WeChat accounts with LangBot.
|
||||||
|
|
||||||
|
Reference: https://github.com/epiral/weixin-bot
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import traceback
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
import sqlalchemy
|
||||||
|
|
||||||
|
from langbot.libs.openclaw_weixin_api.client import (
|
||||||
|
DEFAULT_BASE_URL,
|
||||||
|
SESSION_EXPIRED_ERRCODE,
|
||||||
|
OpenClawWeixinClient,
|
||||||
|
)
|
||||||
|
from langbot.libs.openclaw_weixin_api.types import (
|
||||||
|
MessageItem,
|
||||||
|
WeixinMessage,
|
||||||
|
)
|
||||||
|
from langbot.pkg.entity.persistence import bot as persistence_bot
|
||||||
|
|
||||||
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||||
|
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||||
|
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||||
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
|
|
||||||
|
|
||||||
|
class OpenClawWeixinMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||||
|
"""Converts between LangBot MessageChain and OpenClaw WeChat message items."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]:
|
||||||
|
"""Convert LangBot MessageChain to a list of OpenClaw message item dicts."""
|
||||||
|
items = []
|
||||||
|
for component in message_chain:
|
||||||
|
if isinstance(component, platform_message.Plain):
|
||||||
|
items.append({'type': MessageItem.TEXT, 'text_item': {'text': component.text}})
|
||||||
|
elif isinstance(component, platform_message.Image):
|
||||||
|
# OpenClaw WeChat only supports text messages without CDN upload.
|
||||||
|
# For images, we send a placeholder text with the URL if available.
|
||||||
|
if component.url:
|
||||||
|
items.append(
|
||||||
|
{
|
||||||
|
'type': MessageItem.TEXT,
|
||||||
|
'text_item': {'text': f'[Image: {component.url}]'},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif component.base64:
|
||||||
|
items.append(
|
||||||
|
{
|
||||||
|
'type': MessageItem.TEXT,
|
||||||
|
'text_item': {'text': '[Image]'},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(component, platform_message.File):
|
||||||
|
if component.name:
|
||||||
|
items.append(
|
||||||
|
{
|
||||||
|
'type': MessageItem.TEXT,
|
||||||
|
'text_item': {'text': f'[File: {component.name}]'},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(component, platform_message.Forward):
|
||||||
|
for node in component.node_list:
|
||||||
|
if node.message_chain:
|
||||||
|
items.extend(await OpenClawWeixinMessageConverter.yiri2target(node.message_chain))
|
||||||
|
return items
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def target2yiri(
|
||||||
|
msg: WeixinMessage,
|
||||||
|
) -> platform_message.MessageChain:
|
||||||
|
"""Convert an OpenClaw WeixinMessage to LangBot MessageChain."""
|
||||||
|
components: list[platform_message.MessageComponent] = []
|
||||||
|
|
||||||
|
if not msg.item_list:
|
||||||
|
return platform_message.MessageChain(components)
|
||||||
|
|
||||||
|
for item in msg.item_list:
|
||||||
|
if item.type == MessageItem.TEXT and item.text_item and item.text_item.text:
|
||||||
|
text = item.text_item.text
|
||||||
|
|
||||||
|
# Handle quoted messages
|
||||||
|
if item.ref_msg:
|
||||||
|
ref_parts = []
|
||||||
|
if item.ref_msg.title:
|
||||||
|
ref_parts.append(item.ref_msg.title)
|
||||||
|
if item.ref_msg.message_item:
|
||||||
|
ref_item = item.ref_msg.message_item
|
||||||
|
if ref_item.text_item and ref_item.text_item.text:
|
||||||
|
ref_parts.append(ref_item.text_item.text)
|
||||||
|
if ref_parts:
|
||||||
|
components.append(
|
||||||
|
platform_message.Quote(
|
||||||
|
sender_id='',
|
||||||
|
origin=platform_message.MessageChain(
|
||||||
|
[platform_message.Plain(text=' | '.join(ref_parts))]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
components.append(platform_message.Plain(text=text))
|
||||||
|
|
||||||
|
elif item.type == MessageItem.IMAGE and item.image_item:
|
||||||
|
if hasattr(item.image_item, '_downloaded_bytes') and item.image_item._downloaded_bytes:
|
||||||
|
b64 = base64.b64encode(item.image_item._downloaded_bytes).decode('utf-8')
|
||||||
|
components.append(platform_message.Image(base64=f'data:image/jpeg;base64,{b64}'))
|
||||||
|
else:
|
||||||
|
components.append(platform_message.Unknown(text='[Image]'))
|
||||||
|
|
||||||
|
elif item.type == MessageItem.VOICE and item.voice_item:
|
||||||
|
# Voice with speech-to-text: use the transcribed text
|
||||||
|
if item.voice_item.text:
|
||||||
|
components.append(platform_message.Plain(text=item.voice_item.text))
|
||||||
|
else:
|
||||||
|
components.append(platform_message.Unknown(text='[Voice]'))
|
||||||
|
|
||||||
|
# TODO: enable after full testing
|
||||||
|
# elif item.type == MessageItem.VOICE and item.voice_item:
|
||||||
|
# if item.voice_item.text:
|
||||||
|
# components.append(platform_message.Plain(text=item.voice_item.text))
|
||||||
|
# elif hasattr(item.voice_item, '_downloaded_bytes') and item.voice_item._downloaded_bytes:
|
||||||
|
# b64 = base64.b64encode(item.voice_item._downloaded_bytes).decode('utf-8')
|
||||||
|
# components.append(
|
||||||
|
# platform_message.Voice(
|
||||||
|
# base64=b64,
|
||||||
|
# length=item.voice_item.playtime or 0,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# components.append(
|
||||||
|
# platform_message.Voice(
|
||||||
|
# length=item.voice_item.playtime or 0,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
elif item.type == MessageItem.FILE and item.file_item:
|
||||||
|
components.append(platform_message.Unknown(text=f'[File: {item.file_item.file_name or ""}]'))
|
||||||
|
|
||||||
|
# TODO: enable after full testing
|
||||||
|
# elif item.type == MessageItem.FILE and item.file_item:
|
||||||
|
# file_name = item.file_item.file_name or ''
|
||||||
|
# file_size = int(item.file_item.len) if item.file_item.len else 0
|
||||||
|
# if hasattr(item.file_item, '_downloaded_bytes') and item.file_item._downloaded_bytes:
|
||||||
|
# b64 = base64.b64encode(item.file_item._downloaded_bytes).decode('utf-8')
|
||||||
|
# components.append(
|
||||||
|
# platform_message.File(
|
||||||
|
# name=file_name,
|
||||||
|
# size=file_size,
|
||||||
|
# base64=b64,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# components.append(
|
||||||
|
# platform_message.File(
|
||||||
|
# name=file_name,
|
||||||
|
# size=file_size,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
elif item.type == MessageItem.VIDEO and item.video_item:
|
||||||
|
components.append(platform_message.Unknown(text='[Video]'))
|
||||||
|
|
||||||
|
# TODO: enable after full testing
|
||||||
|
# elif item.type == MessageItem.VIDEO and item.video_item:
|
||||||
|
# if hasattr(item.video_item, '_downloaded_bytes') and item.video_item._downloaded_bytes:
|
||||||
|
# b64 = base64.b64encode(item.video_item._downloaded_bytes).decode('utf-8')
|
||||||
|
# components.append(
|
||||||
|
# platform_message.File(
|
||||||
|
# name='video.mp4',
|
||||||
|
# size=item.video_item.video_size or 0,
|
||||||
|
# base64=b64,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# components.append(
|
||||||
|
# platform_message.File(
|
||||||
|
# name='video.mp4',
|
||||||
|
# size=item.video_item.video_size or 0,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
else:
|
||||||
|
components.append(platform_message.Unknown(text='[Unknown message type]'))
|
||||||
|
|
||||||
|
return platform_message.MessageChain(components)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenClawWeixinEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||||
|
"""Converts OpenClaw WeChat messages to LangBot events."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def yiri2target(event: platform_events.MessageEvent) -> dict:
|
||||||
|
return event.source_platform_object
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def target2yiri(msg: WeixinMessage) -> typing.Optional[platform_events.MessageEvent]:
|
||||||
|
"""Convert an inbound WeixinMessage to a LangBot event."""
|
||||||
|
if msg.message_type != WeixinMessage.TYPE_USER:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from_user_id = msg.from_user_id or ''
|
||||||
|
if not from_user_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
message_chain = await OpenClawWeixinMessageConverter.target2yiri(msg)
|
||||||
|
if not message_chain:
|
||||||
|
return None
|
||||||
|
|
||||||
|
timestamp = (msg.create_time_ms or 0) / 1000.0
|
||||||
|
|
||||||
|
return platform_events.FriendMessage(
|
||||||
|
sender=platform_entities.Friend(
|
||||||
|
id=from_user_id,
|
||||||
|
nickname=from_user_id,
|
||||||
|
remark='',
|
||||||
|
),
|
||||||
|
message_chain=message_chain,
|
||||||
|
time=timestamp,
|
||||||
|
source_platform_object=msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenClawWeixinAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||||
|
"""LangBot adapter for OpenClaw WeChat (long-poll based)."""
|
||||||
|
|
||||||
|
name: str = 'openclaw-weixin'
|
||||||
|
|
||||||
|
client: OpenClawWeixinClient = pydantic.Field(exclude=True)
|
||||||
|
|
||||||
|
config: dict
|
||||||
|
|
||||||
|
message_converter: OpenClawWeixinMessageConverter = OpenClawWeixinMessageConverter()
|
||||||
|
event_converter: OpenClawWeixinEventConverter = OpenClawWeixinEventConverter()
|
||||||
|
|
||||||
|
# context_token cache: from_user_id -> context_token
|
||||||
|
_context_tokens: dict[str, str] = pydantic.PrivateAttr(default_factory=dict)
|
||||||
|
|
||||||
|
_polling: bool = pydantic.PrivateAttr(default=False)
|
||||||
|
_poll_task: typing.Optional[asyncio.Task] = pydantic.PrivateAttr(default=None)
|
||||||
|
_bot_uuid: typing.Optional[str] = pydantic.PrivateAttr(default=None)
|
||||||
|
|
||||||
|
listeners: typing.Dict[
|
||||||
|
typing.Type[platform_events.Event],
|
||||||
|
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||||||
|
] = {}
|
||||||
|
|
||||||
|
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
|
||||||
|
client = OpenClawWeixinClient(
|
||||||
|
base_url=config.get('base_url', DEFAULT_BASE_URL),
|
||||||
|
token=config.get('token', ''),
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
config=config,
|
||||||
|
logger=logger,
|
||||||
|
client=client,
|
||||||
|
bot_account_id='',
|
||||||
|
listeners={},
|
||||||
|
name='openclaw-weixin',
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_bot_uuid(self, bot_uuid: str):
|
||||||
|
"""Called by BotManager to provide the bot's UUID for config persistence."""
|
||||||
|
self._bot_uuid = bot_uuid
|
||||||
|
|
||||||
|
async def _persist_config(self) -> None:
|
||||||
|
"""Persist current self.config to the database so token survives restart."""
|
||||||
|
if not self._bot_uuid:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
ap = self.logger.ap
|
||||||
|
await ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.update(persistence_bot.Bot)
|
||||||
|
.where(persistence_bot.Bot.uuid == self._bot_uuid)
|
||||||
|
.values(adapter_config=self.config)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await self.logger.warning(f'Failed to persist adapter config: {e}')
|
||||||
|
|
||||||
|
async def _do_login(self) -> None:
|
||||||
|
"""Run the QR code login flow via client.login() and update config."""
|
||||||
|
adapter_logger = self.logger
|
||||||
|
|
||||||
|
async def _on_qrcode(qr_base64: str, _qr_url: str):
|
||||||
|
await adapter_logger.info(
|
||||||
|
f'Please scan the QR code to login WeChat: {_qr_url}',
|
||||||
|
images=[platform_message.Image(base64=qr_base64)],
|
||||||
|
)
|
||||||
|
|
||||||
|
login_result = await self.client.login(
|
||||||
|
on_qrcode=_on_qrcode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# client.login() already updates client.token and client.base_url
|
||||||
|
self.config['token'] = login_result.token
|
||||||
|
self.config['base_url'] = login_result.base_url
|
||||||
|
if login_result.account_id:
|
||||||
|
self.config['account_id'] = login_result.account_id
|
||||||
|
|
||||||
|
await self.logger.info(f'WeChat login successful! account_id={login_result.account_id}')
|
||||||
|
|
||||||
|
# Persist token to database so it survives restart
|
||||||
|
await self._persist_config()
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
self,
|
||||||
|
target_type: str,
|
||||||
|
target_id: str,
|
||||||
|
message: platform_message.MessageChain,
|
||||||
|
):
|
||||||
|
"""Send a message to a user."""
|
||||||
|
context_token = self._context_tokens.get(target_id, '')
|
||||||
|
|
||||||
|
for component in message:
|
||||||
|
try:
|
||||||
|
if isinstance(component, platform_message.Plain):
|
||||||
|
if component.text:
|
||||||
|
await self.client.send_text(target_id, component.text, context_token)
|
||||||
|
|
||||||
|
elif isinstance(component, platform_message.Image):
|
||||||
|
img_bytes, _ = await component.get_bytes()
|
||||||
|
await self.client.send_image(target_id, img_bytes, context_token)
|
||||||
|
|
||||||
|
elif isinstance(component, platform_message.File):
|
||||||
|
file_bytes = await self._get_component_bytes(component)
|
||||||
|
if file_bytes:
|
||||||
|
await self.client.send_file(target_id, file_bytes, component.name or 'file', context_token)
|
||||||
|
|
||||||
|
elif isinstance(component, platform_message.Voice):
|
||||||
|
voice_bytes = await self._get_component_bytes(component)
|
||||||
|
if voice_bytes:
|
||||||
|
await self.client.send_voice(target_id, voice_bytes, component.length or 0, context_token)
|
||||||
|
|
||||||
|
elif isinstance(component, platform_message.Forward):
|
||||||
|
for node in component.node_list:
|
||||||
|
if node.message_chain:
|
||||||
|
await self.send_message(target_type, target_id, node.message_chain)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(
|
||||||
|
f'Failed to send component {type(component).__name__}: {traceback.format_exc()}'
|
||||||
|
)
|
||||||
|
|
||||||
|
async def reply_message(
|
||||||
|
self,
|
||||||
|
message_source: platform_events.MessageEvent,
|
||||||
|
message: platform_message.MessageChain,
|
||||||
|
quote_origin: bool = False,
|
||||||
|
):
|
||||||
|
"""Reply to a received message."""
|
||||||
|
source_msg = message_source.source_platform_object
|
||||||
|
if isinstance(source_msg, WeixinMessage):
|
||||||
|
target_id = source_msg.from_user_id or ''
|
||||||
|
if target_id:
|
||||||
|
await self.send_message('friend', target_id, message)
|
||||||
|
|
||||||
|
async def is_muted(self, group_id: int) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _get_component_bytes(component: platform_message.MessageComponent) -> typing.Optional[bytes]:
|
||||||
|
"""Extract raw bytes from a File or Voice component."""
|
||||||
|
b64_val = getattr(component, 'base64', None)
|
||||||
|
url_val = getattr(component, 'url', None)
|
||||||
|
path_val = getattr(component, 'path', None)
|
||||||
|
|
||||||
|
if b64_val:
|
||||||
|
return base64.b64decode(b64_val)
|
||||||
|
elif url_val and url_val.startswith(('http://', 'https://')):
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url_val) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
return await resp.read()
|
||||||
|
elif path_val:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
with open(path_val, 'rb') as f:
|
||||||
|
return await asyncio.to_thread(f.read)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def register_listener(
|
||||||
|
self,
|
||||||
|
event_type: typing.Type[platform_events.Event],
|
||||||
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter],
|
||||||
|
None,
|
||||||
|
],
|
||||||
|
):
|
||||||
|
self.listeners[event_type] = callback
|
||||||
|
|
||||||
|
def unregister_listener(
|
||||||
|
self,
|
||||||
|
event_type: typing.Type[platform_events.Event],
|
||||||
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter],
|
||||||
|
None,
|
||||||
|
],
|
||||||
|
):
|
||||||
|
self.listeners.pop(event_type, None)
|
||||||
|
|
||||||
|
async def run_async(self):
|
||||||
|
"""Start the adapter. If no token is configured, trigger QR code login first."""
|
||||||
|
base_url = self.config.get('base_url', DEFAULT_BASE_URL)
|
||||||
|
token = self.config.get('token', '')
|
||||||
|
|
||||||
|
await self.logger.info('OpenClaw WeChat adapter starting...')
|
||||||
|
|
||||||
|
# QR code login flow when no token is provided
|
||||||
|
if not token:
|
||||||
|
await self.logger.info('No token configured, starting QR code login...')
|
||||||
|
try:
|
||||||
|
await self._do_login()
|
||||||
|
except Exception as e:
|
||||||
|
await self.logger.error(f'QR code login failed: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Rebuild client with the (possibly updated) config
|
||||||
|
self.client = OpenClawWeixinClient(
|
||||||
|
base_url=self.config.get('base_url', base_url),
|
||||||
|
token=self.config.get('token', token),
|
||||||
|
)
|
||||||
|
self.bot_account_id = self.config.get('account_id', 'openclaw-weixin')
|
||||||
|
self._polling = True
|
||||||
|
|
||||||
|
# Start the long-poll loop
|
||||||
|
self._poll_task = asyncio.create_task(self._poll_loop())
|
||||||
|
await self.logger.info('OpenClaw WeChat adapter running')
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._poll_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _poll_loop(self):
|
||||||
|
"""Long-poll loop: call getUpdates continuously.
|
||||||
|
|
||||||
|
Error handling follows the weixin-bot SDK pattern:
|
||||||
|
- Exponential backoff (1s -> 10s max) on failures
|
||||||
|
- Session expired (errcode -14) triggers automatic re-login
|
||||||
|
"""
|
||||||
|
get_updates_buf = ''
|
||||||
|
poll_timeout = float(self.config.get('poll_timeout', 35))
|
||||||
|
|
||||||
|
backoff_delay = 1.0
|
||||||
|
max_backoff = 10.0
|
||||||
|
|
||||||
|
while self._polling:
|
||||||
|
try:
|
||||||
|
resp = await self.client.get_updates(
|
||||||
|
get_updates_buf=get_updates_buf,
|
||||||
|
timeout=poll_timeout + 5,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.longpolling_timeout_ms and resp.longpolling_timeout_ms > 0:
|
||||||
|
poll_timeout = resp.longpolling_timeout_ms / 1000.0
|
||||||
|
|
||||||
|
is_api_error = (resp.ret is not None and resp.ret != 0) or (
|
||||||
|
resp.errcode is not None and resp.errcode != 0
|
||||||
|
)
|
||||||
|
if is_api_error:
|
||||||
|
is_session_expired = resp.errcode == SESSION_EXPIRED_ERRCODE or resp.ret == SESSION_EXPIRED_ERRCODE
|
||||||
|
|
||||||
|
if is_session_expired:
|
||||||
|
await self.logger.error('OpenClaw WeChat session expired, attempting re-login...')
|
||||||
|
try:
|
||||||
|
await self._do_login()
|
||||||
|
# Rebuild client with new credentials
|
||||||
|
self.client = OpenClawWeixinClient(
|
||||||
|
base_url=self.config.get('base_url', DEFAULT_BASE_URL),
|
||||||
|
token=self.config.get('token', ''),
|
||||||
|
)
|
||||||
|
self._context_tokens.clear()
|
||||||
|
get_updates_buf = ''
|
||||||
|
backoff_delay = 1.0
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Re-login failed: {traceback.format_exc()}')
|
||||||
|
break
|
||||||
|
|
||||||
|
await self.logger.error(
|
||||||
|
f'OpenClaw getUpdates failed: ret={resp.ret} errcode={resp.errcode} errmsg={resp.errmsg}'
|
||||||
|
)
|
||||||
|
await asyncio.sleep(backoff_delay)
|
||||||
|
backoff_delay = min(backoff_delay * 2, max_backoff)
|
||||||
|
continue
|
||||||
|
|
||||||
|
backoff_delay = 1.0
|
||||||
|
|
||||||
|
if resp.get_updates_buf:
|
||||||
|
get_updates_buf = resp.get_updates_buf
|
||||||
|
|
||||||
|
for msg in resp.msgs:
|
||||||
|
try:
|
||||||
|
await self._handle_inbound_message(msg)
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Error handling message: {traceback.format_exc()}')
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'OpenClaw poll error: {traceback.format_exc()}')
|
||||||
|
await asyncio.sleep(backoff_delay)
|
||||||
|
backoff_delay = min(backoff_delay * 2, max_backoff)
|
||||||
|
|
||||||
|
async def _handle_inbound_message(self, msg: WeixinMessage):
|
||||||
|
"""Process a single inbound message from getUpdates."""
|
||||||
|
if msg.context_token and msg.from_user_id:
|
||||||
|
self._context_tokens[msg.from_user_id] = msg.context_token
|
||||||
|
|
||||||
|
# Download CDN media (files, images) before converting to LangBot events
|
||||||
|
await self._download_media_items(msg)
|
||||||
|
|
||||||
|
event = await OpenClawWeixinEventConverter.target2yiri(msg)
|
||||||
|
if event is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if type(event) in self.listeners:
|
||||||
|
await self.listeners[type(event)](event, self)
|
||||||
|
|
||||||
|
async def _download_media_items(self, msg: WeixinMessage):
|
||||||
|
"""Download CDN media for image items in the message."""
|
||||||
|
if not msg.item_list:
|
||||||
|
return
|
||||||
|
|
||||||
|
for item in msg.item_list:
|
||||||
|
try:
|
||||||
|
if item.type == MessageItem.IMAGE and item.image_item:
|
||||||
|
if (
|
||||||
|
item.image_item.media
|
||||||
|
and item.image_item.media.encrypt_query_param
|
||||||
|
and item.image_item.media.aes_key
|
||||||
|
):
|
||||||
|
img_bytes = await self.client.download_media(item.image_item.media)
|
||||||
|
item.image_item._downloaded_bytes = img_bytes
|
||||||
|
|
||||||
|
# TODO: enable after full testing
|
||||||
|
# elif item.type == MessageItem.FILE and item.file_item and item.file_item.media:
|
||||||
|
# if item.file_item.media.encrypt_query_param and item.file_item.media.aes_key:
|
||||||
|
# file_bytes = await self.client.download_media(item.file_item.media)
|
||||||
|
# item.file_item._downloaded_bytes = file_bytes
|
||||||
|
#
|
||||||
|
# elif item.type == MessageItem.VOICE and item.voice_item and item.voice_item.media:
|
||||||
|
# if item.voice_item.media.encrypt_query_param and item.voice_item.media.aes_key:
|
||||||
|
# voice_bytes = await self.client.download_media(item.voice_item.media)
|
||||||
|
# item.voice_item._downloaded_bytes = voice_bytes
|
||||||
|
#
|
||||||
|
# elif item.type == MessageItem.VIDEO and item.video_item and item.video_item.media:
|
||||||
|
# if item.video_item.media.encrypt_query_param and item.video_item.media.aes_key:
|
||||||
|
# video_bytes = await self.client.download_media(item.video_item.media)
|
||||||
|
# item.video_item._downloaded_bytes = video_bytes
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
await self.logger.warning(f'Failed to download CDN media: {traceback.format_exc()}')
|
||||||
|
|
||||||
|
async def kill(self) -> bool:
|
||||||
|
"""Stop the adapter."""
|
||||||
|
self._polling = False
|
||||||
|
if self._poll_task and not self._poll_task.done():
|
||||||
|
self._poll_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._poll_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
await self.client.close()
|
||||||
|
await self.logger.info('OpenClaw WeChat adapter stopped')
|
||||||
|
return True
|
||||||
57
src/langbot/pkg/platform/sources/openclaw_weixin.yaml
Normal file
57
src/langbot/pkg/platform/sources/openclaw_weixin.yaml
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: MessagePlatformAdapter
|
||||||
|
metadata:
|
||||||
|
name: openclaw-weixin
|
||||||
|
label:
|
||||||
|
en_US: OpenClaw WeChat
|
||||||
|
zh_Hans: OpenClaw 微信
|
||||||
|
description:
|
||||||
|
en_US: OpenClaw WeChat adapter, supports personal WeChat via QR code login
|
||||||
|
zh_Hans: OpenClaw 微信适配器,通过扫码登录支持个人微信
|
||||||
|
icon: wechat.png
|
||||||
|
spec:
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: API Base URL
|
||||||
|
zh_Hans: API 基础地址
|
||||||
|
description:
|
||||||
|
en_US: The base URL of the OpenClaw WeChat backend API
|
||||||
|
zh_Hans: OpenClaw 微信后端 API 的基础地址
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: "https://ilinkai.weixin.qq.com"
|
||||||
|
- name: token
|
||||||
|
label:
|
||||||
|
en_US: Token
|
||||||
|
zh_Hans: 令牌
|
||||||
|
description:
|
||||||
|
en_US: Bearer token obtained after QR code login authorization. Leave empty to trigger QR code login on startup.
|
||||||
|
zh_Hans: 扫码登录授权后获取的 Bearer 令牌。留空则启动时自动触发扫码登录。
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
default: ""
|
||||||
|
- name: account_id
|
||||||
|
label:
|
||||||
|
en_US: Account ID
|
||||||
|
zh_Hans: 账号标识
|
||||||
|
description:
|
||||||
|
en_US: A label for this WeChat account (used for display purposes)
|
||||||
|
zh_Hans: 此微信账号的标识(用于显示)
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
default: "openclaw-weixin"
|
||||||
|
- name: poll_timeout
|
||||||
|
label:
|
||||||
|
en_US: Poll Timeout (seconds)
|
||||||
|
zh_Hans: 轮询超时(秒)
|
||||||
|
description:
|
||||||
|
en_US: Long-poll timeout for getUpdates, the server may hold the request up to this duration
|
||||||
|
zh_Hans: getUpdates 长轮询超时时间,服务端最多持有请求的时长
|
||||||
|
type: integer
|
||||||
|
required: false
|
||||||
|
default: 35
|
||||||
|
execution:
|
||||||
|
python:
|
||||||
|
path: ./openclaw_weixin.py
|
||||||
|
attr: OpenClawWeixinAdapter
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
import telegram
|
import telegram
|
||||||
@@ -9,9 +10,9 @@ import telegramify_markdown
|
|||||||
import typing
|
import typing
|
||||||
import traceback
|
import traceback
|
||||||
import base64
|
import base64
|
||||||
import aiohttp
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
|
from langbot.pkg.utils import httpclient
|
||||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
@@ -33,14 +34,33 @@ class TelegramMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
|||||||
if component.base64:
|
if component.base64:
|
||||||
photo_bytes = base64.b64decode(component.base64)
|
photo_bytes = base64.b64decode(component.base64)
|
||||||
elif component.url:
|
elif component.url:
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(component.url) as response:
|
async with session.get(component.url) as response:
|
||||||
photo_bytes = await response.read()
|
photo_bytes = await response.read()
|
||||||
elif component.path:
|
elif component.path:
|
||||||
with open(component.path, 'rb') as f:
|
with open(component.path, 'rb') as f:
|
||||||
photo_bytes = f.read()
|
photo_bytes = f.read()
|
||||||
|
|
||||||
components.append({'type': 'photo', 'photo': photo_bytes})
|
components.append({'type': 'photo', 'photo': photo_bytes})
|
||||||
|
elif isinstance(component, platform_message.File):
|
||||||
|
file_bytes = None
|
||||||
|
|
||||||
|
if component.base64:
|
||||||
|
# Strip data URI prefix if present (e.g. "data:application/pdf;base64,...")
|
||||||
|
b64_data = component.base64
|
||||||
|
if ';base64,' in b64_data:
|
||||||
|
b64_data = b64_data.split(';base64,', 1)[1]
|
||||||
|
file_bytes = base64.b64decode(b64_data)
|
||||||
|
elif component.url:
|
||||||
|
session = httpclient.get_session()
|
||||||
|
async with session.get(component.url) as response:
|
||||||
|
file_bytes = await response.read()
|
||||||
|
elif component.path:
|
||||||
|
with open(component.path, 'rb') as f:
|
||||||
|
file_bytes = f.read()
|
||||||
|
|
||||||
|
file_name = getattr(component, 'name', None) or 'file'
|
||||||
|
components.append({'type': 'document', 'document': file_bytes, 'filename': file_name})
|
||||||
elif isinstance(component, platform_message.Forward):
|
elif isinstance(component, platform_message.Forward):
|
||||||
for node in component.node_list:
|
for node in component.node_list:
|
||||||
components.extend(await TelegramMessageConverter.yiri2target(node.message_chain, bot))
|
components.extend(await TelegramMessageConverter.yiri2target(node.message_chain, bot))
|
||||||
@@ -74,10 +94,9 @@ class TelegramMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
|||||||
file_bytes = None
|
file_bytes = None
|
||||||
file_format = ''
|
file_format = ''
|
||||||
|
|
||||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
async with httpclient.get_session(trust_env=True).get(file.file_path) as response:
|
||||||
async with session.get(file.file_path) as response:
|
file_bytes = await response.read()
|
||||||
file_bytes = await response.read()
|
file_format = 'image/jpeg'
|
||||||
file_format = 'image/jpeg'
|
|
||||||
|
|
||||||
message_components.append(
|
message_components.append(
|
||||||
platform_message.Image(
|
platform_message.Image(
|
||||||
@@ -94,9 +113,8 @@ class TelegramMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
|||||||
file_bytes = None
|
file_bytes = None
|
||||||
file_format = message.voice.mime_type or 'audio/ogg'
|
file_format = message.voice.mime_type or 'audio/ogg'
|
||||||
|
|
||||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
async with httpclient.get_session(trust_env=True).get(file.file_path) as response:
|
||||||
async with session.get(file.file_path) as response:
|
file_bytes = await response.read()
|
||||||
file_bytes = await response.read()
|
|
||||||
|
|
||||||
message_components.append(
|
message_components.append(
|
||||||
platform_message.Voice(
|
platform_message.Voice(
|
||||||
@@ -105,6 +123,27 @@ class TelegramMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if message.document:
|
||||||
|
if message.caption:
|
||||||
|
message_components.extend(parse_message_text(message.caption))
|
||||||
|
|
||||||
|
file = await message.document.get_file()
|
||||||
|
file_name = message.document.file_name or 'document'
|
||||||
|
file_size = message.document.file_size or 0
|
||||||
|
file_format = message.document.mime_type or 'application/octet-stream'
|
||||||
|
|
||||||
|
file_bytes = None
|
||||||
|
async with httpclient.get_session(trust_env=True).get(file.file_path) as response:
|
||||||
|
file_bytes = await response.read()
|
||||||
|
|
||||||
|
message_components.append(
|
||||||
|
platform_message.File(
|
||||||
|
name=file_name,
|
||||||
|
size=file_size,
|
||||||
|
base64=f'data:{file_format};base64,{base64.b64encode(file_bytes).decode("utf-8")}',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return platform_message.MessageChain(message_components)
|
return platform_message.MessageChain(message_components)
|
||||||
|
|
||||||
|
|
||||||
@@ -180,7 +219,10 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
application = ApplicationBuilder().token(config['token']).build()
|
application = ApplicationBuilder().token(config['token']).build()
|
||||||
bot = application.bot
|
bot = application.bot
|
||||||
application.add_handler(
|
application.add_handler(
|
||||||
MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO | filters.VOICE, telegram_callback)
|
MessageHandler(
|
||||||
|
filters.TEXT | (filters.COMMAND) | filters.PHOTO | filters.VOICE | filters.Document.ALL,
|
||||||
|
telegram_callback,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
config=config,
|
config=config,
|
||||||
@@ -194,7 +236,38 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||||
pass
|
components = await TelegramMessageConverter.yiri2target(message, self.bot)
|
||||||
|
|
||||||
|
chat_id_str, _, thread_id_str = str(target_id).partition('#')
|
||||||
|
chat_id: int | str = int(chat_id_str) if chat_id_str.lstrip('-').isdigit() else chat_id_str
|
||||||
|
message_thread_id = int(thread_id_str) if thread_id_str and thread_id_str.isdigit() else None
|
||||||
|
|
||||||
|
for component in components:
|
||||||
|
component_type = component.get('type')
|
||||||
|
args = {'chat_id': chat_id}
|
||||||
|
if message_thread_id is not None:
|
||||||
|
args['message_thread_id'] = message_thread_id
|
||||||
|
|
||||||
|
if component_type == 'text':
|
||||||
|
text = component.get('text', '')
|
||||||
|
if self.config['markdown_card'] is True:
|
||||||
|
text = telegramify_markdown.markdownify(content=text)
|
||||||
|
args['parse_mode'] = 'MarkdownV2'
|
||||||
|
args['text'] = text
|
||||||
|
await self.bot.send_message(**args)
|
||||||
|
elif component_type == 'photo':
|
||||||
|
photo = component.get('photo')
|
||||||
|
if photo is None:
|
||||||
|
continue
|
||||||
|
args['photo'] = telegram.InputFile(photo)
|
||||||
|
await self.bot.send_photo(**args)
|
||||||
|
elif component_type == 'document':
|
||||||
|
doc = component.get('document')
|
||||||
|
if doc is None:
|
||||||
|
continue
|
||||||
|
filename = component.get('filename', 'file')
|
||||||
|
args['document'] = telegram.InputFile(doc, filename=filename)
|
||||||
|
await self.bot.send_document(**args)
|
||||||
|
|
||||||
async def reply_message(
|
async def reply_message(
|
||||||
self,
|
self,
|
||||||
@@ -228,6 +301,39 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
|
|
||||||
await self.bot.send_message(**args)
|
await self.bot.send_message(**args)
|
||||||
|
|
||||||
|
def _process_markdown(self, text: str) -> str:
|
||||||
|
if self.config.get('markdown_card', False):
|
||||||
|
return telegramify_markdown.markdownify(content=text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _build_message_args(self, chat_id: int, text: str, message_thread_id: int = None, **extra_args) -> dict:
|
||||||
|
args = {'chat_id': chat_id, 'text': self._process_markdown(text), **extra_args}
|
||||||
|
if message_thread_id:
|
||||||
|
args['message_thread_id'] = message_thread_id
|
||||||
|
if self.config.get('markdown_card', False):
|
||||||
|
args['parse_mode'] = 'MarkdownV2'
|
||||||
|
return args
|
||||||
|
|
||||||
|
async def create_message_card(self, message_id, event):
|
||||||
|
assert isinstance(event.source_platform_object, Update)
|
||||||
|
update = event.source_platform_object
|
||||||
|
chat_id = update.effective_chat.id
|
||||||
|
chat_type = update.effective_chat.type
|
||||||
|
message_thread_id = update.message.message_thread_id
|
||||||
|
|
||||||
|
if chat_type == 'private':
|
||||||
|
draft_id = int(time.time() * 1000)
|
||||||
|
self.msg_stream_id[message_id] = ('private', draft_id)
|
||||||
|
|
||||||
|
args = self._build_message_args(chat_id, 'Thinking...', message_thread_id, draft_id=draft_id)
|
||||||
|
await self.bot.send_message_draft(**args)
|
||||||
|
else:
|
||||||
|
args = self._build_message_args(chat_id, 'Thinking...', message_thread_id)
|
||||||
|
send_msg = await self.bot.send_message(**args)
|
||||||
|
self.msg_stream_id[message_id] = ('group', send_msg.message_id)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
async def reply_message_chunk(
|
async def reply_message_chunk(
|
||||||
self,
|
self,
|
||||||
message_source: platform_events.MessageEvent,
|
message_source: platform_events.MessageEvent,
|
||||||
@@ -236,59 +342,47 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
quote_origin: bool = False,
|
quote_origin: bool = False,
|
||||||
is_final: bool = False,
|
is_final: bool = False,
|
||||||
):
|
):
|
||||||
|
message_id = bot_message.resp_message_id
|
||||||
msg_seq = bot_message.msg_sequence
|
msg_seq = bot_message.msg_sequence
|
||||||
if (msg_seq - 1) % 8 == 0 or is_final:
|
assert isinstance(message_source.source_platform_object, Update)
|
||||||
assert isinstance(message_source.source_platform_object, Update)
|
update = message_source.source_platform_object
|
||||||
components = await TelegramMessageConverter.yiri2target(message, self.bot)
|
chat_id = update.effective_chat.id
|
||||||
args = {}
|
message_thread_id = update.message.message_thread_id
|
||||||
message_id = message_source.source_platform_object.message.id
|
|
||||||
|
|
||||||
component = components[0]
|
if message_id not in self.msg_stream_id:
|
||||||
if message_id not in self.msg_stream_id: # 当消息回复第一次时,发送新消息
|
return
|
||||||
# time.sleep(0.6)
|
|
||||||
if component['type'] == 'text':
|
|
||||||
if self.config['markdown_card'] is True:
|
|
||||||
content = telegramify_markdown.markdownify(
|
|
||||||
content=component['text'],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
content = component['text']
|
|
||||||
args = {
|
|
||||||
'chat_id': message_source.source_platform_object.effective_chat.id,
|
|
||||||
'text': content,
|
|
||||||
}
|
|
||||||
if message_source.source_platform_object.message.message_thread_id:
|
|
||||||
args['message_thread_id'] = message_source.source_platform_object.message.message_thread_id
|
|
||||||
|
|
||||||
if quote_origin:
|
chat_mode, draft_id = self.msg_stream_id[message_id]
|
||||||
args['reply_to_message_id'] = message_source.source_platform_object.message.id
|
components = await TelegramMessageConverter.yiri2target(message, self.bot)
|
||||||
|
|
||||||
if self.config['markdown_card'] is True:
|
if not components or components[0]['type'] != 'text':
|
||||||
args['parse_mode'] = 'MarkdownV2'
|
|
||||||
|
|
||||||
send_msg = await self.bot.send_message(**args)
|
|
||||||
send_msg_id = send_msg.message_id
|
|
||||||
self.msg_stream_id[message_id] = send_msg_id
|
|
||||||
else: # 存在消息的时候直接编辑消息1
|
|
||||||
if component['type'] == 'text':
|
|
||||||
if self.config['markdown_card'] is True:
|
|
||||||
content = telegramify_markdown.markdownify(
|
|
||||||
content=component['text'],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
content = component['text']
|
|
||||||
args = {
|
|
||||||
'message_id': self.msg_stream_id[message_id],
|
|
||||||
'chat_id': message_source.source_platform_object.effective_chat.id,
|
|
||||||
'text': content,
|
|
||||||
}
|
|
||||||
if self.config['markdown_card'] is True:
|
|
||||||
args['parse_mode'] = 'MarkdownV2'
|
|
||||||
|
|
||||||
await self.bot.edit_message_text(**args)
|
|
||||||
if is_final and bot_message.tool_calls is None:
|
if is_final and bot_message.tool_calls is None:
|
||||||
# self.seq = 1 # 消息回复结束之后重置seq
|
self.msg_stream_id.pop(message_id)
|
||||||
self.msg_stream_id.pop(message_id) # 消息回复结束之后删除流式消息id
|
return
|
||||||
|
|
||||||
|
content = components[0]['text']
|
||||||
|
|
||||||
|
if chat_mode == 'private':
|
||||||
|
args = self._build_message_args(chat_id, content, message_thread_id, draft_id=draft_id)
|
||||||
|
await self.bot.send_message_draft(**args)
|
||||||
|
if is_final and bot_message.tool_calls is None:
|
||||||
|
del args['draft_id']
|
||||||
|
await self.bot.send_message(**args)
|
||||||
|
self.msg_stream_id.pop(message_id)
|
||||||
|
else:
|
||||||
|
stream_id = draft_id
|
||||||
|
if (msg_seq - 1) % 8 == 0 or is_final:
|
||||||
|
args = {
|
||||||
|
'message_id': stream_id,
|
||||||
|
'chat_id': chat_id,
|
||||||
|
'text': self._process_markdown(content),
|
||||||
|
}
|
||||||
|
if self.config.get('markdown_card', False):
|
||||||
|
args['parse_mode'] = 'MarkdownV2'
|
||||||
|
await self.bot.edit_message_text(**args)
|
||||||
|
|
||||||
|
if is_final and bot_message.tool_calls is None:
|
||||||
|
self.msg_stream_id.pop(message_id)
|
||||||
|
|
||||||
def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None:
|
def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None:
|
||||||
if not isinstance(event.source_platform_object, Update):
|
if not isinstance(event.source_platform_object, Update):
|
||||||
|
|||||||
@@ -37,16 +37,24 @@ class WebSocketSession:
|
|||||||
id: str
|
id: str
|
||||||
message_lists: dict[str, list[WebSocketMessage]] = {}
|
message_lists: dict[str, list[WebSocketMessage]] = {}
|
||||||
"""消息列表 {pipeline_uuid: [messages]}"""
|
"""消息列表 {pipeline_uuid: [messages]}"""
|
||||||
|
stream_message_indexes: dict[str, dict[str, int]] = {}
|
||||||
|
"""流式消息索引 {pipeline_uuid: {resp_message_id: message_index}}"""
|
||||||
|
|
||||||
def __init__(self, id: str):
|
def __init__(self, id: str):
|
||||||
self.id = id
|
self.id = id
|
||||||
self.message_lists = {}
|
self.message_lists = {}
|
||||||
|
self.stream_message_indexes = {}
|
||||||
|
|
||||||
def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]:
|
def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]:
|
||||||
if pipeline_uuid not in self.message_lists:
|
if pipeline_uuid not in self.message_lists:
|
||||||
self.message_lists[pipeline_uuid] = []
|
self.message_lists[pipeline_uuid] = []
|
||||||
return self.message_lists[pipeline_uuid]
|
return self.message_lists[pipeline_uuid]
|
||||||
|
|
||||||
|
def get_stream_message_indexes(self, pipeline_uuid: str) -> dict[str, int]:
|
||||||
|
if pipeline_uuid not in self.stream_message_indexes:
|
||||||
|
self.stream_message_indexes[pipeline_uuid] = {}
|
||||||
|
return self.stream_message_indexes[pipeline_uuid]
|
||||||
|
|
||||||
|
|
||||||
class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||||
"""WebSocket适配器 - 支持双向实时通信"""
|
"""WebSocket适配器 - 支持双向实时通信"""
|
||||||
@@ -89,20 +97,46 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
|||||||
target_id: str,
|
target_id: str,
|
||||||
message: platform_message.MessageChain,
|
message: platform_message.MessageChain,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""发送消息 - 这里用于主动推送消息到前端"""
|
"""发送消息 - 这里用于主动推送消息到前端
|
||||||
message_data = {
|
|
||||||
'type': 'bot_message',
|
|
||||||
'target_type': target_type,
|
|
||||||
'target_id': target_id,
|
|
||||||
'content': str(message),
|
|
||||||
'message_chain': [component.__dict__ for component in message],
|
|
||||||
'timestamp': datetime.now().isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
# 推送到所有相关连接
|
对于 WebSocket 适配器,我们需要将消息广播到正确的 pipeline 连接。
|
||||||
await self.outbound_message_queue.put(message_data)
|
target_id 可能是 launcher_id(如 websocket_xxx)或 pipeline_uuid。
|
||||||
|
我们需要尝试两种方式来确保消息能够送达。
|
||||||
|
"""
|
||||||
|
# 获取当前的 pipeline_uuid
|
||||||
|
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||||||
|
session_type = 'group' if target_type == 'group' else 'person'
|
||||||
|
|
||||||
return message_data
|
# 选择会话
|
||||||
|
session = self.websocket_group_session if session_type == 'group' else self.websocket_person_session
|
||||||
|
|
||||||
|
# 生成唯一消息ID
|
||||||
|
msg_id = len(session.get_message_list(pipeline_uuid)) + 1
|
||||||
|
|
||||||
|
message_data = WebSocketMessage(
|
||||||
|
id=msg_id,
|
||||||
|
role='assistant',
|
||||||
|
content=str(message),
|
||||||
|
message_chain=[component.__dict__ for component in message],
|
||||||
|
timestamp=datetime.now().isoformat(),
|
||||||
|
is_final=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存到历史记录
|
||||||
|
session.get_message_list(pipeline_uuid).append(message_data)
|
||||||
|
|
||||||
|
# 直接广播到当前pipeline的连接
|
||||||
|
await ws_connection_manager.broadcast_to_pipeline(
|
||||||
|
pipeline_uuid,
|
||||||
|
{
|
||||||
|
'type': 'response',
|
||||||
|
'session_type': session_type,
|
||||||
|
'data': message_data.model_dump(),
|
||||||
|
},
|
||||||
|
session_type=session_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return message_data.model_dump()
|
||||||
|
|
||||||
async def reply_message(
|
async def reply_message(
|
||||||
self,
|
self,
|
||||||
@@ -169,10 +203,16 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
|||||||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||||||
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
|
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
|
||||||
message_list = session.get_message_list(pipeline_uuid)
|
message_list = session.get_message_list(pipeline_uuid)
|
||||||
|
stream_message_indexes = session.get_stream_message_indexes(pipeline_uuid)
|
||||||
|
|
||||||
# 检查是否是新的流式消息(通过bot_message对象判断)
|
# Streaming messages in LangBot have a stable resp_message_id during the same assistant reply.
|
||||||
# 如果列表为空,或者最后一条消息已经is_final=True,则创建新消息
|
# Use it as the primary key to avoid overwriting an old card from a previous reply.
|
||||||
if not message_list or message_list[-1].is_final:
|
resp_message_id = str(getattr(bot_message, 'resp_message_id', '') or '')
|
||||||
|
existing_index = stream_message_indexes.get(resp_message_id) if resp_message_id else None
|
||||||
|
|
||||||
|
message_is_final = is_final and bot_message.tool_calls is None
|
||||||
|
|
||||||
|
if existing_index is None or existing_index >= len(message_list):
|
||||||
# 创建新消息
|
# 创建新消息
|
||||||
msg_id = len(message_list) + 1
|
msg_id = len(message_list) + 1
|
||||||
message_data = WebSocketMessage(
|
message_data = WebSocketMessage(
|
||||||
@@ -181,27 +221,31 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
|||||||
content=str(message),
|
content=str(message),
|
||||||
message_chain=[component.__dict__ for component in message],
|
message_chain=[component.__dict__ for component in message],
|
||||||
timestamp=datetime.now().isoformat(),
|
timestamp=datetime.now().isoformat(),
|
||||||
is_final=is_final and bot_message.tool_calls is None,
|
is_final=message_is_final,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 只有在is_final时才保存到历史记录
|
# 立即添加到历史记录(即使is_final=False),以便后续块可以更新它
|
||||||
if is_final and bot_message.tool_calls is None:
|
message_list.append(message_data)
|
||||||
message_list.append(message_data)
|
if resp_message_id:
|
||||||
|
stream_message_indexes[resp_message_id] = len(message_list) - 1
|
||||||
else:
|
else:
|
||||||
# 更新最后一条消息
|
# 更新同一条流式消息
|
||||||
msg_id = message_list[-1].id
|
old_message = message_list[existing_index]
|
||||||
|
msg_id = old_message.id
|
||||||
message_data = WebSocketMessage(
|
message_data = WebSocketMessage(
|
||||||
id=msg_id,
|
id=msg_id,
|
||||||
role='assistant',
|
role='assistant',
|
||||||
content=str(message),
|
content=str(message),
|
||||||
message_chain=[component.__dict__ for component in message],
|
message_chain=[component.__dict__ for component in message],
|
||||||
timestamp=message_list[-1].timestamp, # 保持原始时间戳
|
timestamp=old_message.timestamp, # 保持原始时间戳
|
||||||
is_final=is_final and bot_message.tool_calls is None,
|
is_final=message_is_final,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 如果是final,更新历史记录中的最后一条
|
# 更新历史记录中的对应消息
|
||||||
if is_final and bot_message.tool_calls is None:
|
message_list[existing_index] = message_data
|
||||||
message_list[-1] = message_data
|
|
||||||
|
if message_is_final and resp_message_id:
|
||||||
|
stream_message_indexes.pop(resp_message_id, None)
|
||||||
|
|
||||||
# 直接广播到所有该pipeline的连接,包含session_type信息
|
# 直接广播到所有该pipeline的连接,包含session_type信息
|
||||||
await ws_connection_manager.broadcast_to_pipeline(
|
await ws_connection_manager.broadcast_to_pipeline(
|
||||||
@@ -410,6 +454,10 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
|||||||
if session_type == 'person':
|
if session_type == 'person':
|
||||||
if pipeline_uuid in self.websocket_person_session.message_lists:
|
if pipeline_uuid in self.websocket_person_session.message_lists:
|
||||||
self.websocket_person_session.message_lists[pipeline_uuid] = []
|
self.websocket_person_session.message_lists[pipeline_uuid] = []
|
||||||
|
if pipeline_uuid in self.websocket_person_session.stream_message_indexes:
|
||||||
|
self.websocket_person_session.stream_message_indexes[pipeline_uuid] = {}
|
||||||
else:
|
else:
|
||||||
if pipeline_uuid in self.websocket_group_session.message_lists:
|
if pipeline_uuid in self.websocket_group_session.message_lists:
|
||||||
self.websocket_group_session.message_lists[pipeline_uuid] = []
|
self.websocket_group_session.message_lists[pipeline_uuid] = []
|
||||||
|
if pipeline_uuid in self.websocket_group_session.stream_message_indexes:
|
||||||
|
self.websocket_group_session.stream_message_indexes[pipeline_uuid] = {}
|
||||||
|
|||||||
BIN
src/langbot/pkg/platform/sources/wechat.png
Normal file
BIN
src/langbot/pkg/platform/sources/wechat.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 466 KiB |
@@ -148,51 +148,54 @@ class WecomEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if type(event) is platform_events.FriendMessage:
|
if type(event) is platform_events.FriendMessage:
|
||||||
payload = {
|
return event.source_platform_object
|
||||||
'MsgType': 'text',
|
|
||||||
'Content': '',
|
|
||||||
'FromUserName': event.sender.id,
|
|
||||||
'ToUserName': bot_account_id,
|
|
||||||
'CreateTime': int(datetime.datetime.now().timestamp()),
|
|
||||||
'AgentID': event.sender.nickname,
|
|
||||||
}
|
|
||||||
wecom_event = WecomEvent.from_payload(payload=payload)
|
|
||||||
if not wecom_event:
|
|
||||||
raise ValueError('无法从 message_data 构造 WecomEvent 对象')
|
|
||||||
|
|
||||||
return wecom_event
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def target2yiri(event: WecomEvent):
|
async def target2yiri(event: WecomEvent, bot: WecomClient = None):
|
||||||
"""
|
"""
|
||||||
将 WecomEvent 转换为平台的 FriendMessage 对象。
|
将 WecomEvent 转换为平台的 FriendMessage 对象。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (WecomEvent): 企业微信事件。
|
event (WecomEvent): 企业微信事件。
|
||||||
|
bot (WecomClient): 企业微信客户端,用于获取用户信息。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
platform_events.FriendMessage: 转换后的 FriendMessage 对象。
|
platform_events.FriendMessage: 转换后的 FriendMessage 对象。
|
||||||
"""
|
"""
|
||||||
|
# Try to get the user's real name from the WeCom API
|
||||||
|
nickname = str(event.user_id)
|
||||||
|
if bot and event.user_id:
|
||||||
|
try:
|
||||||
|
user_info = await bot.get_user_info(event.user_id)
|
||||||
|
if user_info and user_info.get('name'):
|
||||||
|
nickname = user_info.get('name')
|
||||||
|
except Exception:
|
||||||
|
pass # Fall back to user_id as nickname
|
||||||
|
|
||||||
# 转换消息链
|
# 转换消息链
|
||||||
if event.type == 'text':
|
if event.type == 'text':
|
||||||
yiri_chain = await WecomMessageConverter.target2yiri(event.message, event.message_id)
|
yiri_chain = await WecomMessageConverter.target2yiri(event.message, event.message_id)
|
||||||
friend = platform_entities.Friend(
|
friend = platform_entities.Friend(
|
||||||
id=f'u{event.user_id}',
|
id=f'u{event.user_id}',
|
||||||
nickname=str(event.agent_id),
|
nickname=nickname,
|
||||||
remark='',
|
remark='',
|
||||||
)
|
)
|
||||||
|
|
||||||
return platform_events.FriendMessage(sender=friend, message_chain=yiri_chain, time=event.timestamp)
|
return platform_events.FriendMessage(
|
||||||
|
sender=friend, message_chain=yiri_chain, time=event.timestamp, source_platform_object=event
|
||||||
|
)
|
||||||
elif event.type == 'image':
|
elif event.type == 'image':
|
||||||
friend = platform_entities.Friend(
|
friend = platform_entities.Friend(
|
||||||
id=f'u{event.user_id}',
|
id=f'u{event.user_id}',
|
||||||
nickname=str(event.agent_id),
|
nickname=nickname,
|
||||||
remark='',
|
remark='',
|
||||||
)
|
)
|
||||||
|
|
||||||
yiri_chain = await WecomMessageConverter.target2yiri_image(picurl=event.picurl, message_id=event.message_id)
|
yiri_chain = await WecomMessageConverter.target2yiri_image(picurl=event.picurl, message_id=event.message_id)
|
||||||
|
|
||||||
return platform_events.FriendMessage(sender=friend, message_chain=yiri_chain, time=event.timestamp)
|
return platform_events.FriendMessage(
|
||||||
|
sender=friend, message_chain=yiri_chain, time=event.timestamp, source_platform_object=event
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||||
@@ -210,7 +213,6 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
'secret',
|
'secret',
|
||||||
'token',
|
'token',
|
||||||
'EncodingAESKey',
|
'EncodingAESKey',
|
||||||
'contacts_secret',
|
|
||||||
]
|
]
|
||||||
|
|
||||||
missing_keys = [key for key in required_keys if key not in config]
|
missing_keys = [key for key in required_keys if key not in config]
|
||||||
@@ -223,7 +225,7 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
secret=config['secret'],
|
secret=config['secret'],
|
||||||
token=config['token'],
|
token=config['token'],
|
||||||
EncodingAESKey=config['EncodingAESKey'],
|
EncodingAESKey=config['EncodingAESKey'],
|
||||||
contacts_secret=config['contacts_secret'],
|
contacts_secret=config.get('contacts_secret', ''), # Optional, kept for backward compatibility
|
||||||
logger=logger,
|
logger=logger,
|
||||||
unified_mode=True,
|
unified_mode=True,
|
||||||
api_base_url=config.get('api_base_url', 'https://qyapi.weixin.qq.com/cgi-bin'),
|
api_base_url=config.get('api_base_url', 'https://qyapi.weixin.qq.com/cgi-bin'),
|
||||||
@@ -248,18 +250,17 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
):
|
):
|
||||||
Wecom_event = await WecomEventConverter.yiri2target(message_source, self.bot_account_id, self.bot)
|
Wecom_event = await WecomEventConverter.yiri2target(message_source, self.bot_account_id, self.bot)
|
||||||
content_list = await WecomMessageConverter.yiri2target(message, self.bot)
|
content_list = await WecomMessageConverter.yiri2target(message, self.bot)
|
||||||
fixed_user_id = Wecom_event.user_id
|
# user_id is the original FromUserName from WecomEvent
|
||||||
# 删掉开头的u
|
user_id = Wecom_event.user_id
|
||||||
fixed_user_id = fixed_user_id[1:]
|
|
||||||
for content in content_list:
|
for content in content_list:
|
||||||
if content['type'] == 'text':
|
if content['type'] == 'text':
|
||||||
await self.bot.send_private_msg(fixed_user_id, Wecom_event.agent_id, content['content'])
|
await self.bot.send_private_msg(user_id, Wecom_event.agent_id, content['content'])
|
||||||
elif content['type'] == 'image':
|
elif content['type'] == 'image':
|
||||||
await self.bot.send_image(fixed_user_id, Wecom_event.agent_id, content['media_id'])
|
await self.bot.send_image(user_id, Wecom_event.agent_id, content['media_id'])
|
||||||
elif content['type'] == 'voice':
|
elif content['type'] == 'voice':
|
||||||
await self.bot.send_voice(fixed_user_id, Wecom_event.agent_id, content['media_id'])
|
await self.bot.send_voice(user_id, Wecom_event.agent_id, content['media_id'])
|
||||||
elif content['type'] == 'file':
|
elif content['type'] == 'file':
|
||||||
await self.bot.send_file(fixed_user_id, Wecom_event.agent_id, content['media_id'])
|
await self.bot.send_file(user_id, Wecom_event.agent_id, content['media_id'])
|
||||||
|
|
||||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||||
content_list = await WecomMessageConverter.yiri2target(message, self.bot)
|
content_list = await WecomMessageConverter.yiri2target(message, self.bot)
|
||||||
@@ -287,7 +288,7 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
async def on_message(event: WecomEvent):
|
async def on_message(event: WecomEvent):
|
||||||
self.bot_account_id = event.receiver_id
|
self.bot_account_id = event.receiver_id
|
||||||
try:
|
try:
|
||||||
return await callback(await self.event_converter.target2yiri(event), self)
|
return await callback(await self.event_converter.target2yiri(event, self.bot), self)
|
||||||
except Exception:
|
except Exception:
|
||||||
await self.logger.error(f'Error in wecom callback: {traceback.format_exc()}')
|
await self.logger.error(f'Error in wecom callback: {traceback.format_exc()}')
|
||||||
|
|
||||||
|
|||||||
@@ -39,13 +39,6 @@ spec:
|
|||||||
type: string
|
type: string
|
||||||
required: true
|
required: true
|
||||||
default: ""
|
default: ""
|
||||||
- name: contacts_secret
|
|
||||||
label:
|
|
||||||
en_US: Contacts Secret
|
|
||||||
zh_Hans: 通讯录密钥
|
|
||||||
type: string
|
|
||||||
required: true
|
|
||||||
default: ""
|
|
||||||
- name: api_base_url
|
- name: api_base_url
|
||||||
label:
|
label:
|
||||||
en_US: API Base URL
|
en_US: API Base URL
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import langbot_plugin.api.entities.builtin.platform.entities as platform_entitie
|
|||||||
from ..logger import EventLogger
|
from ..logger import EventLogger
|
||||||
from langbot.libs.wecom_ai_bot_api.wecombotevent import WecomBotEvent
|
from langbot.libs.wecom_ai_bot_api.wecombotevent import WecomBotEvent
|
||||||
from langbot.libs.wecom_ai_bot_api.api import WecomBotClient
|
from langbot.libs.wecom_ai_bot_api.api import WecomBotClient
|
||||||
|
from langbot.libs.wecom_ai_bot_api.ws_client import WecomBotWsClient
|
||||||
|
|
||||||
|
|
||||||
class WecomBotMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
class WecomBotMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||||
@@ -23,14 +24,18 @@ class WecomBotMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def target2yiri(event: WecomBotEvent):
|
async def target2yiri(event: WecomBotEvent, bot_name: str = ''):
|
||||||
yiri_msg_list = []
|
yiri_msg_list = []
|
||||||
if event.type == 'group':
|
if event.type == 'group':
|
||||||
yiri_msg_list.append(platform_message.At(target=event.ai_bot_id))
|
yiri_msg_list.append(platform_message.At(target=event.ai_bot_id))
|
||||||
|
|
||||||
yiri_msg_list.append(platform_message.Source(id=event.message_id, time=datetime.datetime.now()))
|
yiri_msg_list.append(platform_message.Source(id=event.message_id, time=datetime.datetime.now()))
|
||||||
|
|
||||||
if event.content:
|
if event.content:
|
||||||
yiri_msg_list.append(platform_message.Plain(text=event.content))
|
content = event.content
|
||||||
|
if bot_name:
|
||||||
|
content = content.replace(f'@{bot_name}', '').strip()
|
||||||
|
yiri_msg_list.append(platform_message.Plain(text=content))
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
if event.images:
|
if event.images:
|
||||||
@@ -133,13 +138,15 @@ class WecomBotMessageConverter(abstract_platform_adapter.AbstractMessageConverte
|
|||||||
|
|
||||||
|
|
||||||
class WecomBotEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
class WecomBotEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||||
|
def __init__(self, bot_name: str = ''):
|
||||||
|
self.bot_name = bot_name
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def yiri2target(event: platform_events.MessageEvent):
|
async def yiri2target(event: platform_events.MessageEvent):
|
||||||
return event.source_platform_object
|
return event.source_platform_object
|
||||||
|
|
||||||
@staticmethod
|
async def target2yiri(self, event: WecomBotEvent):
|
||||||
async def target2yiri(event: WecomBotEvent):
|
message_chain = await WecomBotMessageConverter.target2yiri(event, bot_name=self.bot_name)
|
||||||
message_chain = await WecomBotMessageConverter.target2yiri(event)
|
|
||||||
if event.type == 'single':
|
if event.type == 'single':
|
||||||
return platform_events.FriendMessage(
|
return platform_events.FriendMessage(
|
||||||
sender=platform_entities.Friend(
|
sender=platform_entities.Friend(
|
||||||
@@ -176,34 +183,53 @@ class WecomBotEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
|||||||
|
|
||||||
|
|
||||||
class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||||
bot: WecomBotClient
|
bot: typing.Union[WecomBotClient, WecomBotWsClient]
|
||||||
bot_account_id: str
|
bot_account_id: str
|
||||||
message_converter: WecomBotMessageConverter = WecomBotMessageConverter()
|
message_converter: WecomBotMessageConverter = WecomBotMessageConverter()
|
||||||
event_converter: WecomBotEventConverter = WecomBotEventConverter()
|
event_converter: WecomBotEventConverter
|
||||||
config: dict
|
config: dict
|
||||||
bot_uuid: str = None
|
bot_uuid: str = None
|
||||||
|
_ws_mode: bool = False
|
||||||
|
bot_name: str = ''
|
||||||
|
listeners: dict = {}
|
||||||
|
|
||||||
def __init__(self, config: dict, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
required_keys = ['Token', 'EncodingAESKey', 'Corpid', 'BotId']
|
enable_webhook = config.get('enable-webhook', False)
|
||||||
missing_keys = [key for key in required_keys if key not in config]
|
bot_name = config.get('robot_name', '')
|
||||||
if missing_keys:
|
|
||||||
raise Exception(f'WecomBot 缺少配置项: {missing_keys}')
|
|
||||||
|
|
||||||
bot = WecomBotClient(
|
if not enable_webhook:
|
||||||
Token=config['Token'],
|
bot = WecomBotWsClient(
|
||||||
EnCodingAESKey=config['EncodingAESKey'],
|
bot_id=config['BotId'],
|
||||||
Corpid=config['Corpid'],
|
secret=config['Secret'],
|
||||||
logger=logger,
|
logger=logger,
|
||||||
unified_mode=True,
|
encoding_aes_key=config.get('EncodingAESKey', ''),
|
||||||
)
|
)
|
||||||
bot_account_id = config['BotId']
|
else:
|
||||||
|
# Webhook callback mode
|
||||||
|
required_keys = ['Token', 'EncodingAESKey', 'Corpid']
|
||||||
|
missing_keys = [key for key in required_keys if key not in config or not config[key]]
|
||||||
|
if missing_keys:
|
||||||
|
raise Exception(f'WecomBot webhook mode missing config: {missing_keys}')
|
||||||
|
|
||||||
|
bot = WecomBotClient(
|
||||||
|
Token=config['Token'],
|
||||||
|
EnCodingAESKey=config['EncodingAESKey'],
|
||||||
|
Corpid=config['Corpid'],
|
||||||
|
logger=logger,
|
||||||
|
unified_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
bot_account_id = config.get('BotId', '')
|
||||||
|
event_converter = WecomBotEventConverter(bot_name=bot_name)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
config=config,
|
config=config,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
bot=bot,
|
bot=bot,
|
||||||
bot_account_id=bot_account_id,
|
bot_account_id=bot_account_id,
|
||||||
|
bot_name=bot_name,
|
||||||
|
event_converter=event_converter,
|
||||||
)
|
)
|
||||||
|
self.listeners = {}
|
||||||
|
|
||||||
async def reply_message(
|
async def reply_message(
|
||||||
self,
|
self,
|
||||||
@@ -212,7 +238,17 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
quote_origin: bool = False,
|
quote_origin: bool = False,
|
||||||
):
|
):
|
||||||
content = await self.message_converter.yiri2target(message)
|
content = await self.message_converter.yiri2target(message)
|
||||||
await self.bot.set_message(message_source.source_platform_object.message_id, content)
|
_ws_mode = not self.config.get('enable-webhook', False)
|
||||||
|
|
||||||
|
if _ws_mode:
|
||||||
|
event = message_source.source_platform_object
|
||||||
|
req_id = event.get('req_id', '')
|
||||||
|
if req_id:
|
||||||
|
await self.bot.reply_text(req_id, content)
|
||||||
|
else:
|
||||||
|
await self.bot.set_message(event.message_id, content)
|
||||||
|
else:
|
||||||
|
await self.bot.set_message(message_source.source_platform_object.message_id, content)
|
||||||
|
|
||||||
async def reply_message_chunk(
|
async def reply_message_chunk(
|
||||||
self,
|
self,
|
||||||
@@ -222,31 +258,23 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
quote_origin: bool = False,
|
quote_origin: bool = False,
|
||||||
is_final: bool = False,
|
is_final: bool = False,
|
||||||
):
|
):
|
||||||
"""将流水线增量输出写入企业微信 stream 会话。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_source: 流水线提供的原始消息事件。
|
|
||||||
bot_message: 当前片段对应的模型元信息(未使用)。
|
|
||||||
message: 需要回复的消息链。
|
|
||||||
quote_origin: 是否引用原消息(企业微信暂不支持)。
|
|
||||||
is_final: 标记当前片段是否为最终回复。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 包含 `stream` 键,标识写入是否成功。
|
|
||||||
|
|
||||||
Example:
|
|
||||||
在流水线 `reply_message_chunk` 调用中自动触发,无需手动调用。
|
|
||||||
"""
|
|
||||||
# 转换为纯文本(智能机器人当前协议仅支持文本流)
|
|
||||||
content = await self.message_converter.yiri2target(message)
|
content = await self.message_converter.yiri2target(message)
|
||||||
msg_id = message_source.source_platform_object.message_id
|
msg_id = message_source.source_platform_object.message_id
|
||||||
|
_ws_mode = not self.config.get('enable-webhook', False)
|
||||||
|
|
||||||
# 将片段推送到 WecomBotClient 中的队列,返回值用于判断是否走降级逻辑
|
if _ws_mode:
|
||||||
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
|
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
|
||||||
if not success and is_final:
|
if not success and is_final:
|
||||||
# 未命中流式队列时使用旧有 set_message 兜底
|
event = message_source.source_platform_object
|
||||||
await self.bot.set_message(msg_id, content)
|
req_id = event.get('req_id', '')
|
||||||
return {'stream': success}
|
if req_id:
|
||||||
|
await self.bot.reply_text(req_id, content)
|
||||||
|
return {'stream': success}
|
||||||
|
else:
|
||||||
|
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
|
||||||
|
if not success and is_final:
|
||||||
|
await self.bot.set_message(msg_id, content)
|
||||||
|
return {'stream': success}
|
||||||
|
|
||||||
async def is_stream_output_supported(self) -> bool:
|
async def is_stream_output_supported(self) -> bool:
|
||||||
"""智能机器人侧默认开启流式能力。
|
"""智能机器人侧默认开启流式能力。
|
||||||
@@ -259,7 +287,21 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
async def send_message(self, target_type, target_id, message):
|
async def send_message(self, target_type, target_id, message):
|
||||||
pass
|
_ws_mode = not self.config.get('enable-webhook', False)
|
||||||
|
if _ws_mode:
|
||||||
|
content = await self.message_converter.yiri2target(message)
|
||||||
|
await self.bot.send_message(target_id, content)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_message(self, event: WecomBotEvent):
|
||||||
|
try:
|
||||||
|
lb_event = await self.event_converter.target2yiri(event)
|
||||||
|
if lb_event:
|
||||||
|
await self.listeners[type(lb_event)](lb_event, self)
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Error in wecombot callback: {traceback.format_exc()}')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
def register_listener(
|
def register_listener(
|
||||||
self,
|
self,
|
||||||
@@ -268,18 +310,13 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
async def on_message(event: WecomBotEvent):
|
self.listeners[event_type] = callback
|
||||||
try:
|
|
||||||
return await callback(await self.event_converter.target2yiri(event), self)
|
|
||||||
except Exception:
|
|
||||||
await self.logger.error(f'Error in wecombot callback: {traceback.format_exc()}')
|
|
||||||
print(traceback.format_exc())
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if event_type == platform_events.FriendMessage:
|
if event_type == platform_events.FriendMessage:
|
||||||
self.bot.on_message('single')(on_message)
|
self.bot.on_message('single')(self.on_message)
|
||||||
elif event_type == platform_events.GroupMessage:
|
elif event_type == platform_events.GroupMessage:
|
||||||
self.bot.on_message('group')(on_message)
|
self.bot.on_message('group')(self.on_message)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
@@ -288,29 +325,28 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
self.bot_uuid = bot_uuid
|
self.bot_uuid = bot_uuid
|
||||||
|
|
||||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||||
"""处理统一 webhook 请求。
|
_ws_mode = not self.config.get('enable-webhook', False)
|
||||||
|
if _ws_mode:
|
||||||
Args:
|
return None
|
||||||
bot_uuid: Bot 的 UUID
|
|
||||||
path: 子路径(如果有的话)
|
|
||||||
request: Quart Request 对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
响应数据
|
|
||||||
"""
|
|
||||||
return await self.bot.handle_unified_webhook(request)
|
return await self.bot.handle_unified_webhook(request)
|
||||||
|
|
||||||
async def run_async(self):
|
async def run_async(self):
|
||||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
_ws_mode = not self.config.get('enable-webhook', False)
|
||||||
# 保持运行但不启动独立端口
|
if _ws_mode:
|
||||||
|
await self.bot.connect()
|
||||||
|
else:
|
||||||
|
|
||||||
async def keep_alive():
|
async def keep_alive():
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
await keep_alive()
|
await keep_alive()
|
||||||
|
|
||||||
async def kill(self) -> bool:
|
async def kill(self) -> bool:
|
||||||
|
_ws_mode = not self.config.get('enable-webhook', False)
|
||||||
|
if _ws_mode:
|
||||||
|
await self.bot.disconnect()
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def unregister_listener(
|
async def unregister_listener(
|
||||||
|
|||||||
@@ -11,35 +11,71 @@ metadata:
|
|||||||
icon: wecombot.png
|
icon: wecombot.png
|
||||||
spec:
|
spec:
|
||||||
config:
|
config:
|
||||||
|
- name: BotId
|
||||||
|
label:
|
||||||
|
en_US: BotId
|
||||||
|
zh_Hans: 机器人ID (BotId)
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: ""
|
||||||
|
- name: robot_name
|
||||||
|
label:
|
||||||
|
en_US: Robot Name
|
||||||
|
zh_Hans: 机器人名称
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: ""
|
||||||
|
- name: enable-webhook
|
||||||
|
label:
|
||||||
|
en_US: Enable Webhook Mode
|
||||||
|
zh_Hans: 启用Webhook模式
|
||||||
|
description:
|
||||||
|
en_US: If enabled, the bot will use webhook mode to receive messages. Otherwise, it will use WS long connection mode
|
||||||
|
zh_Hans: 如果启用,机器人将使用 Webhook 模式接收消息。否则,将使用 WS 长连接模式
|
||||||
|
type: boolean
|
||||||
|
required: true
|
||||||
|
default: false
|
||||||
|
- name: Secret
|
||||||
|
label:
|
||||||
|
en_US: Secret
|
||||||
|
zh_Hans: 机器人密钥 (Secret)
|
||||||
|
description:
|
||||||
|
en_US: Required for WebSocket long connection mode
|
||||||
|
zh_Hans: 使用 WS 长连接模式时必填
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
default: ""
|
||||||
- name: Corpid
|
- name: Corpid
|
||||||
label:
|
label:
|
||||||
en_US: Corpid
|
en_US: Corpid
|
||||||
zh_Hans: 企业ID
|
zh_Hans: 企业ID
|
||||||
|
description:
|
||||||
|
en_US: Required for Webhook mode
|
||||||
|
zh_Hans: 使用 Webhook 模式时必填
|
||||||
type: string
|
type: string
|
||||||
required: true
|
required: false
|
||||||
default: ""
|
default: ""
|
||||||
- name: Token
|
- name: Token
|
||||||
label:
|
label:
|
||||||
en_US: Token
|
en_US: Token
|
||||||
zh_Hans: 令牌 (Token)
|
zh_Hans: 令牌 (Token)
|
||||||
|
description:
|
||||||
|
en_US: Required for Webhook mode
|
||||||
|
zh_Hans: 使用 Webhook 模式时必填
|
||||||
type: string
|
type: string
|
||||||
required: true
|
required: false
|
||||||
default: ""
|
default: ""
|
||||||
- name: EncodingAESKey
|
- name: EncodingAESKey
|
||||||
label:
|
label:
|
||||||
en_US: EncodingAESKey
|
en_US: EncodingAESKey
|
||||||
zh_Hans: 消息加解密密钥 (EncodingAESKey)
|
zh_Hans: 消息加解密密钥 (EncodingAESKey)
|
||||||
type: string
|
description:
|
||||||
required: true
|
en_US: Required for Webhook mode. Optional for WebSocket mode (used for file decryption)
|
||||||
default: ""
|
zh_Hans: 使用 Webhook 模式时必填。WebSocket 模式下可选(用于文件解密)
|
||||||
- name: BotId
|
|
||||||
label:
|
|
||||||
en_US: BotId
|
|
||||||
zh_Hans: 机器人ID
|
|
||||||
type: string
|
type: string
|
||||||
required: false
|
required: false
|
||||||
default: ""
|
default: ""
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
path: ./wecombot.py
|
path: ./wecombot.py
|
||||||
attr: WecomBotAdapter
|
attr: WecomBotAdapter
|
||||||
|
|||||||
@@ -81,22 +81,33 @@ class WecomEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
|||||||
return event.source_platform_object
|
return event.source_platform_object
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def target2yiri(event: WecomCSEvent):
|
async def target2yiri(event: WecomCSEvent, bot: WecomCSClient = None):
|
||||||
"""
|
"""
|
||||||
将 WecomEvent 转换为平台的 FriendMessage 对象。
|
将 WecomEvent 转换为平台的 FriendMessage 对象。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (WecomEvent): 企业微信客服事件。
|
event (WecomEvent): 企业微信客服事件。
|
||||||
|
bot (WecomCSClient): 企业微信客服客户端,用于获取用户信息。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
platform_events.FriendMessage: 转换后的 FriendMessage 对象。
|
platform_events.FriendMessage: 转换后的 FriendMessage 对象。
|
||||||
"""
|
"""
|
||||||
|
# Try to get customer nickname from WeChat API
|
||||||
|
nickname = str(event.user_id)
|
||||||
|
if bot and event.user_id:
|
||||||
|
try:
|
||||||
|
customer_info = await bot.get_customer_info(event.user_id)
|
||||||
|
if customer_info and customer_info.get('nickname'):
|
||||||
|
nickname = customer_info.get('nickname')
|
||||||
|
except Exception:
|
||||||
|
pass # Fall back to user_id as nickname
|
||||||
|
|
||||||
# 转换消息链
|
# 转换消息链
|
||||||
if event.type == 'text':
|
if event.type == 'text':
|
||||||
yiri_chain = await WecomMessageConverter.target2yiri(event.message, event.message_id)
|
yiri_chain = await WecomMessageConverter.target2yiri(event.message, event.message_id)
|
||||||
friend = platform_entities.Friend(
|
friend = platform_entities.Friend(
|
||||||
id=f'u{event.user_id}',
|
id=f'u{event.user_id}',
|
||||||
nickname=str(event.user_id),
|
nickname=nickname,
|
||||||
remark='',
|
remark='',
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -106,7 +117,7 @@ class WecomEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
|||||||
elif event.type == 'image':
|
elif event.type == 'image':
|
||||||
friend = platform_entities.Friend(
|
friend = platform_entities.Friend(
|
||||||
id=f'u{event.user_id}',
|
id=f'u{event.user_id}',
|
||||||
nickname=str(event.user_id),
|
nickname=nickname,
|
||||||
remark='',
|
remark='',
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -187,7 +198,7 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
async def on_message(event: WecomCSEvent):
|
async def on_message(event: WecomCSEvent):
|
||||||
self.bot_account_id = event.receiver_id
|
self.bot_account_id = event.receiver_id
|
||||||
try:
|
try:
|
||||||
return await callback(await self.event_converter.target2yiri(event), self)
|
return await callback(await self.event_converter.target2yiri(event, self.bot), self)
|
||||||
except Exception:
|
except Exception:
|
||||||
await self.logger.error(f'Error in wecomcs callback: {traceback.format_exc()}')
|
await self.logger.error(f'Error in wecomcs callback: {traceback.format_exc()}')
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
from langbot.pkg.utils import httpclient
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@@ -119,23 +121,23 @@ class WebhookPusher:
|
|||||||
dict | None: The response JSON if successful, None otherwise
|
dict | None: The response JSON if successful, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.post(
|
async with session.post(
|
||||||
url,
|
url,
|
||||||
json=payload,
|
json=payload,
|
||||||
headers={'Content-Type': 'application/json'},
|
headers={'Content-Type': 'application/json'},
|
||||||
timeout=aiohttp.ClientTimeout(total=15),
|
timeout=aiohttp.ClientTimeout(total=15),
|
||||||
) as response:
|
) as response:
|
||||||
if response.status >= 400:
|
if response.status >= 400:
|
||||||
self.logger.warning(f'Webhook {url} returned status {response.status}')
|
self.logger.warning(f'Webhook {url} returned status {response.status}')
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
self.logger.debug(f'Successfully pushed to webhook {url}')
|
||||||
|
try:
|
||||||
|
return await response.json()
|
||||||
|
except Exception as json_error:
|
||||||
|
self.logger.debug(f'Failed to parse JSON response from webhook {url}: {json_error}')
|
||||||
return None
|
return None
|
||||||
else:
|
|
||||||
self.logger.debug(f'Successfully pushed to webhook {url}')
|
|
||||||
try:
|
|
||||||
return await response.json()
|
|
||||||
except Exception as json_error:
|
|
||||||
self.logger.debug(f'Failed to parse JSON response from webhook {url}: {json_error}')
|
|
||||||
return None
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
self.logger.warning(f'Timeout pushing to webhook {url}')
|
self.logger.warning(f'Timeout pushing to webhook {url}')
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import typing
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import httpx
|
import httpx
|
||||||
import traceback
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from async_lru import alru_cache
|
from async_lru import alru_cache
|
||||||
from langbot_plugin.api.entities.builtin.pipeline.query import provider_session
|
from langbot_plugin.api.entities.builtin.pipeline.query import provider_session
|
||||||
@@ -102,12 +101,6 @@ class PluginRuntimeConnector:
|
|||||||
self.handler_task = asyncio.create_task(self.handler.run())
|
self.handler_task = asyncio.create_task(self.handler.run())
|
||||||
_ = await self.handler.ping()
|
_ = await self.handler.ping()
|
||||||
self.ap.logger.info('Connected to plugin runtime.')
|
self.ap.logger.info('Connected to plugin runtime.')
|
||||||
# Sync polymorphic component instances after connection
|
|
||||||
try:
|
|
||||||
await self.sync_polymorphic_component_instances()
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
self.ap.logger.error(f'Failed to sync polymorphic component instances: {e}')
|
|
||||||
await self.handler_task
|
await self.handler_task
|
||||||
|
|
||||||
task: asyncio.Task | None = None
|
task: asyncio.Task | None = None
|
||||||
@@ -463,30 +456,18 @@ class PluginRuntimeConnector:
|
|||||||
|
|
||||||
yield cmd_ret
|
yield cmd_ret
|
||||||
|
|
||||||
# KnowledgeRetriever methods
|
|
||||||
async def list_knowledge_retrievers(self, bound_plugins: list[str] | None = None) -> list[dict[str, Any]]:
|
|
||||||
"""List all available KnowledgeRetriever components."""
|
|
||||||
if not self.is_enable_plugin:
|
|
||||||
return []
|
|
||||||
|
|
||||||
retrievers_data = await self.handler.list_knowledge_retrievers(include_plugins=bound_plugins)
|
|
||||||
return retrievers_data
|
|
||||||
|
|
||||||
async def retrieve_knowledge(
|
async def retrieve_knowledge(
|
||||||
self,
|
self,
|
||||||
plugin_author: str,
|
plugin_author: str,
|
||||||
plugin_name: str,
|
plugin_name: str,
|
||||||
retriever_name: str,
|
retriever_name: str,
|
||||||
instance_id: str,
|
|
||||||
retrieval_context: dict[str, Any],
|
retrieval_context: dict[str, Any],
|
||||||
) -> list[dict[str, Any]]:
|
) -> dict[str, Any]:
|
||||||
"""Retrieve knowledge using a KnowledgeRetriever instance."""
|
"""Retrieve knowledge using a KnowledgeEngine instance."""
|
||||||
if not self.is_enable_plugin:
|
if not self.is_enable_plugin:
|
||||||
return []
|
return {'results': []}
|
||||||
|
|
||||||
return await self.handler.retrieve_knowledge(
|
return await self.handler.retrieve_knowledge(plugin_author, plugin_name, retriever_name, retrieval_context)
|
||||||
plugin_author, plugin_name, retriever_name, instance_id, retrieval_context
|
|
||||||
)
|
|
||||||
|
|
||||||
def dispose(self):
|
def dispose(self):
|
||||||
# No need to consider the shutdown on Windows
|
# No need to consider the shutdown on Windows
|
||||||
@@ -500,41 +481,84 @@ class PluginRuntimeConnector:
|
|||||||
self.heartbeat_task.cancel()
|
self.heartbeat_task.cancel()
|
||||||
self.heartbeat_task = None
|
self.heartbeat_task = None
|
||||||
|
|
||||||
async def sync_polymorphic_component_instances(self) -> dict[str, Any]:
|
@staticmethod
|
||||||
"""Sync polymorphic component instances with runtime.
|
def _parse_plugin_id(plugin_id: str) -> tuple[str, str]:
|
||||||
|
"""Parse a plugin ID string into (author, name).
|
||||||
|
|
||||||
This collects all external knowledge bases from database and sends to runtime
|
Args:
|
||||||
to ensure instance integrity across restarts.
|
plugin_id: Plugin ID in 'author/name' format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (plugin_author, plugin_name).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If plugin_id is not in the expected 'author/name' format.
|
||||||
|
"""
|
||||||
|
if '/' not in plugin_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid plugin_id format: '{plugin_id}'. Expected 'author/name' format (e.g. 'langbot/rag-engine')."
|
||||||
|
)
|
||||||
|
return plugin_id.split('/', 1)
|
||||||
|
|
||||||
|
async def call_rag_ingest(self, plugin_id: str, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Call plugin to ingest document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: Target plugin ID (author/name).
|
||||||
|
context_data: IngestionContext data.
|
||||||
|
"""
|
||||||
|
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||||
|
return await self.handler.rag_ingest_document(plugin_author, plugin_name, context_data)
|
||||||
|
|
||||||
|
async def call_rag_delete_document(self, plugin_id: str, document_id: str, kb_id: str) -> bool:
|
||||||
|
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||||
|
return await self.handler.rag_delete_document(plugin_author, plugin_name, document_id, kb_id)
|
||||||
|
|
||||||
|
async def get_rag_creation_schema(self, plugin_id: str) -> dict[str, Any]:
|
||||||
|
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||||
|
return await self.handler.get_rag_creation_schema(plugin_author, plugin_name)
|
||||||
|
|
||||||
|
async def get_rag_retrieval_schema(self, plugin_id: str) -> dict[str, Any]:
|
||||||
|
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||||
|
return await self.handler.get_rag_retrieval_schema(plugin_author, plugin_name)
|
||||||
|
|
||||||
|
async def rag_on_kb_create(self, plugin_id: str, kb_id: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Notify plugin about KB creation."""
|
||||||
|
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||||
|
return await self.handler.rag_on_kb_create(plugin_author, plugin_name, kb_id, config)
|
||||||
|
|
||||||
|
async def rag_on_kb_delete(self, plugin_id: str, kb_id: str) -> dict[str, Any]:
|
||||||
|
"""Notify plugin about KB deletion."""
|
||||||
|
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||||
|
return await self.handler.rag_on_kb_delete(plugin_author, plugin_name, kb_id)
|
||||||
|
|
||||||
|
async def call_rag_retrieve(self, plugin_id: str, retrieval_context: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Call plugin to retrieve knowledge.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: Target plugin ID (author/name).
|
||||||
|
retrieval_context: RetrievalContext data.
|
||||||
|
"""
|
||||||
|
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||||
|
return await self.handler.retrieve_knowledge(plugin_author, plugin_name, '', retrieval_context)
|
||||||
|
|
||||||
|
async def list_knowledge_engines(self) -> list[dict[str, Any]]:
|
||||||
|
"""List all available Knowledge Engines from plugins.
|
||||||
|
|
||||||
|
Returns a list of Knowledge Engines with their capabilities and configuration schemas.
|
||||||
"""
|
"""
|
||||||
if not self.is_enable_plugin:
|
if not self.is_enable_plugin:
|
||||||
return {}
|
return []
|
||||||
|
|
||||||
# ===== external knowledge bases =====
|
return await self.handler.list_knowledge_engines()
|
||||||
|
|
||||||
external_kbs = await self.ap.external_kb_service.get_external_knowledge_bases()
|
async def list_parsers(self) -> list[dict[str, Any]]:
|
||||||
|
"""List all available parsers from plugins."""
|
||||||
|
if not self.is_enable_plugin:
|
||||||
|
return []
|
||||||
|
return await self.handler.list_parsers()
|
||||||
|
|
||||||
# Build required_instances list
|
async def call_parser(self, plugin_id: str, context_data: dict[str, Any], file_bytes: bytes) -> dict[str, Any]:
|
||||||
required_instances = []
|
"""Call plugin to parse a document."""
|
||||||
for kb in external_kbs:
|
plugin_author, plugin_name = self._parse_plugin_id(plugin_id)
|
||||||
required_instances.append(
|
return await self.handler.parse_document(plugin_author, plugin_name, context_data, file_bytes)
|
||||||
{
|
|
||||||
'instance_id': kb['uuid'],
|
|
||||||
'plugin_author': kb['plugin_author'],
|
|
||||||
'plugin_name': kb['plugin_name'],
|
|
||||||
'component_kind': 'KnowledgeRetriever',
|
|
||||||
'component_name': kb['retriever_name'],
|
|
||||||
'config': kb['retriever_config'],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ap.logger.info(f'Syncing {len(required_instances)} polymorphic component instances to runtime')
|
|
||||||
|
|
||||||
# Send to runtime
|
|
||||||
sync_result = await self.handler.sync_polymorphic_component_instances(required_instances)
|
|
||||||
|
|
||||||
self.ap.logger.info(
|
|
||||||
f'Sync complete: {len(sync_result.get("success_instances", []))} succeeded, '
|
|
||||||
f'{len(sync_result.get("failed_instances", []))} failed'
|
|
||||||
)
|
|
||||||
|
|
||||||
return sync_result
|
|
||||||
|
|||||||
@@ -26,6 +26,20 @@ from ..core import app
|
|||||||
from ..utils import constants
|
from ..utils import constants
|
||||||
|
|
||||||
|
|
||||||
|
def _make_rag_error_response(error: Exception, error_type: str, **extra_context) -> handler.ActionResponse:
|
||||||
|
"""Create a clean error response for RAG operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: The caught exception.
|
||||||
|
error_type: A category string like 'EmbeddingError', 'VectorStoreError'.
|
||||||
|
**extra_context: Additional context fields for the error message.
|
||||||
|
"""
|
||||||
|
context_parts = [f'{k}={v}' for k, v in extra_context.items()]
|
||||||
|
context_str = f' [{", ".join(context_parts)}]' if context_parts else ''
|
||||||
|
message = f'[{error_type}/{type(error).__name__}]{context_str} {str(error)}'
|
||||||
|
return handler.ActionResponse.error(message=message)
|
||||||
|
|
||||||
|
|
||||||
class RuntimeConnectionHandler(handler.Handler):
|
class RuntimeConnectionHandler(handler.Handler):
|
||||||
"""Runtime connection handler"""
|
"""Runtime connection handler"""
|
||||||
|
|
||||||
@@ -300,11 +314,11 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
|
|
||||||
@self.action(PluginToRuntimeAction.GET_LLM_MODELS)
|
@self.action(PluginToRuntimeAction.GET_LLM_MODELS)
|
||||||
async def get_llm_models(data: dict[str, Any]) -> handler.ActionResponse:
|
async def get_llm_models(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
"""Get llm models"""
|
"""Get llm models, returns list of UUID strings"""
|
||||||
llm_models = await self.ap.llm_model_service.get_llm_models(include_secret=False)
|
llm_models = await self.ap.llm_model_service.get_llm_models(include_secret=False)
|
||||||
return handler.ActionResponse.success(
|
return handler.ActionResponse.success(
|
||||||
data={
|
data={
|
||||||
'llm_models': llm_models,
|
'llm_models': [m['uuid'] for m in llm_models],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -323,7 +337,14 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
messages_obj = [provider_message.Message.model_validate(message) for message in messages]
|
messages_obj = [provider_message.Message.model_validate(message) for message in messages]
|
||||||
funcs_obj = [resource_tool.LLMTool.model_validate(func) for func in funcs]
|
|
||||||
|
# The func field is excluded during model_dump() in plugin side (marked as exclude=True),
|
||||||
|
# but it's a required field for LLMTool validation. We need to provide a placeholder
|
||||||
|
# function when reconstructing the LLMTool objects from serialized data.
|
||||||
|
async def _placeholder_func(**kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
funcs_obj = [resource_tool.LLMTool.model_validate({**func, 'func': _placeholder_func}) for func in funcs]
|
||||||
|
|
||||||
result = await llm_model.provider.invoke_llm(
|
result = await llm_model.provider.invoke_llm(
|
||||||
query=None,
|
query=None,
|
||||||
@@ -439,7 +460,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@self.action(RuntimeToLangBotAction.GET_CONFIG_FILE)
|
@self.action(PluginToRuntimeAction.GET_CONFIG_FILE)
|
||||||
async def get_config_file(data: dict[str, Any]) -> handler.ActionResponse:
|
async def get_config_file(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
"""Get a config file by file key"""
|
"""Get a config file by file key"""
|
||||||
file_key = data['file_key']
|
file_key = data['file_key']
|
||||||
@@ -458,6 +479,282 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
message=f'Failed to load config file {file_key}: {e}',
|
message=f'Failed to load config file {file_key}: {e}',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ================= RAG Capability Handlers =================
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.INVOKE_EMBEDDING)
|
||||||
|
async def invoke_embedding(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
embedding_model_uuid = data['embedding_model_uuid']
|
||||||
|
texts = data['texts']
|
||||||
|
|
||||||
|
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(embedding_model_uuid)
|
||||||
|
if embedding_model is None:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Embedding model with embedding_model_uuid {embedding_model_uuid} not found',
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
vectors = await embedding_model.provider.invoke_embedding(embedding_model, texts)
|
||||||
|
return handler.ActionResponse.success(data={'vectors': vectors})
|
||||||
|
except Exception as e:
|
||||||
|
return _make_rag_error_response(e, 'EmbeddingError', embedding_model_uuid=embedding_model_uuid)
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.VECTOR_UPSERT)
|
||||||
|
async def vector_upsert(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
collection_id = data['collection_id']
|
||||||
|
vectors = data['vectors']
|
||||||
|
ids = data['ids']
|
||||||
|
metadata = data.get('metadata')
|
||||||
|
documents = data.get('documents')
|
||||||
|
if len(vectors) != len(ids):
|
||||||
|
return handler.ActionResponse.error(message='vectors and ids must have same length')
|
||||||
|
if metadata and len(metadata) != len(vectors):
|
||||||
|
return handler.ActionResponse.error(message='metadata must match vectors length')
|
||||||
|
if documents and len(documents) != len(vectors):
|
||||||
|
return handler.ActionResponse.error(message='documents must match vectors length')
|
||||||
|
try:
|
||||||
|
await self.ap.rag_runtime_service.vector_upsert(
|
||||||
|
collection_id,
|
||||||
|
vectors,
|
||||||
|
ids,
|
||||||
|
metadata,
|
||||||
|
documents,
|
||||||
|
)
|
||||||
|
return handler.ActionResponse.success(data={})
|
||||||
|
except Exception as e:
|
||||||
|
return _make_rag_error_response(e, 'VectorStoreError', collection_id=collection_id)
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.VECTOR_SEARCH)
|
||||||
|
async def vector_search(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
collection_id = data['collection_id']
|
||||||
|
query_vector = data['query_vector']
|
||||||
|
top_k = data['top_k']
|
||||||
|
filters = data.get('filters')
|
||||||
|
search_type = data.get('search_type', 'vector')
|
||||||
|
query_text = data.get('query_text', '')
|
||||||
|
vector_weight = data.get('vector_weight')
|
||||||
|
try:
|
||||||
|
results = await self.ap.rag_runtime_service.vector_search(
|
||||||
|
collection_id,
|
||||||
|
query_vector,
|
||||||
|
top_k,
|
||||||
|
filters,
|
||||||
|
search_type,
|
||||||
|
query_text,
|
||||||
|
vector_weight=vector_weight,
|
||||||
|
)
|
||||||
|
return handler.ActionResponse.success(data={'results': results})
|
||||||
|
except Exception as e:
|
||||||
|
return _make_rag_error_response(e, 'VectorStoreError', collection_id=collection_id)
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.VECTOR_DELETE)
|
||||||
|
async def vector_delete(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
collection_id = data['collection_id']
|
||||||
|
file_ids = data.get('file_ids')
|
||||||
|
filters = data.get('filters')
|
||||||
|
try:
|
||||||
|
count = await self.ap.rag_runtime_service.vector_delete(collection_id, file_ids, filters)
|
||||||
|
return handler.ActionResponse.success(data={'count': count})
|
||||||
|
except Exception as e:
|
||||||
|
return _make_rag_error_response(e, 'VectorStoreError', collection_id=collection_id)
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.VECTOR_LIST)
|
||||||
|
async def vector_list(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
collection_id = data['collection_id']
|
||||||
|
filters = data.get('filters')
|
||||||
|
limit = data.get('limit', 20)
|
||||||
|
offset = data.get('offset', 0)
|
||||||
|
try:
|
||||||
|
items, total = await self.ap.rag_runtime_service.vector_list(collection_id, filters, limit, offset)
|
||||||
|
return handler.ActionResponse.success(data={'items': items, 'total': total})
|
||||||
|
except Exception as e:
|
||||||
|
return _make_rag_error_response(e, 'VectorStoreError', collection_id=collection_id)
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.GET_KNOWLEDEGE_FILE_STREAM)
|
||||||
|
async def get_knowledge_file_stream(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
storage_path = data['storage_path']
|
||||||
|
try:
|
||||||
|
content_bytes = await self.ap.rag_runtime_service.get_file_stream(storage_path)
|
||||||
|
file_key = await self.send_file(content_bytes, '')
|
||||||
|
return handler.ActionResponse.success(data={'file_key': file_key})
|
||||||
|
except Exception as e:
|
||||||
|
return _make_rag_error_response(e, 'FileServiceError', storage_path=storage_path)
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.LIST_PARSERS)
|
||||||
|
async def list_parsers(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
"""Plugin requests host to list available parser plugins."""
|
||||||
|
mime_type = data.get('mime_type')
|
||||||
|
try:
|
||||||
|
parsers = await self.ap.knowledge_service.list_parsers(mime_type)
|
||||||
|
return handler.ActionResponse.success(data={'parsers': parsers})
|
||||||
|
except Exception as e:
|
||||||
|
return _make_rag_error_response(e, 'ParserDiscoveryError', mime_type=mime_type)
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.INVOKE_PARSER)
|
||||||
|
async def invoke_parser(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
"""Plugin requests host to invoke a parser plugin."""
|
||||||
|
plugin_author = data['plugin_author']
|
||||||
|
plugin_name = data['plugin_name']
|
||||||
|
storage_path = data['storage_path']
|
||||||
|
mime_type = data.get('mime_type', 'application/octet-stream')
|
||||||
|
filename = data.get('filename', '')
|
||||||
|
metadata = data.get('metadata', {})
|
||||||
|
try:
|
||||||
|
# Read file from storage
|
||||||
|
file_bytes = await self.ap.rag_runtime_service.get_file_stream(storage_path)
|
||||||
|
context_data = {
|
||||||
|
'mime_type': mime_type,
|
||||||
|
'filename': filename,
|
||||||
|
'metadata': metadata,
|
||||||
|
}
|
||||||
|
result = await self.ap.plugin_connector.call_parser(
|
||||||
|
f'{plugin_author}/{plugin_name}', context_data, file_bytes
|
||||||
|
)
|
||||||
|
return handler.ActionResponse.success(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
return _make_rag_error_response(e, 'ParserError')
|
||||||
|
|
||||||
|
# ================= Knowledge Base Query APIs =================
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.LIST_KNOWLEDGE_BASES)
|
||||||
|
async def list_knowledge_bases(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
"""List all knowledge bases available in the LangBot instance (unrestricted)."""
|
||||||
|
knowledge_bases = []
|
||||||
|
for kb_uuid, kb in self.ap.rag_mgr.knowledge_bases.items():
|
||||||
|
knowledge_bases.append(
|
||||||
|
{
|
||||||
|
'uuid': kb.get_uuid(),
|
||||||
|
'name': kb.get_name(),
|
||||||
|
'description': kb.knowledge_base_entity.description or '',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return handler.ActionResponse.success(data={'knowledge_bases': knowledge_bases})
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.RETRIEVE_KNOWLEDGE)
|
||||||
|
async def retrieve_knowledge(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
"""Retrieve documents from any knowledge base (unrestricted)."""
|
||||||
|
kb_id = data['kb_id']
|
||||||
|
query_text = data['query_text']
|
||||||
|
top_k = data.get('top_k', 5)
|
||||||
|
filters = data.get('filters', {})
|
||||||
|
|
||||||
|
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_id)
|
||||||
|
if not kb:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Knowledge base {kb_id} not found',
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
entries = await kb.retrieve(
|
||||||
|
query_text,
|
||||||
|
settings={
|
||||||
|
'top_k': top_k,
|
||||||
|
'filters': filters,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
results = [entry.model_dump(mode='json') for entry in entries]
|
||||||
|
return handler.ActionResponse.success(data={'results': results})
|
||||||
|
except Exception as e:
|
||||||
|
return _make_rag_error_response(e, 'RetrievalError', kb_id=kb_id)
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.LIST_PIPELINE_KNOWLEDGE_BASES)
|
||||||
|
async def list_pipeline_knowledge_bases(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
"""List knowledge bases configured for the current query's pipeline."""
|
||||||
|
query_id = data['query_id']
|
||||||
|
|
||||||
|
if query_id not in self.ap.query_pool.cached_queries:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Query with query_id {query_id} not found',
|
||||||
|
)
|
||||||
|
|
||||||
|
query = self.ap.query_pool.cached_queries[query_id]
|
||||||
|
|
||||||
|
kb_uuids = []
|
||||||
|
if query.pipeline_config:
|
||||||
|
local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {})
|
||||||
|
kb_uuids = local_agent_config.get('knowledge-bases', [])
|
||||||
|
# Backward compatibility
|
||||||
|
if not kb_uuids:
|
||||||
|
old_kb_uuid = local_agent_config.get('knowledge-base', '')
|
||||||
|
if old_kb_uuid and old_kb_uuid != '__none__':
|
||||||
|
kb_uuids = [old_kb_uuid]
|
||||||
|
|
||||||
|
knowledge_bases = []
|
||||||
|
for kb_uuid in kb_uuids:
|
||||||
|
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||||
|
if kb:
|
||||||
|
knowledge_bases.append(
|
||||||
|
{
|
||||||
|
'uuid': kb.get_uuid(),
|
||||||
|
'name': kb.get_name(),
|
||||||
|
'description': kb.knowledge_base_entity.description or '',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return handler.ActionResponse.success(data={'knowledge_bases': knowledge_bases})
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.RETRIEVE_KNOWLEDGE_BASE)
|
||||||
|
async def retrieve_knowledge_base(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
"""Retrieve documents from a knowledge base within the pipeline's scope."""
|
||||||
|
query_id = data['query_id']
|
||||||
|
kb_id = data['kb_id']
|
||||||
|
query_text = data['query_text']
|
||||||
|
top_k = data.get('top_k', 5)
|
||||||
|
filters = data.get('filters', {})
|
||||||
|
|
||||||
|
if query_id not in self.ap.query_pool.cached_queries:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Query with query_id {query_id} not found',
|
||||||
|
)
|
||||||
|
|
||||||
|
query = self.ap.query_pool.cached_queries[query_id]
|
||||||
|
|
||||||
|
# Validate kb_id is in pipeline's allowed list
|
||||||
|
allowed_kb_uuids = []
|
||||||
|
if query.pipeline_config:
|
||||||
|
local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {})
|
||||||
|
allowed_kb_uuids = local_agent_config.get('knowledge-bases', [])
|
||||||
|
if not allowed_kb_uuids:
|
||||||
|
old_kb_uuid = local_agent_config.get('knowledge-base', '')
|
||||||
|
if old_kb_uuid and old_kb_uuid != '__none__':
|
||||||
|
allowed_kb_uuids = [old_kb_uuid]
|
||||||
|
|
||||||
|
if kb_id not in allowed_kb_uuids:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Knowledge base {kb_id} is not configured for this pipeline',
|
||||||
|
)
|
||||||
|
|
||||||
|
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_id)
|
||||||
|
if not kb:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Knowledge base {kb_id} not found',
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
session_name = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
|
||||||
|
entries = await kb.retrieve(
|
||||||
|
query_text,
|
||||||
|
settings={
|
||||||
|
'top_k': top_k,
|
||||||
|
'filters': filters,
|
||||||
|
'session_name': session_name,
|
||||||
|
'bot_uuid': query.bot_uuid or '',
|
||||||
|
'sender_id': str(query.sender_id),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
results = [entry.model_dump(mode='json') for entry in entries]
|
||||||
|
return handler.ActionResponse.success(data={'results': results})
|
||||||
|
except Exception as e:
|
||||||
|
return _make_rag_error_response(e, 'RetrievalError', kb_id=kb_id)
|
||||||
|
|
||||||
|
@self.action(CommonAction.PING)
|
||||||
|
async def ping(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
"""Ping"""
|
||||||
|
return handler.ActionResponse.success(
|
||||||
|
data={
|
||||||
|
'pong': 'pong',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
async def ping(self) -> dict[str, Any]:
|
async def ping(self) -> dict[str, Any]:
|
||||||
"""Ping the runtime"""
|
"""Ping the runtime"""
|
||||||
return await self.call_action(
|
return await self.call_action(
|
||||||
@@ -717,26 +1014,13 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
async for ret in gen:
|
async for ret in gen:
|
||||||
yield ret
|
yield ret
|
||||||
|
|
||||||
# KnowledgeRetriever methods
|
|
||||||
async def list_knowledge_retrievers(self, include_plugins: list[str] | None = None) -> list[dict[str, Any]]:
|
|
||||||
"""List knowledge retrievers"""
|
|
||||||
result = await self.call_action(
|
|
||||||
LangBotToRuntimeAction.LIST_KNOWLEDGE_RETRIEVERS,
|
|
||||||
{
|
|
||||||
'include_plugins': include_plugins,
|
|
||||||
},
|
|
||||||
timeout=10,
|
|
||||||
)
|
|
||||||
return result['retrievers']
|
|
||||||
|
|
||||||
async def retrieve_knowledge(
|
async def retrieve_knowledge(
|
||||||
self,
|
self,
|
||||||
plugin_author: str,
|
plugin_author: str,
|
||||||
plugin_name: str,
|
plugin_name: str,
|
||||||
retriever_name: str,
|
retriever_name: str,
|
||||||
instance_id: str,
|
|
||||||
retrieval_context: dict[str, Any],
|
retrieval_context: dict[str, Any],
|
||||||
) -> list[dict[str, Any]]:
|
) -> dict[str, Any]:
|
||||||
"""Retrieve knowledge"""
|
"""Retrieve knowledge"""
|
||||||
result = await self.call_action(
|
result = await self.call_action(
|
||||||
LangBotToRuntimeAction.RETRIEVE_KNOWLEDGE,
|
LangBotToRuntimeAction.RETRIEVE_KNOWLEDGE,
|
||||||
@@ -744,22 +1028,10 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
'plugin_author': plugin_author,
|
'plugin_author': plugin_author,
|
||||||
'plugin_name': plugin_name,
|
'plugin_name': plugin_name,
|
||||||
'retriever_name': retriever_name,
|
'retriever_name': retriever_name,
|
||||||
'instance_id': instance_id,
|
|
||||||
'retrieval_context': retrieval_context,
|
'retrieval_context': retrieval_context,
|
||||||
},
|
},
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
return result['retrieval_results']
|
|
||||||
|
|
||||||
async def sync_polymorphic_component_instances(self, required_instances: list[dict[str, Any]]) -> dict[str, Any]:
|
|
||||||
"""Sync polymorphic component instances with runtime"""
|
|
||||||
result = await self.call_action(
|
|
||||||
LangBotToRuntimeAction.SYNC_POLYMORPHIC_COMPONENT_INSTANCES,
|
|
||||||
{
|
|
||||||
'required_instances': required_instances,
|
|
||||||
},
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_debug_info(self) -> dict[str, Any]:
|
async def get_debug_info(self) -> dict[str, Any]:
|
||||||
@@ -770,3 +1042,91 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# ================= RAG Capability Callers (LangBot -> Runtime) =================
|
||||||
|
|
||||||
|
async def rag_ingest_document(
|
||||||
|
self, plugin_author: str, plugin_name: str, context_data: dict[str, Any]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send INGEST_DOCUMENT action to runtime."""
|
||||||
|
result = await self.call_action(
|
||||||
|
LangBotToRuntimeAction.RAG_INGEST_DOCUMENT,
|
||||||
|
{'plugin_author': plugin_author, 'plugin_name': plugin_name, 'context': context_data},
|
||||||
|
timeout=1200, # Ingestion can be slow for large documents
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def rag_delete_document(self, plugin_author: str, plugin_name: str, document_id: str, kb_id: str) -> bool:
|
||||||
|
result = await self.call_action(
|
||||||
|
LangBotToRuntimeAction.RAG_DELETE_DOCUMENT,
|
||||||
|
{'plugin_author': plugin_author, 'plugin_name': plugin_name, 'document_id': document_id, 'kb_id': kb_id},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
return result.get('success', False)
|
||||||
|
|
||||||
|
async def rag_on_kb_create(
|
||||||
|
self, plugin_author: str, plugin_name: str, kb_id: str, config: dict[str, Any]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Notify plugin about KB creation."""
|
||||||
|
result = await self.call_action(
|
||||||
|
LangBotToRuntimeAction.RAG_ON_KB_CREATE,
|
||||||
|
{'plugin_author': plugin_author, 'plugin_name': plugin_name, 'kb_id': kb_id, 'config': config},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def rag_on_kb_delete(self, plugin_author: str, plugin_name: str, kb_id: str) -> dict[str, Any]:
|
||||||
|
"""Notify plugin about KB deletion."""
|
||||||
|
result = await self.call_action(
|
||||||
|
LangBotToRuntimeAction.RAG_ON_KB_DELETE,
|
||||||
|
{'plugin_author': plugin_author, 'plugin_name': plugin_name, 'kb_id': kb_id},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_rag_creation_schema(self, plugin_author: str, plugin_name: str) -> dict[str, Any]:
|
||||||
|
return await self.call_action(
|
||||||
|
LangBotToRuntimeAction.GET_RAG_CREATION_SETTINGS_SCHEMA,
|
||||||
|
{'plugin_author': plugin_author, 'plugin_name': plugin_name},
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_rag_retrieval_schema(self, plugin_author: str, plugin_name: str) -> dict[str, Any]:
|
||||||
|
return await self.call_action(
|
||||||
|
LangBotToRuntimeAction.GET_RAG_RETRIEVAL_SETTINGS_SCHEMA,
|
||||||
|
{'plugin_author': plugin_author, 'plugin_name': plugin_name},
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_knowledge_engines(self) -> list[dict[str, Any]]:
|
||||||
|
"""List all available Knowledge Engines from plugins."""
|
||||||
|
result = await self.call_action(LangBotToRuntimeAction.LIST_KNOWLEDGE_ENGINES, {}, timeout=60)
|
||||||
|
return result.get('engines', [])
|
||||||
|
|
||||||
|
# ================= Parser Capability Callers (LangBot -> Runtime) =================
|
||||||
|
|
||||||
|
async def list_parsers(self) -> list[dict[str, Any]]:
|
||||||
|
"""List all available parsers from plugins."""
|
||||||
|
result = await self.call_action(LangBotToRuntimeAction.LIST_PARSERS, {}, timeout=60)
|
||||||
|
return result.get('parsers', [])
|
||||||
|
|
||||||
|
async def parse_document(
|
||||||
|
self, plugin_author: str, plugin_name: str, context_data: dict[str, Any], file_bytes: bytes
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send PARSE_DOCUMENT action to runtime.
|
||||||
|
|
||||||
|
Sends file content via chunked FILE_CHUNK transfer, then invokes
|
||||||
|
the PARSE_DOCUMENT action with a file_key reference.
|
||||||
|
"""
|
||||||
|
# Send file to runtime via chunked transfer
|
||||||
|
file_key = await self.send_file(file_bytes, '')
|
||||||
|
|
||||||
|
# Include file_key in context_data for the runtime to read
|
||||||
|
context_data['file_key'] = file_key
|
||||||
|
|
||||||
|
result = await self.call_action(
|
||||||
|
LangBotToRuntimeAction.PARSE_DOCUMENT,
|
||||||
|
{'plugin_author': plugin_author, 'plugin_name': plugin_name, 'context': context_data},
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|||||||
@@ -288,10 +288,10 @@ class AnthropicMessages(requester.ProviderAPIRequester):
|
|||||||
think_started = False
|
think_started = False
|
||||||
think_ended = False
|
think_ended = False
|
||||||
finish_reason = False
|
finish_reason = False
|
||||||
content = ''
|
|
||||||
tool_name = ''
|
tool_name = ''
|
||||||
tool_id = ''
|
tool_id = ''
|
||||||
async for chunk in await self.client.messages.create(**args):
|
async for chunk in await self.client.messages.create(**args):
|
||||||
|
content = ''
|
||||||
tool_call = {'id': None, 'function': {'name': None, 'arguments': None}, 'type': 'function'}
|
tool_call = {'id': None, 'function': {'name': None, 'arguments': None}, 'type': 'function'}
|
||||||
if isinstance(
|
if isinstance(
|
||||||
chunk, anthropic.types.raw_content_block_start_event.RawContentBlockStartEvent
|
chunk, anthropic.types.raw_content_block_start_event.RawContentBlockStartEvent
|
||||||
|
|||||||
@@ -72,6 +72,28 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
|||||||
content = f'<think>\n{thinking_content}\n</think>\n{content}'.strip()
|
content = f'<think>\n{thinking_content}\n</think>\n{content}'.strip()
|
||||||
return content, thinking_content
|
return content, thinking_content
|
||||||
|
|
||||||
|
def _extract_dify_text_output(self, value: typing.Any) -> str:
|
||||||
|
"""Extract text content from Dify output payload."""
|
||||||
|
if value is None:
|
||||||
|
return ''
|
||||||
|
if isinstance(value, dict):
|
||||||
|
content = value.get('content')
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
return json.dumps(value, ensure_ascii=False)
|
||||||
|
if isinstance(value, str):
|
||||||
|
text = value.strip()
|
||||||
|
if not text:
|
||||||
|
return ''
|
||||||
|
try:
|
||||||
|
parsed = json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return value
|
||||||
|
if isinstance(parsed, dict) and isinstance(parsed.get('content'), str):
|
||||||
|
return parsed['content']
|
||||||
|
return value
|
||||||
|
return str(value)
|
||||||
|
|
||||||
async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[dict]]:
|
async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[dict]]:
|
||||||
"""预处理用户消息,提取纯文本,并将图片/文件上传到 Dify 服务
|
"""预处理用户消息,提取纯文本,并将图片/文件上传到 Dify 服务
|
||||||
|
|
||||||
@@ -192,7 +214,8 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
|||||||
if mode == 'workflow':
|
if mode == 'workflow':
|
||||||
if chunk['event'] == 'node_finished':
|
if chunk['event'] == 'node_finished':
|
||||||
if chunk['data']['node_type'] == 'answer':
|
if chunk['data']['node_type'] == 'answer':
|
||||||
content, _ = self._process_thinking_content(chunk['data']['outputs']['answer'])
|
answer = self._extract_dify_text_output(chunk['data']['outputs'].get('answer'))
|
||||||
|
content, _ = self._process_thinking_content(answer)
|
||||||
|
|
||||||
yield provider_message.Message(
|
yield provider_message.Message(
|
||||||
role='assistant',
|
role='assistant',
|
||||||
@@ -405,6 +428,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
|||||||
for f in upload_files
|
for f in upload_files
|
||||||
]
|
]
|
||||||
|
|
||||||
|
mode = 'basic'
|
||||||
basic_mode_pending_chunk = ''
|
basic_mode_pending_chunk = ''
|
||||||
|
|
||||||
inputs = {}
|
inputs = {}
|
||||||
@@ -417,6 +441,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
|||||||
is_final = False
|
is_final = False
|
||||||
think_start = False
|
think_start = False
|
||||||
think_end = False
|
think_end = False
|
||||||
|
yielded_final = False
|
||||||
|
|
||||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||||
|
|
||||||
@@ -430,11 +455,12 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
|||||||
):
|
):
|
||||||
self.ap.logger.debug('dify-chat-chunk: ' + str(chunk))
|
self.ap.logger.debug('dify-chat-chunk: ' + str(chunk))
|
||||||
|
|
||||||
# if chunk['event'] == 'workflow_started':
|
if chunk['event'] == 'workflow_started':
|
||||||
# mode = 'workflow'
|
mode = 'workflow'
|
||||||
# if mode == 'workflow':
|
elif chunk['event'] in ('node_started', 'node_finished', 'workflow_finished'):
|
||||||
# elif mode == 'basic':
|
# Some Dify deployments may omit workflow_started in streamed chunks.
|
||||||
# 因为都只是返回的 message也没有工具调用什么的,暂时不分类
|
mode = 'workflow'
|
||||||
|
|
||||||
if chunk['event'] == 'message':
|
if chunk['event'] == 'message':
|
||||||
message_idx += 1
|
message_idx += 1
|
||||||
if remove_think:
|
if remove_think:
|
||||||
@@ -457,14 +483,30 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
|||||||
|
|
||||||
if chunk['event'] == 'message_end':
|
if chunk['event'] == 'message_end':
|
||||||
is_final = True
|
is_final = True
|
||||||
|
elif chunk['event'] == 'workflow_finished':
|
||||||
|
is_final = True
|
||||||
|
if chunk['data'].get('error'):
|
||||||
|
raise errors.DifyAPIError(chunk['data']['error'])
|
||||||
|
|
||||||
if is_final or message_idx % 8 == 0:
|
if mode == 'workflow' and chunk['event'] == 'node_finished':
|
||||||
|
if chunk['data'].get('node_type') == 'answer':
|
||||||
|
answer = self._extract_dify_text_output(chunk['data'].get('outputs', {}).get('answer'))
|
||||||
|
if answer:
|
||||||
|
basic_mode_pending_chunk = answer
|
||||||
|
|
||||||
|
if (
|
||||||
|
not yielded_final
|
||||||
|
and (is_final or message_idx % 8 == 0)
|
||||||
|
and (basic_mode_pending_chunk != '' or is_final)
|
||||||
|
):
|
||||||
# content, _ = self._process_thinking_content(basic_mode_pending_chunk)
|
# content, _ = self._process_thinking_content(basic_mode_pending_chunk)
|
||||||
yield provider_message.MessageChunk(
|
yield provider_message.MessageChunk(
|
||||||
role='assistant',
|
role='assistant',
|
||||||
content=basic_mode_pending_chunk,
|
content=basic_mode_pending_chunk,
|
||||||
is_final=is_final,
|
is_final=is_final,
|
||||||
)
|
)
|
||||||
|
if is_final:
|
||||||
|
yielded_final = True
|
||||||
|
|
||||||
if chunk is None:
|
if chunk is None:
|
||||||
raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置')
|
raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置')
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
import copy
|
import copy
|
||||||
import typing
|
import typing
|
||||||
from .. import runner
|
from .. import runner
|
||||||
|
from ..modelmgr import requester as modelmgr_requester
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
import langbot_plugin.api.entities.builtin.rag.context as rag_context
|
import langbot_plugin.api.entities.builtin.rag.context as rag_context
|
||||||
@@ -26,29 +27,114 @@ Respond in the same language as the user's input.
|
|||||||
|
|
||||||
@runner.runner_class('local-agent')
|
@runner.runner_class('local-agent')
|
||||||
class LocalAgentRunner(runner.RequestRunner):
|
class LocalAgentRunner(runner.RequestRunner):
|
||||||
"""本地Agent请求运行器"""
|
"""Local agent request runner"""
|
||||||
|
|
||||||
class ToolCallTracker:
|
async def _get_model_candidates(
|
||||||
"""工具调用追踪器"""
|
self,
|
||||||
|
query: pipeline_query.Query,
|
||||||
|
) -> list[modelmgr_requester.RuntimeLLMModel]:
|
||||||
|
"""Build ordered list of models to try: primary model + fallback models."""
|
||||||
|
candidates = []
|
||||||
|
|
||||||
def __init__(self):
|
# Primary model
|
||||||
self.active_calls: dict[str, dict] = {}
|
if query.use_llm_model_uuid:
|
||||||
self.completed_calls: list[provider_message.ToolCall] = []
|
try:
|
||||||
|
primary = await self.ap.model_mgr.get_model_by_uuid(query.use_llm_model_uuid)
|
||||||
|
candidates.append(primary)
|
||||||
|
except ValueError:
|
||||||
|
self.ap.logger.warning(f'Primary model {query.use_llm_model_uuid} not found')
|
||||||
|
|
||||||
|
# Fallback models
|
||||||
|
fallback_uuids = (query.variables or {}).get('_fallback_model_uuids', [])
|
||||||
|
for fb_uuid in fallback_uuids:
|
||||||
|
try:
|
||||||
|
fb_model = await self.ap.model_mgr.get_model_by_uuid(fb_uuid)
|
||||||
|
candidates.append(fb_model)
|
||||||
|
except ValueError:
|
||||||
|
self.ap.logger.warning(f'Fallback model {fb_uuid} not found, skipping')
|
||||||
|
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
async def _invoke_with_fallback(
|
||||||
|
self,
|
||||||
|
query: pipeline_query.Query,
|
||||||
|
candidates: list[modelmgr_requester.RuntimeLLMModel],
|
||||||
|
messages: list,
|
||||||
|
funcs: list,
|
||||||
|
remove_think: bool,
|
||||||
|
) -> tuple[provider_message.Message, modelmgr_requester.RuntimeLLMModel]:
|
||||||
|
"""Try non-streaming invocation with sequential fallback. Returns (message, model_used)."""
|
||||||
|
last_error = None
|
||||||
|
for model in candidates:
|
||||||
|
try:
|
||||||
|
msg = await model.provider.invoke_llm(
|
||||||
|
query,
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
funcs if model.model_entity.abilities.__contains__('func_call') else [],
|
||||||
|
extra_args=model.model_entity.extra_args,
|
||||||
|
remove_think=remove_think,
|
||||||
|
)
|
||||||
|
return msg, model
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
self.ap.logger.warning(f'Model {model.model_entity.name} failed: {e}, trying next fallback...')
|
||||||
|
raise last_error or RuntimeError('No model candidates available')
|
||||||
|
|
||||||
|
async def _invoke_stream_with_fallback(
|
||||||
|
self,
|
||||||
|
query: pipeline_query.Query,
|
||||||
|
candidates: list[modelmgr_requester.RuntimeLLMModel],
|
||||||
|
messages: list,
|
||||||
|
funcs: list,
|
||||||
|
remove_think: bool,
|
||||||
|
) -> tuple[typing.AsyncGenerator, modelmgr_requester.RuntimeLLMModel]:
|
||||||
|
"""Try streaming invocation with sequential fallback. Returns (stream_generator, model_used).
|
||||||
|
|
||||||
|
Fallback is only possible before any chunks have been yielded to the client.
|
||||||
|
Once streaming starts, the model is committed.
|
||||||
|
"""
|
||||||
|
last_error = None
|
||||||
|
for model in candidates:
|
||||||
|
try:
|
||||||
|
stream = model.provider.invoke_llm_stream(
|
||||||
|
query,
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
funcs if model.model_entity.abilities.__contains__('func_call') else [],
|
||||||
|
extra_args=model.model_entity.extra_args,
|
||||||
|
remove_think=remove_think,
|
||||||
|
)
|
||||||
|
# Attempt to get the first chunk to verify the stream works
|
||||||
|
first_chunk = await stream.__anext__()
|
||||||
|
|
||||||
|
async def _chain_stream(first, rest):
|
||||||
|
yield first
|
||||||
|
async for chunk in rest:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return _chain_stream(first_chunk, stream), model
|
||||||
|
except StopAsyncIteration:
|
||||||
|
# Empty stream — treat as success (model returned nothing)
|
||||||
|
async def _empty_stream():
|
||||||
|
return
|
||||||
|
yield # make it a generator
|
||||||
|
|
||||||
|
return _empty_stream(), model
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
self.ap.logger.warning(f'Model {model.model_entity.name} stream failed: {e}, trying next fallback...')
|
||||||
|
raise last_error or RuntimeError('No model candidates available')
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, query: pipeline_query.Query
|
self, query: pipeline_query.Query
|
||||||
) -> typing.AsyncGenerator[provider_message.Message | provider_message.MessageChunk, None]:
|
) -> typing.AsyncGenerator[provider_message.Message | provider_message.MessageChunk, None]:
|
||||||
"""运行请求"""
|
"""Run request"""
|
||||||
pending_tool_calls = []
|
pending_tool_calls = []
|
||||||
|
|
||||||
# Get knowledge bases list (new field)
|
# Get knowledge bases list from query variables (set by PreProcessor,
|
||||||
kb_uuids = query.pipeline_config['ai']['local-agent'].get('knowledge-bases', [])
|
# may have been modified by plugins during PromptPreProcessing)
|
||||||
|
kb_uuids = query.variables.get('_knowledge_base_uuids', [])
|
||||||
# Fallback to old field for backward compatibility
|
|
||||||
if not kb_uuids:
|
|
||||||
old_kb_uuid = query.pipeline_config['ai']['local-agent'].get('knowledge-base', '')
|
|
||||||
if old_kb_uuid and old_kb_uuid != '__none__':
|
|
||||||
kb_uuids = [old_kb_uuid]
|
|
||||||
|
|
||||||
user_message = copy.deepcopy(query.user_message)
|
user_message = copy.deepcopy(query.user_message)
|
||||||
|
|
||||||
@@ -74,15 +160,14 @@ class LocalAgentRunner(runner.RequestRunner):
|
|||||||
self.ap.logger.warning(f'Knowledge base {kb_uuid} not found, skipping')
|
self.ap.logger.warning(f'Knowledge base {kb_uuid} not found, skipping')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get top_k based on KB type
|
result = await kb.retrieve(
|
||||||
if kb.get_type() == 'internal':
|
user_message_text,
|
||||||
top_k = kb.knowledge_base_entity.top_k
|
settings={
|
||||||
elif kb.get_type() == 'external':
|
'bot_uuid': query.bot_uuid or '',
|
||||||
top_k = 5 # external kb's top_k is managed by plugin config
|
'sender_id': str(query.sender_id),
|
||||||
else:
|
'session_name': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||||
top_k = 5 # default fallback
|
},
|
||||||
|
)
|
||||||
result = await kb.retrieve(user_message_text, top_k)
|
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
all_results.extend(result)
|
all_results.extend(result)
|
||||||
@@ -97,9 +182,9 @@ class LocalAgentRunner(runner.RequestRunner):
|
|||||||
if content.type == 'text' and content.text is not None:
|
if content.type == 'text' and content.text is not None:
|
||||||
texts.append(f'[{idx}] {content.text}')
|
texts.append(f'[{idx}] {content.text}')
|
||||||
idx += 1
|
idx += 1
|
||||||
rag_context = '\n\n'.join(texts)
|
rag_context_text = '\n\n'.join(texts)
|
||||||
final_user_message_text = rag_combined_prompt_template.format(
|
final_user_message_text = rag_combined_prompt_template.format(
|
||||||
rag_context=rag_context, user_message=user_message_text
|
rag_context=rag_context_text, user_message=user_message_text
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -121,51 +206,51 @@ class LocalAgentRunner(runner.RequestRunner):
|
|||||||
|
|
||||||
remove_think = query.pipeline_config['output'].get('misc', '').get('remove-think')
|
remove_think = query.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||||
|
|
||||||
use_llm_model = await self.ap.model_mgr.get_model_by_uuid(query.use_llm_model_uuid)
|
# Build ordered candidate list (primary + fallbacks)
|
||||||
|
candidates = await self._get_model_candidates(query)
|
||||||
|
if not candidates:
|
||||||
|
raise RuntimeError('No LLM model configured for local-agent runner')
|
||||||
|
|
||||||
self.ap.logger.debug(
|
self.ap.logger.debug(
|
||||||
f'localagent req: query={query.query_id} req_messages={req_messages} use_llm_model={query.use_llm_model_uuid}'
|
f'localagent req: query={query.query_id} req_messages={req_messages} '
|
||||||
|
f'candidates={[m.model_entity.name for m in candidates]}'
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_stream:
|
if not is_stream:
|
||||||
# 非流式输出,直接请求
|
# Non-streaming: invoke with fallback
|
||||||
|
msg, use_llm_model = await self._invoke_with_fallback(
|
||||||
msg = await use_llm_model.provider.invoke_llm(
|
|
||||||
query,
|
query,
|
||||||
use_llm_model,
|
candidates,
|
||||||
req_messages,
|
req_messages,
|
||||||
query.use_funcs,
|
query.use_funcs,
|
||||||
extra_args=use_llm_model.model_entity.extra_args,
|
remove_think,
|
||||||
remove_think=remove_think,
|
|
||||||
)
|
)
|
||||||
yield msg
|
yield msg
|
||||||
final_msg = msg
|
final_msg = msg
|
||||||
else:
|
else:
|
||||||
# 流式输出,需要处理工具调用
|
# Streaming: invoke with fallback
|
||||||
tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
||||||
msg_idx = 0
|
msg_idx = 0
|
||||||
accumulated_content = '' # 从开始累积的所有内容
|
accumulated_content = ''
|
||||||
last_role = 'assistant'
|
last_role = 'assistant'
|
||||||
msg_sequence = 1
|
msg_sequence = 1
|
||||||
async for msg in use_llm_model.provider.invoke_llm_stream(
|
|
||||||
|
stream_src, use_llm_model = await self._invoke_stream_with_fallback(
|
||||||
query,
|
query,
|
||||||
use_llm_model,
|
candidates,
|
||||||
req_messages,
|
req_messages,
|
||||||
query.use_funcs,
|
query.use_funcs,
|
||||||
extra_args=use_llm_model.model_entity.extra_args,
|
remove_think,
|
||||||
remove_think=remove_think,
|
)
|
||||||
):
|
async for msg in stream_src:
|
||||||
msg_idx = msg_idx + 1
|
msg_idx = msg_idx + 1
|
||||||
|
|
||||||
# 记录角色
|
|
||||||
if msg.role:
|
if msg.role:
|
||||||
last_role = msg.role
|
last_role = msg.role
|
||||||
|
|
||||||
# 累积内容
|
|
||||||
if msg.content:
|
if msg.content:
|
||||||
accumulated_content += msg.content
|
accumulated_content += msg.content
|
||||||
|
|
||||||
# 处理工具调用
|
|
||||||
if msg.tool_calls:
|
if msg.tool_calls:
|
||||||
for tool_call in msg.tool_calls:
|
for tool_call in msg.tool_calls:
|
||||||
if tool_call.id not in tool_calls_map:
|
if tool_call.id not in tool_calls_map:
|
||||||
@@ -177,21 +262,18 @@ class LocalAgentRunner(runner.RequestRunner):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if tool_call.function and tool_call.function.arguments:
|
if tool_call.function and tool_call.function.arguments:
|
||||||
# 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖
|
|
||||||
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
|
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
|
||||||
# continue
|
|
||||||
# 每8个chunk或最后一个chunk时,输出所有累积的内容
|
|
||||||
if msg_idx % 8 == 0 or msg.is_final:
|
if msg_idx % 8 == 0 or msg.is_final:
|
||||||
msg_sequence += 1
|
msg_sequence += 1
|
||||||
yield provider_message.MessageChunk(
|
yield provider_message.MessageChunk(
|
||||||
role=last_role,
|
role=last_role,
|
||||||
content=accumulated_content, # 输出所有累积内容
|
content=accumulated_content,
|
||||||
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
|
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
|
||||||
is_final=msg.is_final,
|
is_final=msg.is_final,
|
||||||
msg_sequence=msg_sequence,
|
msg_sequence=msg_sequence,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建最终消息用于后续处理
|
|
||||||
final_msg = provider_message.MessageChunk(
|
final_msg = provider_message.MessageChunk(
|
||||||
role=last_role,
|
role=last_role,
|
||||||
content=accumulated_content,
|
content=accumulated_content,
|
||||||
@@ -206,7 +288,8 @@ class LocalAgentRunner(runner.RequestRunner):
|
|||||||
|
|
||||||
req_messages.append(final_msg)
|
req_messages.append(final_msg)
|
||||||
|
|
||||||
# 持续请求,只要还有待处理的工具调用就继续处理调用
|
# Once a model succeeds, commit to it for the tool call loop
|
||||||
|
# (no fallback mid-conversation — different models may interpret tool results differently)
|
||||||
while pending_tool_calls:
|
while pending_tool_calls:
|
||||||
for tool_call in pending_tool_calls:
|
for tool_call in pending_tool_calls:
|
||||||
try:
|
try:
|
||||||
@@ -247,7 +330,6 @@ class LocalAgentRunner(runner.RequestRunner):
|
|||||||
|
|
||||||
req_messages.append(msg)
|
req_messages.append(msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 工具调用出错,添加一个报错信息到 req_messages
|
|
||||||
err_msg = provider_message.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id)
|
err_msg = provider_message.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id)
|
||||||
|
|
||||||
yield err_msg
|
yield err_msg
|
||||||
@@ -255,39 +337,38 @@ class LocalAgentRunner(runner.RequestRunner):
|
|||||||
req_messages.append(err_msg)
|
req_messages.append(err_msg)
|
||||||
|
|
||||||
self.ap.logger.debug(
|
self.ap.logger.debug(
|
||||||
f'localagent req: query={query.query_id} req_messages={req_messages} use_llm_model={query.use_llm_model_uuid}'
|
f'localagent req: query={query.query_id} req_messages={req_messages} '
|
||||||
|
f'use_llm_model={use_llm_model.model_entity.name}'
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_stream:
|
if is_stream:
|
||||||
tool_calls_map = {}
|
tool_calls_map = {}
|
||||||
msg_idx = 0
|
msg_idx = 0
|
||||||
accumulated_content = '' # 从开始累积的所有内容
|
accumulated_content = ''
|
||||||
last_role = 'assistant'
|
last_role = 'assistant'
|
||||||
msg_sequence = first_end_sequence
|
msg_sequence = first_end_sequence
|
||||||
|
|
||||||
async for msg in use_llm_model.provider.invoke_llm_stream(
|
tool_stream_src = use_llm_model.provider.invoke_llm_stream(
|
||||||
query,
|
query,
|
||||||
use_llm_model,
|
use_llm_model,
|
||||||
req_messages,
|
req_messages,
|
||||||
query.use_funcs,
|
query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [],
|
||||||
extra_args=use_llm_model.model_entity.extra_args,
|
extra_args=use_llm_model.model_entity.extra_args,
|
||||||
remove_think=remove_think,
|
remove_think=remove_think,
|
||||||
):
|
)
|
||||||
|
async for msg in tool_stream_src:
|
||||||
msg_idx += 1
|
msg_idx += 1
|
||||||
|
|
||||||
# 记录角色
|
|
||||||
if msg.role:
|
if msg.role:
|
||||||
last_role = msg.role
|
last_role = msg.role
|
||||||
|
|
||||||
# 第一次请求工具调用时的内容
|
# Prepend first-round content on first chunk of tool-call round
|
||||||
if msg_idx == 1:
|
if msg_idx == 1:
|
||||||
accumulated_content = first_content if first_content is not None else accumulated_content
|
accumulated_content = first_content if first_content is not None else accumulated_content
|
||||||
|
|
||||||
# 累积内容
|
|
||||||
if msg.content:
|
if msg.content:
|
||||||
accumulated_content += msg.content
|
accumulated_content += msg.content
|
||||||
|
|
||||||
# 处理工具调用
|
|
||||||
if msg.tool_calls:
|
if msg.tool_calls:
|
||||||
for tool_call in msg.tool_calls:
|
for tool_call in msg.tool_calls:
|
||||||
if tool_call.id not in tool_calls_map:
|
if tool_call.id not in tool_calls_map:
|
||||||
@@ -299,15 +380,13 @@ class LocalAgentRunner(runner.RequestRunner):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if tool_call.function and tool_call.function.arguments:
|
if tool_call.function and tool_call.function.arguments:
|
||||||
# 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖
|
|
||||||
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
|
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
|
||||||
|
|
||||||
# 每8个chunk或最后一个chunk时,输出所有累积的内容
|
|
||||||
if msg_idx % 8 == 0 or msg.is_final:
|
if msg_idx % 8 == 0 or msg.is_final:
|
||||||
msg_sequence += 1
|
msg_sequence += 1
|
||||||
yield provider_message.MessageChunk(
|
yield provider_message.MessageChunk(
|
||||||
role=last_role,
|
role=last_role,
|
||||||
content=accumulated_content, # 输出所有累积内容
|
content=accumulated_content,
|
||||||
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
|
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
|
||||||
is_final=msg.is_final,
|
is_final=msg.is_final,
|
||||||
msg_sequence=msg_sequence,
|
msg_sequence=msg_sequence,
|
||||||
@@ -320,12 +399,12 @@ class LocalAgentRunner(runner.RequestRunner):
|
|||||||
msg_sequence=msg_sequence,
|
msg_sequence=msg_sequence,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 处理完所有调用,再次请求
|
# Non-streaming: use committed model directly (no fallback in tool loop)
|
||||||
msg = await use_llm_model.provider.invoke_llm(
|
msg = await use_llm_model.provider.invoke_llm(
|
||||||
query,
|
query,
|
||||||
use_llm_model,
|
use_llm_model,
|
||||||
req_messages,
|
req_messages,
|
||||||
query.use_funcs,
|
query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [],
|
||||||
extra_args=use_llm_model.model_entity.extra_args,
|
extra_args=use_llm_model.model_entity.extra_args,
|
||||||
remove_think=remove_think,
|
remove_think=remove_think,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import json
|
|||||||
import uuid
|
import uuid
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
from langbot.pkg.utils import httpclient
|
||||||
|
|
||||||
from .. import runner
|
from .. import runner
|
||||||
from ...core import app
|
from ...core import app
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
@@ -217,50 +219,50 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
|||||||
self.ap.logger.debug('no auth')
|
self.ap.logger.debug('no auth')
|
||||||
|
|
||||||
# 调用webhook
|
# 调用webhook
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
if is_stream:
|
if is_stream:
|
||||||
# 流式请求
|
# 流式请求
|
||||||
async with session.post(
|
async with session.post(
|
||||||
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||||
) as response:
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||||
|
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||||
|
|
||||||
|
# 处理流式响应
|
||||||
|
async for chunk in self._process_stream_response(response):
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
async with session.post(
|
||||||
|
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||||
|
) as response:
|
||||||
|
try:
|
||||||
|
async for chunk in self._process_stream_response(response):
|
||||||
|
output_content = chunk.content if chunk.is_final else ''
|
||||||
|
except:
|
||||||
|
# 非流式请求(保持原有逻辑)
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||||
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||||
|
|
||||||
# 处理流式响应
|
# 解析响应
|
||||||
async for chunk in self._process_stream_response(response):
|
response_data = await response.json()
|
||||||
yield chunk
|
self.ap.logger.debug(f'n8n webhook response: {response_data}')
|
||||||
else:
|
|
||||||
async with session.post(
|
|
||||||
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
|
||||||
) as response:
|
|
||||||
try:
|
|
||||||
async for chunk in self._process_stream_response(response):
|
|
||||||
output_content = chunk.content if chunk.is_final else ''
|
|
||||||
except:
|
|
||||||
# 非流式请求(保持原有逻辑)
|
|
||||||
if response.status != 200:
|
|
||||||
error_text = await response.text()
|
|
||||||
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
|
||||||
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
|
||||||
|
|
||||||
# 解析响应
|
# 从响应中提取输出
|
||||||
response_data = await response.json()
|
if self.output_key in response_data:
|
||||||
self.ap.logger.debug(f'n8n webhook response: {response_data}')
|
output_content = response_data[self.output_key]
|
||||||
|
else:
|
||||||
|
# 如果没有指定的输出键,则使用整个响应
|
||||||
|
output_content = json.dumps(response_data, ensure_ascii=False)
|
||||||
|
|
||||||
# 从响应中提取输出
|
# 返回消息
|
||||||
if self.output_key in response_data:
|
yield provider_message.Message(
|
||||||
output_content = response_data[self.output_key]
|
role='assistant',
|
||||||
else:
|
content=output_content,
|
||||||
# 如果没有指定的输出键,则使用整个响应
|
)
|
||||||
output_content = json.dumps(response_data, ensure_ascii=False)
|
|
||||||
|
|
||||||
# 返回消息
|
|
||||||
yield provider_message.Message(
|
|
||||||
role='assistant',
|
|
||||||
content=output_content,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.ap.logger.error(f'n8n webhook call exception: {str(e)}')
|
self.ap.logger.error(f'n8n webhook call exception: {str(e)}')
|
||||||
raise N8nAPIError(f'n8n webhook call exception: {str(e)}')
|
raise N8nAPIError(f'n8n webhook call exception: {str(e)}')
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ class KnowledgeBaseInterface(metaclass=abc.ABCMeta):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def retrieve(self, query: str, top_k: int) -> list[rag_context.RetrievalResultEntry]:
|
async def retrieve(self, query: str, settings: dict | None = None) -> list[rag_context.RetrievalResultEntry]:
|
||||||
"""Retrieve relevant documents from the knowledge base
|
"""Retrieve relevant documents from the knowledge base
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The query string
|
query: The query string
|
||||||
top_k: Number of top results to return
|
settings: Optional per-request retrieval settings overrides
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of retrieve result entries
|
List of retrieve result entries
|
||||||
@@ -45,8 +45,8 @@ class KnowledgeBaseInterface(metaclass=abc.ABCMeta):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_type(self) -> str:
|
def get_knowledge_engine_plugin_id(self) -> str:
|
||||||
"""Get the type of knowledge base (internal/external)"""
|
"""Get the Knowledge Engine plugin ID"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
|||||||
@@ -1,85 +0,0 @@
|
|||||||
"""External knowledge base implementation"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from langbot.pkg.core import app
|
|
||||||
from langbot.pkg.entity.persistence import rag as persistence_rag
|
|
||||||
from langbot_plugin.api.entities.builtin.rag import context as rag_context
|
|
||||||
from .base import KnowledgeBaseInterface
|
|
||||||
|
|
||||||
|
|
||||||
class ExternalKnowledgeBase(KnowledgeBaseInterface):
|
|
||||||
"""External knowledge base that queries via HTTP API or plugin retriever"""
|
|
||||||
|
|
||||||
external_kb_entity: persistence_rag.ExternalKnowledgeBase
|
|
||||||
|
|
||||||
# Plugin retriever instance ID
|
|
||||||
retriever_instance_id: str | None
|
|
||||||
|
|
||||||
def __init__(self, ap: app.Application, external_kb_entity: persistence_rag.ExternalKnowledgeBase):
|
|
||||||
super().__init__(ap)
|
|
||||||
self.external_kb_entity = external_kb_entity
|
|
||||||
self.retriever_instance_id = None
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
"""Initialize the external knowledge base"""
|
|
||||||
# Use KB UUID as instance ID
|
|
||||||
# Instance creation is now handled by the unified sync mechanism
|
|
||||||
# when LangBot connects to runtime
|
|
||||||
self.retriever_instance_id = self.external_kb_entity.uuid
|
|
||||||
|
|
||||||
self.ap.logger.info(
|
|
||||||
f'Initialized external KB {self.external_kb_entity.uuid}, instance will be created by sync mechanism'
|
|
||||||
)
|
|
||||||
|
|
||||||
async def retrieve(self, query: str, top_k: int = 5) -> list[rag_context.RetrievalResultEntry]:
|
|
||||||
"""Retrieve documents from external knowledge base via plugin retriever"""
|
|
||||||
if not self.retriever_instance_id:
|
|
||||||
self.ap.logger.error(f'No retriever instance for KB {self.external_kb_entity.uuid}')
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
results = await self.ap.plugin_connector.retrieve_knowledge(
|
|
||||||
self.external_kb_entity.plugin_author,
|
|
||||||
self.external_kb_entity.plugin_name,
|
|
||||||
self.external_kb_entity.retriever_name,
|
|
||||||
self.retriever_instance_id,
|
|
||||||
{'query': query},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert plugin results to RetrievalResultEntry
|
|
||||||
retrieval_entries = []
|
|
||||||
for result in results:
|
|
||||||
retrieval_entries.append(rag_context.RetrievalResultEntry(**result))
|
|
||||||
|
|
||||||
return retrieval_entries
|
|
||||||
except Exception as e:
|
|
||||||
self.ap.logger.error(f'Plugin retriever error: {e}')
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_uuid(self) -> str:
|
|
||||||
"""Get the UUID of the external knowledge base"""
|
|
||||||
return self.external_kb_entity.uuid
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
"""Get the name of the external knowledge base"""
|
|
||||||
return self.external_kb_entity.name
|
|
||||||
|
|
||||||
def get_type(self) -> str:
|
|
||||||
"""Get the type of knowledge base"""
|
|
||||||
return 'external'
|
|
||||||
|
|
||||||
async def dispose(self):
|
|
||||||
"""Clean up resources"""
|
|
||||||
# Trigger sync to immediately delete the instance from plugin process
|
|
||||||
# This ensures instance is cleaned up without waiting for next LangBot restart
|
|
||||||
try:
|
|
||||||
await self.ap.plugin_connector.sync_polymorphic_component_instances()
|
|
||||||
self.ap.logger.info(
|
|
||||||
f'Disposed external KB {self.external_kb_entity.uuid}, triggered sync to delete instance'
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.ap.logger.error(f'Failed to sync after disposing KB: {e}')
|
|
||||||
@@ -1,18 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import mimetypes
|
||||||
|
import os.path
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
import zipfile
|
import zipfile
|
||||||
import io
|
import io
|
||||||
from .services import parser, chunker
|
from typing import Any
|
||||||
from langbot.pkg.core import app
|
from langbot.pkg.core import app
|
||||||
from langbot.pkg.rag.knowledge.services.embedder import Embedder
|
|
||||||
from langbot.pkg.rag.knowledge.services.retriever import Retriever
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
|
|
||||||
from langbot.pkg.entity.persistence import rag as persistence_rag
|
from langbot.pkg.entity.persistence import rag as persistence_rag
|
||||||
from langbot.pkg.core import taskmgr
|
from langbot.pkg.core import taskmgr
|
||||||
from langbot_plugin.api.entities.builtin.rag import context as rag_context
|
from langbot_plugin.api.entities.builtin.rag import context as rag_context
|
||||||
from .base import KnowledgeBaseInterface
|
from .base import KnowledgeBaseInterface
|
||||||
from .external import ExternalKnowledgeBase
|
|
||||||
|
|
||||||
|
|
||||||
class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||||
@@ -20,28 +21,16 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
|||||||
|
|
||||||
knowledge_base_entity: persistence_rag.KnowledgeBase
|
knowledge_base_entity: persistence_rag.KnowledgeBase
|
||||||
|
|
||||||
parser: parser.FileParser
|
|
||||||
|
|
||||||
chunker: chunker.Chunker
|
|
||||||
|
|
||||||
embedder: Embedder
|
|
||||||
|
|
||||||
retriever: Retriever
|
|
||||||
|
|
||||||
def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase):
|
def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase):
|
||||||
super().__init__(ap)
|
super().__init__(ap)
|
||||||
self.knowledge_base_entity = knowledge_base_entity
|
self.knowledge_base_entity = knowledge_base_entity
|
||||||
self.parser = parser.FileParser(ap=self.ap)
|
|
||||||
self.chunker = chunker.Chunker(ap=self.ap)
|
|
||||||
self.embedder = Embedder(ap=self.ap)
|
|
||||||
self.retriever = Retriever(ap=self.ap)
|
|
||||||
# 传递kb_id给retriever
|
|
||||||
self.retriever.kb_id = knowledge_base_entity.uuid
|
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _store_file_task(self, file: persistence_rag.File, task_context: taskmgr.TaskContext):
|
async def _store_file_task(
|
||||||
|
self, file: persistence_rag.File, task_context: taskmgr.TaskContext, parser_plugin_id: str | None = None
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
# set file status to processing
|
# set file status to processing
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
@@ -50,31 +39,46 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
|||||||
.values(status='processing')
|
.values(status='processing')
|
||||||
)
|
)
|
||||||
|
|
||||||
task_context.set_current_action('Parsing file')
|
task_context.set_current_action('Processing file')
|
||||||
# parse file
|
|
||||||
text = await self.parser.parse(file.file_name, file.extension)
|
|
||||||
if not text:
|
|
||||||
raise Exception(f'No text extracted from file {file.file_name}')
|
|
||||||
|
|
||||||
task_context.set_current_action('Chunking file')
|
# Get file size from storage
|
||||||
# chunk file
|
file_size = await self.ap.storage_mgr.storage_provider.size(file.file_name)
|
||||||
chunks_texts = await self.chunker.chunk(text)
|
|
||||||
if not chunks_texts:
|
|
||||||
raise Exception(f'No chunks extracted from file {file.file_name}')
|
|
||||||
|
|
||||||
task_context.set_current_action('Embedding chunks')
|
# Detect MIME type from extension
|
||||||
|
mime_type, _ = mimetypes.guess_type(file.file_name)
|
||||||
|
if mime_type is None:
|
||||||
|
mime_type = 'application/octet-stream'
|
||||||
|
|
||||||
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(
|
# If a parser plugin is specified, call it before ingestion
|
||||||
self.knowledge_base_entity.embedding_model_uuid
|
parsed_content = None
|
||||||
)
|
if parser_plugin_id:
|
||||||
# embed chunks
|
task_context.set_current_action('Parsing file')
|
||||||
await self.embedder.embed_and_store(
|
file_bytes = await self.ap.storage_mgr.storage_provider.load(file.file_name)
|
||||||
kb_id=self.knowledge_base_entity.uuid,
|
parse_context = {
|
||||||
file_id=file.uuid,
|
'mime_type': mime_type,
|
||||||
chunks=chunks_texts,
|
'filename': file.file_name,
|
||||||
embedding_model=embedding_model,
|
'metadata': {},
|
||||||
|
}
|
||||||
|
parsed_content = await self.ap.plugin_connector.call_parser(parser_plugin_id, parse_context, file_bytes)
|
||||||
|
|
||||||
|
# Call plugin to ingest document
|
||||||
|
result = await self._ingest_document(
|
||||||
|
{
|
||||||
|
'document_id': file.uuid,
|
||||||
|
'filename': file.file_name,
|
||||||
|
'extension': file.extension,
|
||||||
|
'file_size': file_size,
|
||||||
|
'mime_type': mime_type,
|
||||||
|
},
|
||||||
|
file.file_name, # storage path
|
||||||
|
parsed_content=parsed_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check plugin result status
|
||||||
|
if result.get('status') == 'failed':
|
||||||
|
error_msg = result.get('error_message', 'Plugin ingestion returned failed status')
|
||||||
|
raise Exception(error_msg)
|
||||||
|
|
||||||
# set file status to completed
|
# set file status to completed
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.update(persistence_rag.File)
|
sqlalchemy.update(persistence_rag.File)
|
||||||
@@ -97,16 +101,17 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
|||||||
# delete file from storage
|
# delete file from storage
|
||||||
await self.ap.storage_mgr.storage_provider.delete(file.file_name)
|
await self.ap.storage_mgr.storage_provider.delete(file.file_name)
|
||||||
|
|
||||||
async def store_file(self, file_id: str) -> str:
|
async def store_file(self, file_id: str, parser_plugin_id: str | None = None) -> str:
|
||||||
# pre checking
|
# pre checking
|
||||||
if not await self.ap.storage_mgr.storage_provider.exists(file_id):
|
if not await self.ap.storage_mgr.storage_provider.exists(file_id):
|
||||||
raise Exception(f'File {file_id} not found')
|
raise Exception(f'File {file_id} not found')
|
||||||
|
|
||||||
file_name = file_id
|
file_name = file_id
|
||||||
extension = file_name.split('.')[-1].lower()
|
_, ext = os.path.splitext(file_name)
|
||||||
|
extension = ext.lstrip('.').lower() if ext else ''
|
||||||
|
|
||||||
if extension == 'zip':
|
if extension == 'zip':
|
||||||
return await self._store_zip_file(file_id)
|
return await self._store_zip_file(file_id, parser_plugin_id=parser_plugin_id)
|
||||||
|
|
||||||
file_uuid = str(uuid.uuid4())
|
file_uuid = str(uuid.uuid4())
|
||||||
kb_id = self.knowledge_base_entity.uuid
|
kb_id = self.knowledge_base_entity.uuid
|
||||||
@@ -126,7 +131,7 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
|||||||
# run background task asynchronously
|
# run background task asynchronously
|
||||||
ctx = taskmgr.TaskContext.new()
|
ctx = taskmgr.TaskContext.new()
|
||||||
wrapper = self.ap.task_mgr.create_user_task(
|
wrapper = self.ap.task_mgr.create_user_task(
|
||||||
self._store_file_task(file_obj, task_context=ctx),
|
self._store_file_task(file_obj, task_context=ctx, parser_plugin_id=parser_plugin_id),
|
||||||
kind='knowledge-operation',
|
kind='knowledge-operation',
|
||||||
name=f'knowledge-store-file-{file_id}',
|
name=f'knowledge-store-file-{file_id}',
|
||||||
label=f'Store file {file_id}',
|
label=f'Store file {file_id}',
|
||||||
@@ -134,7 +139,7 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
|||||||
)
|
)
|
||||||
return wrapper.id
|
return wrapper.id
|
||||||
|
|
||||||
async def _store_zip_file(self, zip_file_id: str) -> str:
|
async def _store_zip_file(self, zip_file_id: str, parser_plugin_id: str | None = None) -> str:
|
||||||
"""Handle ZIP file by extracting each document and storing them separately."""
|
"""Handle ZIP file by extracting each document and storing them separately."""
|
||||||
self.ap.logger.info(f'Processing ZIP file: {zip_file_id}')
|
self.ap.logger.info(f'Processing ZIP file: {zip_file_id}')
|
||||||
|
|
||||||
@@ -150,7 +155,8 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
|||||||
if file_info.is_dir() or file_info.filename.startswith('.'):
|
if file_info.is_dir() or file_info.filename.startswith('.'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
file_extension = file_info.filename.split('.')[-1].lower()
|
_, file_ext = os.path.splitext(file_info.filename)
|
||||||
|
file_extension = file_ext.lstrip('.').lower()
|
||||||
if file_extension not in supported_extensions:
|
if file_extension not in supported_extensions:
|
||||||
self.ap.logger.debug(f'Skipping unsupported file in ZIP: {file_info.filename}')
|
self.ap.logger.debug(f'Skipping unsupported file in ZIP: {file_info.filename}')
|
||||||
continue
|
continue
|
||||||
@@ -159,18 +165,18 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
|||||||
file_content = zip_ref.read(file_info.filename)
|
file_content = zip_ref.read(file_info.filename)
|
||||||
|
|
||||||
base_name = file_info.filename.replace('/', '_').replace('\\', '_')
|
base_name = file_info.filename.replace('/', '_').replace('\\', '_')
|
||||||
extension = base_name.split('.')[-1]
|
file_stem, file_ext = os.path.splitext(base_name)
|
||||||
file_name = base_name.split('.')[0]
|
extension = file_ext.lstrip('.')
|
||||||
|
|
||||||
if file_name.startswith('__MACOSX'):
|
if file_stem.startswith('__MACOSX'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
extracted_file_id = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension
|
extracted_file_id = file_stem + '_' + str(uuid.uuid4())[:8] + '.' + extension
|
||||||
# save file to storage
|
# save file to storage
|
||||||
|
|
||||||
await self.ap.storage_mgr.storage_provider.save(extracted_file_id, file_content)
|
await self.ap.storage_mgr.storage_provider.save(extracted_file_id, file_content)
|
||||||
|
|
||||||
task_id = await self.store_file(extracted_file_id)
|
task_id = await self.store_file(extracted_file_id, parser_plugin_id=parser_plugin_id)
|
||||||
stored_file_tasks.append(task_id)
|
stored_file_tasks.append(task_id)
|
||||||
|
|
||||||
self.ap.logger.info(
|
self.ap.logger.info(
|
||||||
@@ -189,21 +195,28 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
|||||||
|
|
||||||
return stored_file_tasks[0] if stored_file_tasks else ''
|
return stored_file_tasks[0] if stored_file_tasks else ''
|
||||||
|
|
||||||
async def retrieve(self, query: str, top_k: int) -> list[rag_context.RetrievalResultEntry]:
|
async def retrieve(self, query: str, settings: dict | None = None) -> list[rag_context.RetrievalResultEntry]:
|
||||||
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(
|
# Merge stored retrieval_settings with per-request overrides
|
||||||
self.knowledge_base_entity.embedding_model_uuid
|
stored = self.knowledge_base_entity.retrieval_settings or {}
|
||||||
)
|
merged = {**stored, **(settings or {})}
|
||||||
return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model, top_k)
|
if 'top_k' not in merged:
|
||||||
|
merged['top_k'] = 5 # fallback default
|
||||||
|
|
||||||
|
response = await self._retrieve(query, merged)
|
||||||
|
|
||||||
|
results_data = response.get('results', [])
|
||||||
|
entries = []
|
||||||
|
for r in results_data:
|
||||||
|
if isinstance(r, dict):
|
||||||
|
entries.append(rag_context.RetrievalResultEntry(**r))
|
||||||
|
elif isinstance(r, rag_context.RetrievalResultEntry):
|
||||||
|
entries.append(r)
|
||||||
|
return entries
|
||||||
|
|
||||||
async def delete_file(self, file_id: str):
|
async def delete_file(self, file_id: str):
|
||||||
# delete vector
|
await self._delete_document(file_id)
|
||||||
await self.ap.vector_db_mgr.vector_db.delete_by_file_id(self.knowledge_base_entity.uuid, file_id)
|
|
||||||
|
|
||||||
# delete chunk
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
|
||||||
sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Also cleanup DB record
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id)
|
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id)
|
||||||
)
|
)
|
||||||
@@ -216,32 +229,295 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
|||||||
"""Get the name of the knowledge base"""
|
"""Get the name of the knowledge base"""
|
||||||
return self.knowledge_base_entity.name
|
return self.knowledge_base_entity.name
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_knowledge_engine_plugin_id(self) -> str:
|
||||||
"""Get the type of knowledge base"""
|
"""Get the Knowledge Engine plugin ID"""
|
||||||
return 'internal'
|
return self.knowledge_base_entity.knowledge_engine_plugin_id or ''
|
||||||
|
|
||||||
async def dispose(self):
|
async def dispose(self):
|
||||||
await self.ap.vector_db_mgr.vector_db.delete_collection(self.knowledge_base_entity.uuid)
|
"""Dispose the knowledge base, notifying the plugin to cleanup."""
|
||||||
|
await self._on_kb_delete()
|
||||||
|
|
||||||
|
# ========== Plugin Communication Methods ==========
|
||||||
|
|
||||||
|
async def _on_kb_create(self) -> None:
|
||||||
|
"""Notify plugin about KB creation."""
|
||||||
|
plugin_id = self.knowledge_base_entity.knowledge_engine_plugin_id
|
||||||
|
if not plugin_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = self.knowledge_base_entity.creation_settings or {}
|
||||||
|
self.ap.logger.info(
|
||||||
|
f'Calling RAG plugin {plugin_id}: on_knowledge_base_create(kb_id={self.knowledge_base_entity.uuid})'
|
||||||
|
)
|
||||||
|
await self.ap.plugin_connector.rag_on_kb_create(plugin_id, self.knowledge_base_entity.uuid, config)
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.error(f'Failed to notify plugin {plugin_id} on KB create: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _on_kb_delete(self) -> None:
|
||||||
|
"""Notify plugin about KB deletion."""
|
||||||
|
plugin_id = self.knowledge_base_entity.knowledge_engine_plugin_id
|
||||||
|
if not plugin_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.ap.logger.info(
|
||||||
|
f'Calling RAG plugin {plugin_id}: on_knowledge_base_delete(kb_id={self.knowledge_base_entity.uuid})'
|
||||||
|
)
|
||||||
|
await self.ap.plugin_connector.rag_on_kb_delete(plugin_id, self.knowledge_base_entity.uuid)
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.error(f'Failed to notify plugin {plugin_id} on KB delete: {e}')
|
||||||
|
|
||||||
|
async def _ingest_document(
|
||||||
|
self,
|
||||||
|
file_metadata: dict[str, Any],
|
||||||
|
storage_path: str,
|
||||||
|
parsed_content: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Call plugin to ingest document."""
|
||||||
|
kb = self.knowledge_base_entity
|
||||||
|
plugin_id = kb.knowledge_engine_plugin_id
|
||||||
|
if not plugin_id:
|
||||||
|
self.ap.logger.error(f'No RAG plugin ID configured for KB {kb.uuid}. Ingestion failed.')
|
||||||
|
raise ValueError('RAG Plugin ID required')
|
||||||
|
|
||||||
|
self.ap.logger.info(f'Calling RAG plugin {plugin_id}: ingest(doc={file_metadata.get("filename")})')
|
||||||
|
|
||||||
|
# Inject knowledge_base_id into file metadata as required by SDK schema
|
||||||
|
file_metadata['knowledge_base_id'] = kb.uuid
|
||||||
|
|
||||||
|
context_data = {
|
||||||
|
'file_object': {
|
||||||
|
'metadata': file_metadata,
|
||||||
|
'storage_path': storage_path,
|
||||||
|
},
|
||||||
|
'knowledge_base_id': kb.uuid,
|
||||||
|
'collection_id': kb.collection_id or kb.uuid,
|
||||||
|
'creation_settings': kb.creation_settings or {},
|
||||||
|
'parsed_content': parsed_content,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self.ap.plugin_connector.call_rag_ingest(plugin_id, context_data)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.error(f'Plugin ingestion failed: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
settings: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Call plugin to retrieve documents.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no RAG plugin is configured for this KB.
|
||||||
|
Exception: If the plugin retrieval call fails.
|
||||||
|
"""
|
||||||
|
kb = self.knowledge_base_entity
|
||||||
|
plugin_id = kb.knowledge_engine_plugin_id
|
||||||
|
if not plugin_id:
|
||||||
|
raise ValueError(f'No RAG plugin ID configured for KB {kb.uuid}. Retrieval failed.')
|
||||||
|
|
||||||
|
# Session context (e.g. session_name) stays in retrieval_settings
|
||||||
|
# for plugins that need it. Do NOT move them into filters, as filters
|
||||||
|
# are passed directly to vector_search by some plugins (e.g. LangRAG)
|
||||||
|
# and would cause empty results when the metadata field doesn't exist.
|
||||||
|
filters = settings.pop('filters', {})
|
||||||
|
|
||||||
|
retrieval_context = {
|
||||||
|
'query': query,
|
||||||
|
'knowledge_base_id': kb.uuid,
|
||||||
|
'collection_id': kb.collection_id or kb.uuid,
|
||||||
|
'retrieval_settings': settings,
|
||||||
|
'creation_settings': kb.creation_settings or {},
|
||||||
|
'filters': filters,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await self.ap.plugin_connector.call_rag_retrieve(
|
||||||
|
plugin_id,
|
||||||
|
retrieval_context,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _delete_document(self, document_id: str) -> bool:
|
||||||
|
"""Call plugin to delete document."""
|
||||||
|
kb = self.knowledge_base_entity
|
||||||
|
plugin_id = kb.knowledge_engine_plugin_id
|
||||||
|
if not plugin_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.ap.logger.info(f'Calling RAG plugin {plugin_id}: delete_document(doc_id={document_id})')
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self.ap.plugin_connector.call_rag_delete_document(plugin_id, document_id, kb.uuid)
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.error(f'Plugin document deletion failed: {e}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class RAGManager:
|
class RAGManager:
|
||||||
ap: app.Application
|
ap: app.Application
|
||||||
|
|
||||||
knowledge_bases: list[KnowledgeBaseInterface]
|
knowledge_bases: dict[str, KnowledgeBaseInterface]
|
||||||
|
|
||||||
def __init__(self, ap: app.Application):
|
def __init__(self, ap: app.Application):
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.knowledge_bases = []
|
self.knowledge_bases = {}
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
await self.load_knowledge_bases_from_db()
|
await self.load_knowledge_bases_from_db()
|
||||||
|
|
||||||
|
async def get_all_knowledge_base_details(self) -> list[dict]:
|
||||||
|
"""Get all knowledge bases with enriched Knowledge Engine details."""
|
||||||
|
# 1. Get raw KBs from DB
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
|
||||||
|
knowledge_bases = result.all()
|
||||||
|
|
||||||
|
# 2. Get all available Knowledge Engines for enrichment
|
||||||
|
engine_map = {}
|
||||||
|
if self.ap.plugin_connector.is_enable_plugin:
|
||||||
|
try:
|
||||||
|
engines = await self.ap.plugin_connector.list_knowledge_engines()
|
||||||
|
engine_map = {e['plugin_id']: e for e in engines}
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.warning(f'Failed to list Knowledge Engines: {e}')
|
||||||
|
|
||||||
|
# 3. Serialize and enrich
|
||||||
|
kb_list = []
|
||||||
|
for kb in knowledge_bases:
|
||||||
|
kb_dict = self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, kb)
|
||||||
|
self._enrich_kb_dict(kb_dict, engine_map)
|
||||||
|
kb_list.append(kb_dict)
|
||||||
|
|
||||||
|
return kb_list
|
||||||
|
|
||||||
|
async def get_knowledge_base_details(self, kb_uuid: str) -> dict | None:
|
||||||
|
"""Get specific knowledge base with enriched Knowledge Engine details."""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||||
|
)
|
||||||
|
kb = result.first()
|
||||||
|
if not kb:
|
||||||
|
return None
|
||||||
|
|
||||||
|
kb_dict = self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, kb)
|
||||||
|
|
||||||
|
# Fetch engines
|
||||||
|
engine_map = {}
|
||||||
|
if self.ap.plugin_connector.is_enable_plugin:
|
||||||
|
try:
|
||||||
|
engines = await self.ap.plugin_connector.list_knowledge_engines()
|
||||||
|
engine_map = {e['plugin_id']: e for e in engines}
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.warning(f'Failed to list Knowledge Engines: {e}')
|
||||||
|
|
||||||
|
self._enrich_kb_dict(kb_dict, engine_map)
|
||||||
|
return kb_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_i18n_name(name) -> dict:
|
||||||
|
"""Ensure name is always an I18nObject-compatible dict.
|
||||||
|
|
||||||
|
If *name* is already a dict (with ``en_US`` / ``zh_Hans`` keys) it is
|
||||||
|
returned as-is. A plain string is wrapped into an I18nObject so the
|
||||||
|
frontend ``extractI18nObject`` helper never receives an unexpected type.
|
||||||
|
"""
|
||||||
|
if isinstance(name, dict):
|
||||||
|
return name
|
||||||
|
return {'en_US': str(name), 'zh_Hans': str(name)}
|
||||||
|
|
||||||
|
def _enrich_kb_dict(self, kb_dict: dict, engine_map: dict) -> None:
|
||||||
|
"""Helper to inject engine info into KB dict."""
|
||||||
|
plugin_id = kb_dict.get('knowledge_engine_plugin_id')
|
||||||
|
|
||||||
|
# Default fallback structure — name must be I18nObject for frontend compatibility
|
||||||
|
fallback_name = self._to_i18n_name(plugin_id or 'Internal (Legacy)')
|
||||||
|
fallback_info = {
|
||||||
|
'plugin_id': plugin_id,
|
||||||
|
'name': fallback_name,
|
||||||
|
'capabilities': [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if not plugin_id:
|
||||||
|
kb_dict['knowledge_engine'] = fallback_info
|
||||||
|
return
|
||||||
|
|
||||||
|
engine_info = engine_map.get(plugin_id)
|
||||||
|
if engine_info:
|
||||||
|
kb_dict['knowledge_engine'] = {
|
||||||
|
'plugin_id': plugin_id,
|
||||||
|
'name': self._to_i18n_name(engine_info.get('name', plugin_id)),
|
||||||
|
'capabilities': engine_info.get('capabilities', []),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
kb_dict['knowledge_engine'] = fallback_info
|
||||||
|
|
||||||
|
async def create_knowledge_base(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
knowledge_engine_plugin_id: str,
|
||||||
|
creation_settings: dict,
|
||||||
|
retrieval_settings: dict | None = None,
|
||||||
|
description: str = '',
|
||||||
|
) -> persistence_rag.KnowledgeBase:
|
||||||
|
"""Create a new knowledge base using a RAG plugin."""
|
||||||
|
# Validate that the Knowledge Engine plugin exists
|
||||||
|
if self.ap.plugin_connector.is_enable_plugin:
|
||||||
|
try:
|
||||||
|
engines = await self.ap.plugin_connector.list_knowledge_engines()
|
||||||
|
engine_ids = [e.get('plugin_id') for e in engines]
|
||||||
|
if knowledge_engine_plugin_id not in engine_ids:
|
||||||
|
raise ValueError(f'Knowledge Engine plugin {knowledge_engine_plugin_id} not found')
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.warning(f'Failed to validate Knowledge Engine plugin existence: {e}')
|
||||||
|
|
||||||
|
kb_uuid = str(uuid.uuid4())
|
||||||
|
# Use UUID as collection ID by default for isolation
|
||||||
|
collection_id = kb_uuid
|
||||||
|
|
||||||
|
kb_data = {
|
||||||
|
'uuid': kb_uuid,
|
||||||
|
'name': name,
|
||||||
|
'description': description,
|
||||||
|
'knowledge_engine_plugin_id': knowledge_engine_plugin_id,
|
||||||
|
'collection_id': collection_id,
|
||||||
|
'creation_settings': creation_settings,
|
||||||
|
'retrieval_settings': retrieval_settings or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create Entity
|
||||||
|
kb = persistence_rag.KnowledgeBase(**kb_data)
|
||||||
|
|
||||||
|
# Persist
|
||||||
|
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.KnowledgeBase).values(kb_data))
|
||||||
|
|
||||||
|
# Load into Runtime
|
||||||
|
runtime_kb = await self.load_knowledge_base(kb)
|
||||||
|
|
||||||
|
# Notify Plugin — rollback DB record and runtime entry on failure
|
||||||
|
try:
|
||||||
|
await runtime_kb._on_kb_create()
|
||||||
|
except Exception:
|
||||||
|
self.knowledge_bases.pop(kb_uuid, None)
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
self.ap.logger.info(f'Created new Knowledge Base {name} ({kb_uuid}) using plugin {knowledge_engine_plugin_id}')
|
||||||
|
return kb
|
||||||
|
|
||||||
async def load_knowledge_bases_from_db(self):
|
async def load_knowledge_bases_from_db(self):
|
||||||
self.ap.logger.info('Loading knowledge bases from db...')
|
self.ap.logger.info('Loading knowledge bases from db...')
|
||||||
|
|
||||||
self.knowledge_bases = []
|
self.knowledge_bases = {}
|
||||||
|
|
||||||
# Load internal knowledge bases
|
# Load knowledge bases
|
||||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
|
||||||
knowledge_bases = result.all()
|
knowledge_bases = result.all()
|
||||||
|
|
||||||
@@ -253,86 +529,37 @@ class RAGManager:
|
|||||||
f'Error loading knowledge base {knowledge_base.uuid}: {e}\n{traceback.format_exc()}'
|
f'Error loading knowledge base {knowledge_base.uuid}: {e}\n{traceback.format_exc()}'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load external knowledge bases
|
|
||||||
external_result = await self.ap.persistence_mgr.execute_async(
|
|
||||||
sqlalchemy.select(persistence_rag.ExternalKnowledgeBase)
|
|
||||||
)
|
|
||||||
external_kbs = external_result.all()
|
|
||||||
|
|
||||||
for external_kb in external_kbs:
|
|
||||||
try:
|
|
||||||
# Don't trigger sync during batch loading - will sync once after LangBot connects to runtime
|
|
||||||
await self.load_external_knowledge_base(external_kb, trigger_sync=False)
|
|
||||||
except Exception as e:
|
|
||||||
self.ap.logger.error(
|
|
||||||
f'Error loading external knowledge base {external_kb.uuid}: {e}\n{traceback.format_exc()}'
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_knowledge_base(
|
async def load_knowledge_base(
|
||||||
self,
|
self,
|
||||||
knowledge_base_entity: persistence_rag.KnowledgeBase | sqlalchemy.Row | dict,
|
knowledge_base_entity: persistence_rag.KnowledgeBase | sqlalchemy.Row | dict,
|
||||||
) -> RuntimeKnowledgeBase:
|
) -> RuntimeKnowledgeBase:
|
||||||
if isinstance(knowledge_base_entity, sqlalchemy.Row):
|
if isinstance(knowledge_base_entity, sqlalchemy.Row):
|
||||||
|
# Safe access to _mapping for SQLAlchemy 1.4+
|
||||||
knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity._mapping)
|
knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity._mapping)
|
||||||
elif isinstance(knowledge_base_entity, dict):
|
elif isinstance(knowledge_base_entity, dict):
|
||||||
knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity)
|
# Filter out non-database fields (like knowledge_engine which is computed)
|
||||||
|
filtered_dict = {
|
||||||
|
k: v for k, v in knowledge_base_entity.items() if k in persistence_rag.KnowledgeBase.ALL_DB_FIELDS
|
||||||
|
}
|
||||||
|
knowledge_base_entity = persistence_rag.KnowledgeBase(**filtered_dict)
|
||||||
|
|
||||||
runtime_knowledge_base = RuntimeKnowledgeBase(ap=self.ap, knowledge_base_entity=knowledge_base_entity)
|
runtime_knowledge_base = RuntimeKnowledgeBase(ap=self.ap, knowledge_base_entity=knowledge_base_entity)
|
||||||
|
|
||||||
await runtime_knowledge_base.initialize()
|
await runtime_knowledge_base.initialize()
|
||||||
|
|
||||||
self.knowledge_bases.append(runtime_knowledge_base)
|
self.knowledge_bases[runtime_knowledge_base.get_uuid()] = runtime_knowledge_base
|
||||||
|
|
||||||
return runtime_knowledge_base
|
return runtime_knowledge_base
|
||||||
|
|
||||||
async def load_external_knowledge_base(
|
|
||||||
self,
|
|
||||||
external_kb_entity: persistence_rag.ExternalKnowledgeBase | sqlalchemy.Row | dict,
|
|
||||||
trigger_sync: bool = True,
|
|
||||||
) -> ExternalKnowledgeBase:
|
|
||||||
"""Load external knowledge base into runtime
|
|
||||||
|
|
||||||
Args:
|
|
||||||
external_kb_entity: External KB entity to load
|
|
||||||
trigger_sync: Whether to trigger sync after loading (default True for manual creation, False for batch loading)
|
|
||||||
"""
|
|
||||||
if isinstance(external_kb_entity, sqlalchemy.Row):
|
|
||||||
external_kb_entity = persistence_rag.ExternalKnowledgeBase(**external_kb_entity._mapping)
|
|
||||||
elif isinstance(external_kb_entity, dict):
|
|
||||||
external_kb_entity = persistence_rag.ExternalKnowledgeBase(**external_kb_entity)
|
|
||||||
|
|
||||||
external_kb = ExternalKnowledgeBase(ap=self.ap, external_kb_entity=external_kb_entity)
|
|
||||||
|
|
||||||
await external_kb.initialize()
|
|
||||||
|
|
||||||
self.knowledge_bases.append(external_kb)
|
|
||||||
|
|
||||||
# Trigger sync to create the instance immediately (for manual creation)
|
|
||||||
# Skip sync during batch loading from DB to avoid multiple sync calls
|
|
||||||
if trigger_sync:
|
|
||||||
try:
|
|
||||||
await self.ap.plugin_connector.sync_polymorphic_component_instances()
|
|
||||||
self.ap.logger.info(f'Triggered sync after loading external KB {external_kb_entity.uuid}')
|
|
||||||
except Exception as e:
|
|
||||||
self.ap.logger.error(f'Failed to sync after loading external KB: {e}')
|
|
||||||
|
|
||||||
return external_kb
|
|
||||||
|
|
||||||
async def get_knowledge_base_by_uuid(self, kb_uuid: str) -> KnowledgeBaseInterface | None:
|
async def get_knowledge_base_by_uuid(self, kb_uuid: str) -> KnowledgeBaseInterface | None:
|
||||||
for kb in self.knowledge_bases:
|
return self.knowledge_bases.get(kb_uuid)
|
||||||
if kb.get_uuid() == kb_uuid:
|
|
||||||
return kb
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def remove_knowledge_base_from_runtime(self, kb_uuid: str):
|
async def remove_knowledge_base_from_runtime(self, kb_uuid: str):
|
||||||
for kb in self.knowledge_bases:
|
self.knowledge_bases.pop(kb_uuid, None)
|
||||||
if kb.get_uuid() == kb_uuid:
|
|
||||||
self.knowledge_bases.remove(kb)
|
|
||||||
return
|
|
||||||
|
|
||||||
async def delete_knowledge_base(self, kb_uuid: str):
|
async def delete_knowledge_base(self, kb_uuid: str):
|
||||||
for kb in self.knowledge_bases:
|
kb = self.knowledge_bases.pop(kb_uuid, None)
|
||||||
if kb.get_uuid() == kb_uuid:
|
if kb is not None:
|
||||||
await kb.dispose()
|
await kb.dispose()
|
||||||
self.knowledge_bases.remove(kb)
|
else:
|
||||||
return
|
self.ap.logger.warning(f'Knowledge base {kb_uuid} not found in runtime, skipping plugin notification')
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
# 封装异步操作
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
|
||||||
class BaseService:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def _run_sync(self, func, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
在单独的线程中运行同步函数。
|
|
||||||
如果第一个参数是 session,则在 to_thread 中获取新的 session。
|
|
||||||
"""
|
|
||||||
|
|
||||||
return await asyncio.to_thread(func, *args, **kwargs)
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import List
|
|
||||||
from langbot.pkg.rag.knowledge.services import base_service
|
|
||||||
from langbot.pkg.core import app
|
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
||||||
|
|
||||||
|
|
||||||
class Chunker(base_service.BaseService):
|
|
||||||
"""
|
|
||||||
A class for splitting long texts into smaller, overlapping chunks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50):
|
|
||||||
self.ap = ap
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
self.chunk_overlap = chunk_overlap
|
|
||||||
if self.chunk_overlap >= self.chunk_size:
|
|
||||||
self.ap.logger.warning(
|
|
||||||
'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.'
|
|
||||||
)
|
|
||||||
|
|
||||||
def _split_text_sync(self, text: str) -> List[str]:
|
|
||||||
"""
|
|
||||||
Synchronously splits a long text into chunks with specified overlap.
|
|
||||||
This is a CPU-bound operation, intended to be run in a separate thread.
|
|
||||||
"""
|
|
||||||
if not text:
|
|
||||||
return []
|
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
|
||||||
chunk_size=self.chunk_size,
|
|
||||||
chunk_overlap=self.chunk_overlap,
|
|
||||||
length_function=len,
|
|
||||||
is_separator_regex=False,
|
|
||||||
)
|
|
||||||
return text_splitter.split_text(text)
|
|
||||||
|
|
||||||
async def chunk(self, text: str) -> List[str]:
|
|
||||||
"""
|
|
||||||
Asynchronously chunks a given text into smaller pieces.
|
|
||||||
"""
|
|
||||||
self.ap.logger.info(f'Chunking text (length: {len(text)})...')
|
|
||||||
# Run the synchronous splitting logic in a separate thread
|
|
||||||
chunks = await self._run_sync(self._split_text_sync, text)
|
|
||||||
self.ap.logger.info(f'Text chunked into {len(chunks)} pieces.')
|
|
||||||
self.ap.logger.debug(f'Chunks: {json.dumps(chunks, indent=4, ensure_ascii=False)}')
|
|
||||||
return chunks
|
|
||||||
@@ -1,55 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import uuid
|
|
||||||
from typing import List
|
|
||||||
from langbot.pkg.rag.knowledge.services.base_service import BaseService
|
|
||||||
from langbot.pkg.entity.persistence import rag as persistence_rag
|
|
||||||
from langbot.pkg.core import app
|
|
||||||
from langbot.pkg.provider.modelmgr.requester import RuntimeEmbeddingModel
|
|
||||||
import sqlalchemy
|
|
||||||
|
|
||||||
|
|
||||||
class Embedder(BaseService):
|
|
||||||
def __init__(self, ap: app.Application) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.ap = ap
|
|
||||||
|
|
||||||
async def embed_and_store(
|
|
||||||
self, kb_id: str, file_id: str, chunks: List[str], embedding_model: RuntimeEmbeddingModel
|
|
||||||
) -> list[persistence_rag.Chunk]:
|
|
||||||
# save chunk to db
|
|
||||||
chunk_entities: list[persistence_rag.Chunk] = []
|
|
||||||
chunk_ids: list[str] = []
|
|
||||||
|
|
||||||
for chunk_text in chunks:
|
|
||||||
chunk_uuid = str(uuid.uuid4())
|
|
||||||
chunk_ids.append(chunk_uuid)
|
|
||||||
chunk_entity = persistence_rag.Chunk(uuid=chunk_uuid, file_id=file_id, text=chunk_text)
|
|
||||||
chunk_entities.append(chunk_entity)
|
|
||||||
|
|
||||||
chunk_dicts = [
|
|
||||||
self.ap.persistence_mgr.serialize_model(persistence_rag.Chunk, chunk) for chunk in chunk_entities
|
|
||||||
]
|
|
||||||
|
|
||||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts))
|
|
||||||
|
|
||||||
# get embeddings (batch size limit: 64 for OpenAI)
|
|
||||||
MAX_BATCH_SIZE = 64
|
|
||||||
embeddings_list: list[list[float]] = []
|
|
||||||
|
|
||||||
for i in range(0, len(chunks), MAX_BATCH_SIZE):
|
|
||||||
batch = chunks[i : i + MAX_BATCH_SIZE]
|
|
||||||
batch_embeddings = await embedding_model.provider.invoke_embedding(
|
|
||||||
model=embedding_model,
|
|
||||||
input_text=batch,
|
|
||||||
extra_args={}, # TODO: add extra args
|
|
||||||
knowledge_base_id=kb_id,
|
|
||||||
call_type='embedding',
|
|
||||||
)
|
|
||||||
embeddings_list.extend(batch_embeddings)
|
|
||||||
|
|
||||||
# save embeddings to vdb
|
|
||||||
await self.ap.vector_db_mgr.vector_db.add_embeddings(kb_id, chunk_ids, embeddings_list, chunk_dicts)
|
|
||||||
|
|
||||||
self.ap.logger.info(f'Successfully saved {len(chunk_entities)} embeddings to Knowledge Base.')
|
|
||||||
|
|
||||||
return chunk_entities
|
|
||||||
@@ -1,291 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import PyPDF2
|
|
||||||
import io
|
|
||||||
from docx import Document
|
|
||||||
import chardet
|
|
||||||
from typing import Union, Callable, Any
|
|
||||||
import markdown
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
import re
|
|
||||||
import asyncio # Import asyncio for async operations
|
|
||||||
from langbot.pkg.core import app
|
|
||||||
|
|
||||||
|
|
||||||
class FileParser:
|
|
||||||
"""
|
|
||||||
A robust file parser class to extract text content from various document formats.
|
|
||||||
It supports TXT, PDF, DOCX, XLSX, CSV, Markdown, HTML, and EPUB files.
|
|
||||||
All core file reading operations are designed to be run synchronously in a thread pool
|
|
||||||
to avoid blocking the asyncio event loop.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, ap: app.Application):
|
|
||||||
self.ap = ap
|
|
||||||
|
|
||||||
async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
||||||
"""
|
|
||||||
Runs a synchronous function in a separate thread to prevent blocking the event loop.
|
|
||||||
This is a general utility method for wrapping blocking I/O operations.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return await asyncio.to_thread(sync_func, *args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
self.ap.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}')
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def parse(self, file_name: str, extension: str) -> Union[str, None]:
|
|
||||||
"""
|
|
||||||
Parses the file based on its extension and returns the extracted text content.
|
|
||||||
This is the main asynchronous entry point for parsing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_name (str): The name of the file to be parsed, get from ap.storage_mgr
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[str, None]: The extracted text content as a single string, or None if parsing fails.
|
|
||||||
"""
|
|
||||||
|
|
||||||
file_extension = extension.lower()
|
|
||||||
parser_method = getattr(self, f'_parse_{file_extension}', None)
|
|
||||||
|
|
||||||
if parser_method is None:
|
|
||||||
self.ap.logger.error(f'Unsupported file format: {file_extension} for file {file_name}')
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Pass file_path to the specific parser methods
|
|
||||||
return await parser_method(file_name)
|
|
||||||
except Exception as e:
|
|
||||||
self.ap.logger.error(f'Failed to parse {file_extension} file {file_name}: {e}')
|
|
||||||
return None
|
|
||||||
|
|
||||||
# --- Helper for reading files with encoding detection ---
|
|
||||||
async def _read_file_content(self, file_name: str) -> Union[str, bytes]:
|
|
||||||
"""
|
|
||||||
Reads a file with automatic encoding detection, ensuring the synchronous
|
|
||||||
file read operation runs in a separate thread.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# def _read_sync():
|
|
||||||
# with open(file_path, 'rb') as file:
|
|
||||||
# raw_data = file.read()
|
|
||||||
# detected = chardet.detect(raw_data)
|
|
||||||
# encoding = detected['encoding'] or 'utf-8'
|
|
||||||
|
|
||||||
# if mode == 'r':
|
|
||||||
# return raw_data.decode(encoding, errors='ignore')
|
|
||||||
# return raw_data # For binary mode
|
|
||||||
|
|
||||||
# return await self._run_sync(_read_sync)
|
|
||||||
file_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
|
||||||
|
|
||||||
detected = chardet.detect(file_bytes)
|
|
||||||
encoding = detected['encoding'] or 'utf-8'
|
|
||||||
|
|
||||||
return file_bytes.decode(encoding, errors='ignore')
|
|
||||||
|
|
||||||
# --- Specific Parser Methods ---
|
|
||||||
|
|
||||||
async def _parse_txt(self, file_name: str) -> str:
|
|
||||||
"""Parses a TXT file and returns its content."""
|
|
||||||
self.ap.logger.info(f'Parsing TXT file: {file_name}')
|
|
||||||
return await self._read_file_content(file_name)
|
|
||||||
|
|
||||||
async def _parse_pdf(self, file_name: str) -> str:
|
|
||||||
"""Parses a PDF file and returns its text content."""
|
|
||||||
self.ap.logger.info(f'Parsing PDF file: {file_name}')
|
|
||||||
|
|
||||||
# def _parse_pdf_sync():
|
|
||||||
# text_content = []
|
|
||||||
# with open(file_name, 'rb') as file:
|
|
||||||
# pdf_reader = PyPDF2.PdfReader(file)
|
|
||||||
# for page in pdf_reader.pages:
|
|
||||||
# text = page.extract_text()
|
|
||||||
# if text:
|
|
||||||
# text_content.append(text)
|
|
||||||
# return '\n'.join(text_content)
|
|
||||||
|
|
||||||
# return await self._run_sync(_parse_pdf_sync)
|
|
||||||
|
|
||||||
pdf_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
|
||||||
|
|
||||||
def _parse_pdf_sync():
|
|
||||||
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
|
|
||||||
text_content = []
|
|
||||||
for page in pdf_reader.pages:
|
|
||||||
text = page.extract_text()
|
|
||||||
if text:
|
|
||||||
text_content.append(text)
|
|
||||||
return '\n'.join(text_content)
|
|
||||||
|
|
||||||
return await self._run_sync(_parse_pdf_sync)
|
|
||||||
|
|
||||||
async def _parse_docx(self, file_name: str) -> str:
|
|
||||||
"""Parses a DOCX file and returns its text content."""
|
|
||||||
self.ap.logger.info(f'Parsing DOCX file: {file_name}')
|
|
||||||
|
|
||||||
docx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
|
||||||
|
|
||||||
def _parse_docx_sync():
|
|
||||||
doc = Document(io.BytesIO(docx_bytes))
|
|
||||||
text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()]
|
|
||||||
return '\n'.join(text_content)
|
|
||||||
|
|
||||||
return await self._run_sync(_parse_docx_sync)
|
|
||||||
|
|
||||||
async def _parse_doc(self, file_name: str) -> str:
|
|
||||||
"""Handles .doc files, explicitly stating lack of direct support."""
|
|
||||||
self.ap.logger.warning(f'Direct .doc parsing is not supported for {file_name}. Please convert to .docx first.')
|
|
||||||
raise NotImplementedError('Direct .doc parsing not supported. Please convert to .docx first.')
|
|
||||||
|
|
||||||
# async def _parse_xlsx(self, file_name: str) -> str:
|
|
||||||
# """Parses an XLSX file, returning text from all sheets."""
|
|
||||||
# self.ap.logger.info(f'Parsing XLSX file: {file_name}')
|
|
||||||
|
|
||||||
# xlsx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
|
||||||
|
|
||||||
# def _parse_xlsx_sync():
|
|
||||||
# excel_file = pd.ExcelFile(io.BytesIO(xlsx_bytes))
|
|
||||||
# all_sheet_content = []
|
|
||||||
# for sheet_name in excel_file.sheet_names:
|
|
||||||
# df = pd.read_excel(io.BytesIO(xlsx_bytes), sheet_name=sheet_name)
|
|
||||||
# sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n'
|
|
||||||
# all_sheet_content.append(sheet_text)
|
|
||||||
# return '\n'.join(all_sheet_content)
|
|
||||||
|
|
||||||
# return await self._run_sync(_parse_xlsx_sync)
|
|
||||||
|
|
||||||
# async def _parse_csv(self, file_name: str) -> str:
|
|
||||||
# """Parses a CSV file and returns its content as a string."""
|
|
||||||
# self.ap.logger.info(f'Parsing CSV file: {file_name}')
|
|
||||||
|
|
||||||
# csv_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
|
||||||
|
|
||||||
# def _parse_csv_sync():
|
|
||||||
# # pd.read_csv can often detect encoding, but explicit detection is safer
|
|
||||||
# # raw_data = self._read_file_content(
|
|
||||||
# # file_name, mode='rb'
|
|
||||||
# # ) # Note: this will need to be await outside this sync function
|
|
||||||
# # _ = raw_data
|
|
||||||
# # For simplicity, we'll let pandas handle encoding internally after a raw read.
|
|
||||||
# # A more robust solution might pass encoding directly to pd.read_csv after detection.
|
|
||||||
# detected = chardet.detect(io.BytesIO(csv_bytes))
|
|
||||||
# encoding = detected['encoding'] or 'utf-8'
|
|
||||||
# df = pd.read_csv(io.BytesIO(csv_bytes), encoding=encoding)
|
|
||||||
# return df.to_string(index=False)
|
|
||||||
|
|
||||||
# return await self._run_sync(_parse_csv_sync)
|
|
||||||
|
|
||||||
async def _parse_md(self, file_name: str) -> str:
|
|
||||||
"""Parses a Markdown file, converting it to structured plain text."""
|
|
||||||
self.ap.logger.info(f'Parsing Markdown file: {file_name}')
|
|
||||||
|
|
||||||
md_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
|
||||||
|
|
||||||
def _parse_markdown_sync():
|
|
||||||
md_content = io.BytesIO(md_bytes).read().decode('utf-8', errors='ignore')
|
|
||||||
html_content = markdown.markdown(
|
|
||||||
md_content, extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code']
|
|
||||||
)
|
|
||||||
soup = BeautifulSoup(html_content, 'html.parser')
|
|
||||||
text_parts = []
|
|
||||||
for element in soup.children:
|
|
||||||
if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
|
|
||||||
level = int(element.name[1])
|
|
||||||
text_parts.append('#' * level + ' ' + element.get_text().strip())
|
|
||||||
elif element.name == 'p':
|
|
||||||
text = element.get_text().strip()
|
|
||||||
if text:
|
|
||||||
text_parts.append(text)
|
|
||||||
elif element.name in ['ul', 'ol']:
|
|
||||||
for li in element.find_all('li'):
|
|
||||||
text_parts.append(f'* {li.get_text().strip()}')
|
|
||||||
elif element.name == 'pre':
|
|
||||||
code_block = element.get_text().strip()
|
|
||||||
if code_block:
|
|
||||||
text_parts.append(f'```\n{code_block}\n```')
|
|
||||||
elif element.name == 'table':
|
|
||||||
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
|
|
||||||
if table_str:
|
|
||||||
text_parts.append(table_str)
|
|
||||||
elif element.name:
|
|
||||||
text = element.get_text(separator=' ', strip=True)
|
|
||||||
if text:
|
|
||||||
text_parts.append(text)
|
|
||||||
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
|
|
||||||
return cleaned_text.strip()
|
|
||||||
|
|
||||||
return await self._run_sync(_parse_markdown_sync)
|
|
||||||
|
|
||||||
async def _parse_html(self, file_name: str) -> str:
|
|
||||||
"""Parses an HTML file, extracting structured plain text."""
|
|
||||||
self.ap.logger.info(f'Parsing HTML file: {file_name}')
|
|
||||||
|
|
||||||
html_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
|
||||||
|
|
||||||
def _parse_html_sync():
|
|
||||||
html_content = io.BytesIO(html_bytes).read().decode('utf-8', errors='ignore')
|
|
||||||
soup = BeautifulSoup(html_content, 'html.parser')
|
|
||||||
for script_or_style in soup(['script', 'style']):
|
|
||||||
script_or_style.decompose()
|
|
||||||
text_parts = []
|
|
||||||
for element in soup.body.children if soup.body else soup.children:
|
|
||||||
if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
|
|
||||||
level = int(element.name[1])
|
|
||||||
text_parts.append('#' * level + ' ' + element.get_text().strip())
|
|
||||||
elif element.name == 'p':
|
|
||||||
text = element.get_text().strip()
|
|
||||||
if text:
|
|
||||||
text_parts.append(text)
|
|
||||||
elif element.name in ['ul', 'ol']:
|
|
||||||
for li in element.find_all('li'):
|
|
||||||
text = li.get_text().strip()
|
|
||||||
if text:
|
|
||||||
text_parts.append(f'* {text}')
|
|
||||||
elif element.name == 'table':
|
|
||||||
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
|
|
||||||
if table_str:
|
|
||||||
text_parts.append(table_str)
|
|
||||||
elif element.name:
|
|
||||||
text = element.get_text(separator=' ', strip=True)
|
|
||||||
if text:
|
|
||||||
text_parts.append(text)
|
|
||||||
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
|
|
||||||
return cleaned_text.strip()
|
|
||||||
|
|
||||||
return await self._run_sync(_parse_html_sync)
|
|
||||||
|
|
||||||
def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int):
|
|
||||||
"""Recursively adds TOC items to text_content (synchronous helper)."""
|
|
||||||
indent = ' ' * level
|
|
||||||
for item in toc_list:
|
|
||||||
if isinstance(item, tuple):
|
|
||||||
chapter, subchapters = item
|
|
||||||
text_content.append(f'{indent}- {chapter.title}')
|
|
||||||
self._add_toc_items_sync(subchapters, text_content, level + 1)
|
|
||||||
else:
|
|
||||||
text_content.append(f'{indent}- {item.title}')
|
|
||||||
|
|
||||||
def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str:
|
|
||||||
"""Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous)."""
|
|
||||||
headers = [th.get_text().strip() for th in table_element.find_all('th')]
|
|
||||||
rows = []
|
|
||||||
for tr in table_element.find_all('tr'):
|
|
||||||
cells = [td.get_text().strip() for td in tr.find_all('td')]
|
|
||||||
if cells:
|
|
||||||
rows.append(cells)
|
|
||||||
|
|
||||||
if not headers and not rows:
|
|
||||||
return ''
|
|
||||||
|
|
||||||
table_lines = []
|
|
||||||
if headers:
|
|
||||||
table_lines.append(' | '.join(headers))
|
|
||||||
table_lines.append(' | '.join(['---'] * len(headers)))
|
|
||||||
|
|
||||||
for row_cells in rows:
|
|
||||||
padded_cells = row_cells + [''] * (len(headers) - len(row_cells)) if headers else row_cells
|
|
||||||
table_lines.append(' | '.join(padded_cells))
|
|
||||||
|
|
||||||
return '\n'.join(table_lines)
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from . import base_service
|
|
||||||
from ....core import app
|
|
||||||
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
|
|
||||||
from langbot_plugin.api.entities.builtin.rag import context as rag_context
|
|
||||||
from langbot_plugin.api.entities.builtin.provider.message import ContentElement
|
|
||||||
|
|
||||||
|
|
||||||
class Retriever(base_service.BaseService):
|
|
||||||
def __init__(self, ap: app.Application):
|
|
||||||
super().__init__()
|
|
||||||
self.ap = ap
|
|
||||||
|
|
||||||
async def retrieve(
|
|
||||||
self, kb_id: str, query: str, embedding_model: RuntimeEmbeddingModel, k: int = 5
|
|
||||||
) -> list[rag_context.RetrievalResultEntry]:
|
|
||||||
self.ap.logger.info(
|
|
||||||
f"Retrieving for query: '{query[:10]}' with k={k} using {embedding_model.model_entity.uuid}"
|
|
||||||
)
|
|
||||||
|
|
||||||
query_embedding: list[float] = await embedding_model.provider.invoke_embedding(
|
|
||||||
model=embedding_model,
|
|
||||||
input_text=[query],
|
|
||||||
extra_args={}, # TODO: add extra args
|
|
||||||
knowledge_base_id=kb_id,
|
|
||||||
query_text=query,
|
|
||||||
call_type='retrieve',
|
|
||||||
)
|
|
||||||
|
|
||||||
vector_results = await self.ap.vector_db_mgr.vector_db.search(kb_id, query_embedding[0], k)
|
|
||||||
|
|
||||||
# 'ids' shape mirrors the Chroma-style response contract for compatibility
|
|
||||||
matched_vector_ids = vector_results.get('ids', [[]])[0]
|
|
||||||
distances = vector_results.get('distances', [[]])[0]
|
|
||||||
vector_metadatas = vector_results.get('metadatas', [[]])[0]
|
|
||||||
|
|
||||||
if not matched_vector_ids:
|
|
||||||
self.ap.logger.info('No relevant chunks found in vector database.')
|
|
||||||
return []
|
|
||||||
|
|
||||||
result: list[rag_context.RetrievalResultEntry] = []
|
|
||||||
|
|
||||||
for i, id in enumerate(matched_vector_ids):
|
|
||||||
entry = rag_context.RetrievalResultEntry(
|
|
||||||
id=id,
|
|
||||||
content=[ContentElement.from_text(vector_metadatas[i].get('text', ''))],
|
|
||||||
metadata=vector_metadatas[i],
|
|
||||||
distance=distances[i],
|
|
||||||
)
|
|
||||||
result.append(entry)
|
|
||||||
|
|
||||||
return result
|
|
||||||
1
src/langbot/pkg/rag/service/__init__.py
Normal file
1
src/langbot/pkg/rag/service/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .runtime import RAGRuntimeService as RAGRuntimeService
|
||||||
116
src/langbot/pkg/rag/service/runtime.py
Normal file
116
src/langbot/pkg/rag/service/runtime.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import posixpath
|
||||||
|
from typing import Any
|
||||||
|
from langbot.pkg.core import app
|
||||||
|
|
||||||
|
|
||||||
|
class RAGRuntimeService:
|
||||||
|
"""Service to handle RAG-related requests from plugins (Runtime).
|
||||||
|
|
||||||
|
This service acts as the bridge between plugin RPC requests and
|
||||||
|
LangBot's infrastructure (embedding models, vector databases, file storage).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ap: app.Application):
|
||||||
|
self.ap = ap
|
||||||
|
|
||||||
|
async def vector_upsert(
|
||||||
|
self,
|
||||||
|
collection_id: str,
|
||||||
|
vectors: list[list[float]],
|
||||||
|
ids: list[str],
|
||||||
|
metadata: list[dict[str, Any]] | None = None,
|
||||||
|
documents: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle VECTOR_UPSERT action."""
|
||||||
|
metadatas = metadata if metadata else [{} for _ in vectors]
|
||||||
|
await self.ap.vector_db_mgr.upsert(
|
||||||
|
collection_name=collection_id,
|
||||||
|
vectors=vectors,
|
||||||
|
ids=ids,
|
||||||
|
metadata=metadatas,
|
||||||
|
documents=documents,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def vector_search(
|
||||||
|
self,
|
||||||
|
collection_id: str,
|
||||||
|
query_vector: list[float],
|
||||||
|
top_k: int,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
search_type: str = 'vector',
|
||||||
|
query_text: str = '',
|
||||||
|
vector_weight: float | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Handle VECTOR_SEARCH action."""
|
||||||
|
return await self.ap.vector_db_mgr.search(
|
||||||
|
collection_name=collection_id,
|
||||||
|
query_vector=query_vector,
|
||||||
|
limit=top_k,
|
||||||
|
filter=filters,
|
||||||
|
search_type=search_type,
|
||||||
|
query_text=query_text,
|
||||||
|
vector_weight=vector_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def vector_delete(
|
||||||
|
self, collection_id: str, file_ids: list[str] | None = None, filters: dict[str, Any] | None = None
|
||||||
|
) -> int:
|
||||||
|
"""Handle VECTOR_DELETE action.
|
||||||
|
|
||||||
|
Deletes vectors associated with the given file IDs from the collection.
|
||||||
|
Each file_id corresponds to a document whose vectors will be removed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_id: The collection to delete from.
|
||||||
|
file_ids: File IDs whose associated vectors should be deleted.
|
||||||
|
Each file_id maps to a set of vectors stored with that file_id
|
||||||
|
in their metadata.
|
||||||
|
filters: Filter-based deletion (not yet supported, will raise).
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
if file_ids:
|
||||||
|
await self.ap.vector_db_mgr.delete_by_file_id(collection_name=collection_id, file_ids=file_ids)
|
||||||
|
count = len(file_ids)
|
||||||
|
elif filters:
|
||||||
|
count = await self.ap.vector_db_mgr.delete_by_filter(collection_name=collection_id, filter=filters)
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def vector_list(
|
||||||
|
self,
|
||||||
|
collection_id: str,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
"""Handle VECTOR_LIST action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_id: The collection to list from.
|
||||||
|
filters: Optional metadata filters.
|
||||||
|
limit: Maximum number of items to return.
|
||||||
|
offset: Number of items to skip.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (items, total).
|
||||||
|
"""
|
||||||
|
return await self.ap.vector_db_mgr.list_by_filter(
|
||||||
|
collection_name=collection_id,
|
||||||
|
filter=filters,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_file_stream(self, storage_path: str) -> bytes:
|
||||||
|
"""Handle GET_KNOWLEDEGE_FILE_STREAM action.
|
||||||
|
|
||||||
|
Uses the storage manager abstraction to load file content,
|
||||||
|
regardless of the underlying storage provider.
|
||||||
|
"""
|
||||||
|
# Validate storage_path to prevent path traversal
|
||||||
|
normalized = posixpath.normpath(storage_path)
|
||||||
|
if normalized.startswith('/') or '..' in normalized.split('/'):
|
||||||
|
raise ValueError('Invalid storage path')
|
||||||
|
content_bytes = await self.ap.storage_mgr.storage_provider.load(normalized)
|
||||||
|
return content_bytes if content_bytes else b''
|
||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from ..core import app
|
from ..core import app
|
||||||
from . import provider
|
from . import provider
|
||||||
from .providers import localstorage, s3storage
|
from .providers import localstorage
|
||||||
|
|
||||||
|
|
||||||
class StorageMgr:
|
class StorageMgr:
|
||||||
@@ -21,6 +21,8 @@ class StorageMgr:
|
|||||||
storage_type = storage_config.get('use', 'local')
|
storage_type = storage_config.get('use', 'local')
|
||||||
|
|
||||||
if storage_type == 's3':
|
if storage_type == 's3':
|
||||||
|
from .providers import s3storage
|
||||||
|
|
||||||
self.storage_provider = s3storage.S3StorageProvider(self.ap)
|
self.storage_provider = s3storage.S3StorageProvider(self.ap)
|
||||||
self.ap.logger.info('Initialized S3 storage backend.')
|
self.ap.logger.info('Initialized S3 storage backend.')
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -43,6 +43,13 @@ class StorageProvider(abc.ABC):
|
|||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def size(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
) -> int:
|
||||||
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def delete_dir_recursive(
|
async def delete_dir_recursive(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -47,6 +47,12 @@ class LocalStorageProvider(provider.StorageProvider):
|
|||||||
):
|
):
|
||||||
os.remove(os.path.join(LOCAL_STORAGE_PATH, f'{key}'))
|
os.remove(os.path.join(LOCAL_STORAGE_PATH, f'{key}'))
|
||||||
|
|
||||||
|
async def size(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
) -> int:
|
||||||
|
return os.path.getsize(os.path.join(LOCAL_STORAGE_PATH, f'{key}'))
|
||||||
|
|
||||||
async def delete_dir_recursive(
|
async def delete_dir_recursive(
|
||||||
self,
|
self,
|
||||||
dir_path: str,
|
dir_path: str,
|
||||||
|
|||||||
@@ -117,6 +117,21 @@ class S3StorageProvider(provider.StorageProvider):
|
|||||||
self.ap.logger.error(f'Failed to delete from S3: {e}')
|
self.ap.logger.error(f'Failed to delete from S3: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def size(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
) -> int:
|
||||||
|
"""Get object size from S3 without downloading it"""
|
||||||
|
try:
|
||||||
|
response = self.s3_client.head_object(
|
||||||
|
Bucket=self.bucket_name,
|
||||||
|
Key=key,
|
||||||
|
)
|
||||||
|
return response['ContentLength']
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.error(f'Failed to get size from S3: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
async def delete_dir_recursive(
|
async def delete_dir_recursive(
|
||||||
self,
|
self,
|
||||||
dir_path: str,
|
dir_path: str,
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class TelemetryManager:
|
|||||||
except Exception:
|
except Exception:
|
||||||
sanitized['query_id'] = str(sanitized.get('query_id', ''))
|
sanitized['query_id'] = str(sanitized.get('query_id', ''))
|
||||||
|
|
||||||
for sfield in ('adapter', 'runner', 'model_name', 'version', 'error', 'timestamp'):
|
for sfield in ('adapter', 'runner', 'runner_category', 'model_name', 'version', 'error', 'timestamp'):
|
||||||
v = sanitized.get(sfield)
|
v = sanitized.get(sfield)
|
||||||
sanitized[sfield] = '' if v is None else str(v)
|
sanitized[sfield] = '' if v is None else str(v)
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import langbot
|
|||||||
|
|
||||||
semantic_version = f'v{langbot.__version__}'
|
semantic_version = f'v{langbot.__version__}'
|
||||||
|
|
||||||
required_database_version = 19
|
required_database_version = 24
|
||||||
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
|
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
|
||||||
|
|
||||||
debug_mode = False
|
debug_mode = False
|
||||||
|
|||||||
43
src/langbot/pkg/utils/httpclient.py
Normal file
43
src/langbot/pkg/utils/httpclient.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""Shared aiohttp.ClientSession to avoid repeated SSL context creation.
|
||||||
|
|
||||||
|
Each call to `aiohttp.ClientSession()` creates a new `TCPConnector` which in turn
|
||||||
|
creates a new `ssl.SSLContext` and loads all system root certificates. This is
|
||||||
|
extremely expensive in both CPU and memory (~270MB total allocations observed via
|
||||||
|
memray profiling).
|
||||||
|
|
||||||
|
This module provides a shared session pool so that all HTTP client code in LangBot
|
||||||
|
reuses the same underlying SSL context and connection pool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
_sessions: dict[str, aiohttp.ClientSession] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_session(*, trust_env: bool = False) -> aiohttp.ClientSession:
|
||||||
|
"""Get or create a shared aiohttp.ClientSession.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trust_env: Whether to trust environment variables for proxy settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A shared aiohttp.ClientSession instance.
|
||||||
|
"""
|
||||||
|
key = f'trust_env={trust_env}'
|
||||||
|
|
||||||
|
session = _sessions.get(key)
|
||||||
|
if session is None or session.closed:
|
||||||
|
session = aiohttp.ClientSession(trust_env=trust_env)
|
||||||
|
_sessions[key] = session
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def close_all():
|
||||||
|
"""Close all shared sessions. Call on application shutdown."""
|
||||||
|
for session in _sessions.values():
|
||||||
|
if not session.closed:
|
||||||
|
await session.close()
|
||||||
|
_sessions.clear()
|
||||||
@@ -5,6 +5,8 @@ from urllib.parse import urlparse, parse_qs
|
|||||||
import ssl
|
import ssl
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
from langbot.pkg.utils import httpclient
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -47,53 +49,54 @@ async def get_gewechat_image_base64(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
session = httpclient.get_session()
|
||||||
# 获取图片下载链接
|
# 获取图片下载链接
|
||||||
try:
|
try:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f'{gewechat_url}/v2/api/message/downloadImage',
|
f'{gewechat_url}/v2/api/message/downloadImage',
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json={'appId': app_id, 'type': image_type, 'xml': xml_content},
|
json={'appId': app_id, 'type': image_type, 'xml': xml_content},
|
||||||
) as response:
|
timeout=timeout,
|
||||||
if response.status != 200:
|
) as response:
|
||||||
# print(response)
|
if response.status != 200:
|
||||||
raise Exception(f'获取gewechat图片下载失败: {await response.text()}')
|
# print(response)
|
||||||
|
raise Exception(f'获取gewechat图片下载失败: {await response.text()}')
|
||||||
|
|
||||||
resp_data = await response.json()
|
resp_data = await response.json()
|
||||||
if resp_data.get('ret') != 200:
|
if resp_data.get('ret') != 200:
|
||||||
raise Exception(f'获取gewechat图片下载链接失败: {resp_data}')
|
raise Exception(f'获取gewechat图片下载链接失败: {resp_data}')
|
||||||
|
|
||||||
file_url = resp_data['data']['fileUrl']
|
file_url = resp_data['data']['fileUrl']
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
raise Exception('获取图片下载链接超时')
|
raise Exception('获取图片下载链接超时')
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
raise Exception(f'获取图片下载链接网络错误: {str(e)}')
|
raise Exception(f'获取图片下载链接网络错误: {str(e)}')
|
||||||
|
|
||||||
# 解析原始URL并替换端口
|
# 解析原始URL并替换端口
|
||||||
base_url = gewechat_file_url
|
base_url = gewechat_file_url
|
||||||
download_url = f'{base_url}/download/{file_url}'
|
download_url = f'{base_url}/download/{file_url}'
|
||||||
|
|
||||||
# 下载图片
|
# 下载图片
|
||||||
try:
|
try:
|
||||||
async with session.get(download_url) as img_response:
|
async with session.get(download_url) as img_response:
|
||||||
if img_response.status != 200:
|
if img_response.status != 200:
|
||||||
raise Exception(f'下载图片失败: {await img_response.text()}, URL: {download_url}')
|
raise Exception(f'下载图片失败: {await img_response.text()}, URL: {download_url}')
|
||||||
|
|
||||||
image_data = await img_response.read()
|
image_data = await img_response.read()
|
||||||
|
|
||||||
content_type = img_response.headers.get('Content-Type', '')
|
content_type = img_response.headers.get('Content-Type', '')
|
||||||
if content_type:
|
if content_type:
|
||||||
image_format = content_type.split('/')[-1]
|
image_format = content_type.split('/')[-1]
|
||||||
else:
|
else:
|
||||||
image_format = file_url.split('.')[-1]
|
image_format = file_url.split('.')[-1]
|
||||||
|
|
||||||
base64_str = base64.b64encode(image_data).decode('utf-8')
|
base64_str = base64.b64encode(image_data).decode('utf-8')
|
||||||
|
|
||||||
return base64_str, image_format
|
return base64_str, image_format
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
raise Exception(f'下载图片超时, URL: {download_url}')
|
raise Exception(f'下载图片超时, URL: {download_url}')
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
raise Exception(f'下载图片网络错误: {str(e)}, URL: {download_url}')
|
raise Exception(f'下载图片网络错误: {str(e)}, URL: {download_url}')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f'获取图片失败: {str(e)}') from e
|
raise Exception(f'获取图片失败: {str(e)}') from e
|
||||||
|
|
||||||
@@ -104,24 +107,24 @@ async def get_wecom_image_base64(pic_url: str) -> tuple[str, str]:
|
|||||||
:param pic_url: 企业微信图片URL
|
:param pic_url: 企业微信图片URL
|
||||||
:return: (base64_str, image_format)
|
:return: (base64_str, image_format)
|
||||||
"""
|
"""
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(pic_url) as response:
|
async with session.get(pic_url) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
raise Exception(f'Failed to download image: {response.status}')
|
raise Exception(f'Failed to download image: {response.status}')
|
||||||
|
|
||||||
# 读取图片数据
|
# 读取图片数据
|
||||||
image_data = await response.read()
|
image_data = await response.read()
|
||||||
|
|
||||||
# 获取图片格式
|
# 获取图片格式
|
||||||
content_type = response.headers.get('Content-Type', '')
|
content_type = response.headers.get('Content-Type', '')
|
||||||
image_format = content_type.split('/')[-1] # 例如 'image/jpeg' -> 'jpeg'
|
image_format = content_type.split('/')[-1] # 例如 'image/jpeg' -> 'jpeg'
|
||||||
|
|
||||||
# 转换为 base64
|
# 转换为 base64
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||||
|
|
||||||
return image_base64, image_format
|
return image_base64, image_format
|
||||||
|
|
||||||
|
|
||||||
async def get_qq_official_image_base64(pic_url: str, content_type: str) -> tuple[str, str]:
|
async def get_qq_official_image_base64(pic_url: str, content_type: str) -> tuple[str, str]:
|
||||||
@@ -152,21 +155,19 @@ async def get_qq_image_bytes(image_url: str, query: dict = {}) -> tuple[bytes, s
|
|||||||
ssl_context = ssl.create_default_context()
|
ssl_context = ssl.create_default_context()
|
||||||
ssl_context.check_hostname = False
|
ssl_context.check_hostname = False
|
||||||
ssl_context.verify_mode = ssl.CERT_NONE
|
ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
async with aiohttp.ClientSession(trust_env=False) as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(
|
async with session.get(image_url, params=query, ssl=ssl_context, timeout=aiohttp.ClientTimeout(total=30.0)) as resp:
|
||||||
image_url, params=query, ssl=ssl_context, timeout=aiohttp.ClientTimeout(total=30.0)
|
resp.raise_for_status()
|
||||||
) as resp:
|
file_bytes = await resp.read()
|
||||||
resp.raise_for_status()
|
content_type = resp.headers.get('Content-Type')
|
||||||
file_bytes = await resp.read()
|
if not content_type:
|
||||||
content_type = resp.headers.get('Content-Type')
|
image_format = 'jpeg'
|
||||||
if not content_type:
|
elif not content_type.startswith('image/'):
|
||||||
image_format = 'jpeg'
|
pil_img = PIL.Image.open(io.BytesIO(file_bytes))
|
||||||
elif not content_type.startswith('image/'):
|
image_format = pil_img.format.lower()
|
||||||
pil_img = PIL.Image.open(io.BytesIO(file_bytes))
|
else:
|
||||||
image_format = pil_img.format.lower()
|
image_format = content_type.split('/')[-1]
|
||||||
else:
|
return file_bytes, image_format
|
||||||
image_format = content_type.split('/')[-1]
|
|
||||||
return file_bytes, image_format
|
|
||||||
|
|
||||||
|
|
||||||
async def qq_image_url_to_base64(image_url: str) -> typing.Tuple[str, str]:
|
async def qq_image_url_to_base64(image_url: str) -> typing.Tuple[str, str]:
|
||||||
@@ -204,11 +205,11 @@ async def extract_b64_and_format(image_base64_data: str) -> typing.Tuple[str, st
|
|||||||
async def get_slack_image_to_base64(pic_url: str, bot_token: str):
|
async def get_slack_image_to_base64(pic_url: str, bot_token: str):
|
||||||
headers = {'Authorization': f'Bearer {bot_token}'}
|
headers = {'Authorization': f'Bearer {bot_token}'}
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
session = httpclient.get_session()
|
||||||
async with session.get(pic_url, headers=headers) as resp:
|
async with session.get(pic_url, headers=headers) as resp:
|
||||||
mime_type = resp.headers.get('Content-Type', 'application/octet-stream')
|
mime_type = resp.headers.get('Content-Type', 'application/octet-stream')
|
||||||
file_bytes = await resp.read()
|
file_bytes = await resp.read()
|
||||||
base64_str = base64.b64encode(file_bytes).decode('utf-8')
|
base64_str = base64.b64encode(file_bytes).decode('utf-8')
|
||||||
return f'data:{mime_type};base64,{base64_str}'
|
return f'data:{mime_type};base64,{base64_str}'
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise (e)
|
raise (e)
|
||||||
|
|||||||
105
src/langbot/pkg/utils/runner.py
Normal file
105
src/langbot/pkg/utils/runner.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
|
class RunnerCategory:
|
||||||
|
LOCAL = 'local'
|
||||||
|
CLOUD = 'cloud'
|
||||||
|
UNKNOWN = 'unknown'
|
||||||
|
|
||||||
|
|
||||||
|
CLOUD_DOMAINS = [
|
||||||
|
'.n8n.cloud',
|
||||||
|
'.n8n.io',
|
||||||
|
'api.dify.ai',
|
||||||
|
'cloud.dify.ai',
|
||||||
|
'.coze.com',
|
||||||
|
'.coze.cn',
|
||||||
|
'cloud.langflow.ai',
|
||||||
|
'.langflow.org',
|
||||||
|
]
|
||||||
|
|
||||||
|
LOCAL_PATTERNS = [
|
||||||
|
'localhost',
|
||||||
|
'127.0.0.1',
|
||||||
|
'0.0.0.0',
|
||||||
|
'192.168.',
|
||||||
|
'10.',
|
||||||
|
'172.16.',
|
||||||
|
'172.17.',
|
||||||
|
'172.18.',
|
||||||
|
'172.19.',
|
||||||
|
'172.20.',
|
||||||
|
'172.21.',
|
||||||
|
'172.22.',
|
||||||
|
'172.23.',
|
||||||
|
'172.24.',
|
||||||
|
'172.25.',
|
||||||
|
'172.26.',
|
||||||
|
'172.27.',
|
||||||
|
'172.28.',
|
||||||
|
'172.29.',
|
||||||
|
'172.30.',
|
||||||
|
'172.31.',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_runner_category(runner_name: str, runner_url: str) -> str:
|
||||||
|
if not runner_url:
|
||||||
|
return RunnerCategory.UNKNOWN
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_url = urlparse(runner_url)
|
||||||
|
host = parsed_url.hostname.lower() if parsed_url.hostname else ''
|
||||||
|
except Exception:
|
||||||
|
return RunnerCategory.UNKNOWN
|
||||||
|
|
||||||
|
for pattern in LOCAL_PATTERNS:
|
||||||
|
if host.startswith(pattern):
|
||||||
|
return RunnerCategory.LOCAL
|
||||||
|
|
||||||
|
for domain in CLOUD_DOMAINS:
|
||||||
|
if host.endswith(domain):
|
||||||
|
return RunnerCategory.CLOUD
|
||||||
|
|
||||||
|
return RunnerCategory.CLOUD
|
||||||
|
|
||||||
|
|
||||||
|
def get_runner_info(runner_name: str, runner_url: str) -> dict:
|
||||||
|
return {
|
||||||
|
'name': runner_name,
|
||||||
|
'url': runner_url,
|
||||||
|
'category': get_runner_category(runner_name, runner_url),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_cloud_runner(runner_name: str, runner_url: str) -> bool:
|
||||||
|
return get_runner_category(runner_name, runner_url) == RunnerCategory.CLOUD
|
||||||
|
|
||||||
|
|
||||||
|
def is_local_runner(runner_name: str, runner_url: str) -> bool:
|
||||||
|
return get_runner_category(runner_name, runner_url) == RunnerCategory.LOCAL
|
||||||
|
|
||||||
|
|
||||||
|
def extract_runner_url(runner_name: str, runner, pipeline_config: dict | None) -> str | None:
|
||||||
|
if not runner or not hasattr(runner, 'pipeline_config'):
|
||||||
|
return None
|
||||||
|
|
||||||
|
ai_config = pipeline_config.get('ai', {}) if pipeline_config else {}
|
||||||
|
|
||||||
|
if runner_name == 'dify-service-api':
|
||||||
|
return ai_config.get('dify-service-api', {}).get('base-url')
|
||||||
|
elif runner_name == 'n8n-service-api':
|
||||||
|
return ai_config.get('n8n-service-api', {}).get('webhook-url')
|
||||||
|
elif runner_name == 'coze-api':
|
||||||
|
return ai_config.get('coze-api', {}).get('api-base')
|
||||||
|
elif runner_name == 'langflow-api':
|
||||||
|
return ai_config.get('langflow-api', {}).get('base-url')
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_runner_category_from_runner(runner_name: str, runner, pipeline_config: dict | None) -> str:
|
||||||
|
runner_url = extract_runner_url(runner_name, runner, pipeline_config)
|
||||||
|
return get_runner_category(runner_name, runner_url)
|
||||||
77
src/langbot/pkg/vector/filter_utils.py
Normal file
77
src/langbot/pkg/vector/filter_utils.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""Shared utilities for metadata filter handling across VDB backends.
|
||||||
|
|
||||||
|
Canonical filter format (Chroma-style ``where`` syntax):
|
||||||
|
|
||||||
|
{"file_id": "abc"} # implicit $eq
|
||||||
|
{"file_id": {"$eq": "abc"}} # explicit $eq
|
||||||
|
{"created_at": {"$gte": 1700000000}} # comparison
|
||||||
|
{"file_type": {"$in": ["pdf", "docx"]}} # in-list
|
||||||
|
|
||||||
|
Multiple top-level keys are AND-ed. Supported operators:
|
||||||
|
``$eq``, ``$ne``, ``$gt``, ``$gte``, ``$lt``, ``$lte``, ``$in``, ``$nin``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
SUPPORTED_OPS = frozenset({'$eq', '$ne', '$gt', '$gte', '$lt', '$lte', '$in', '$nin'})
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_filter(
|
||||||
|
raw: dict[str, Any] | None,
|
||||||
|
) -> list[tuple[str, str, Any]]:
|
||||||
|
"""Parse a canonical filter dict into ``[(field, op, value)]`` triples.
|
||||||
|
|
||||||
|
Returns an empty list when *raw* is ``None`` or empty.
|
||||||
|
|
||||||
|
Raises ``ValueError`` on unsupported operators or malformed entries.
|
||||||
|
"""
|
||||||
|
if not raw:
|
||||||
|
return []
|
||||||
|
|
||||||
|
triples: list[tuple[str, str, Any]] = []
|
||||||
|
for field, condition in raw.items():
|
||||||
|
if isinstance(condition, dict):
|
||||||
|
for op, value in condition.items():
|
||||||
|
if op not in SUPPORTED_OPS:
|
||||||
|
raise ValueError(f'Unsupported filter operator: {op}')
|
||||||
|
triples.append((field, op, value))
|
||||||
|
else:
|
||||||
|
# Bare value -> implicit $eq
|
||||||
|
triples.append((field, '$eq', condition))
|
||||||
|
return triples
|
||||||
|
|
||||||
|
|
||||||
|
def strip_unsupported_fields(
|
||||||
|
triples: list[tuple[str, str, Any]],
|
||||||
|
supported_fields: set[str],
|
||||||
|
field_aliases: dict[str, str] | None = None,
|
||||||
|
) -> list[tuple[str, str, Any]]:
|
||||||
|
"""Return only triples whose field is in *supported_fields*.
|
||||||
|
|
||||||
|
If *field_aliases* is provided, aliased field names are mapped to the
|
||||||
|
canonical backend name before the support check. For example,
|
||||||
|
``{'uuid': 'chunk_uuid'}`` allows callers to use ``uuid`` which is
|
||||||
|
transparently rewritten to ``chunk_uuid``.
|
||||||
|
|
||||||
|
Dropped fields are logged at WARNING level so the caller knows they were
|
||||||
|
silently ignored (useful for Milvus / pgvector which only store a fixed
|
||||||
|
schema).
|
||||||
|
"""
|
||||||
|
aliases = field_aliases or {}
|
||||||
|
kept: list[tuple[str, str, Any]] = []
|
||||||
|
for field, op, value in triples:
|
||||||
|
resolved = aliases.get(field, field)
|
||||||
|
if resolved in supported_fields:
|
||||||
|
kept.append((resolved, op, value))
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
'Filter field %r is not supported by this backend and will be ignored (supported: %s)',
|
||||||
|
field,
|
||||||
|
', '.join(sorted(supported_fields)),
|
||||||
|
)
|
||||||
|
return kept
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from ..core import app
|
from ..core import app
|
||||||
from .vdb import VectorDatabase
|
from .vdb import VectorDatabase, SearchType
|
||||||
from .vdbs.chroma import ChromaVectorDatabase
|
from .vdbs.chroma import ChromaVectorDatabase
|
||||||
from .vdbs.qdrant import QdrantVectorDatabase
|
from .vdbs.qdrant import QdrantVectorDatabase
|
||||||
from .vdbs.seekdb import SeekDBVectorDatabase
|
from .vdbs.seekdb import SeekDBVectorDatabase
|
||||||
@@ -65,3 +65,111 @@ class VectorDBManager:
|
|||||||
else:
|
else:
|
||||||
self.vector_db = ChromaVectorDatabase(self.ap)
|
self.vector_db = ChromaVectorDatabase(self.ap)
|
||||||
self.ap.logger.warning('No vector database backend configured, defaulting to Chroma.')
|
self.ap.logger.warning('No vector database backend configured, defaulting to Chroma.')
|
||||||
|
|
||||||
|
def get_supported_search_types(self) -> list[str]:
|
||||||
|
"""Return the search types supported by the current VDB backend."""
|
||||||
|
if self.vector_db is None:
|
||||||
|
return [SearchType.VECTOR.value]
|
||||||
|
return [st.value for st in self.vector_db.supported_search_types()]
|
||||||
|
|
||||||
|
async def upsert(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
vectors: list[list[float]],
|
||||||
|
ids: list[str],
|
||||||
|
metadata: list[dict] | None = None,
|
||||||
|
documents: list[str] | None = None,
|
||||||
|
):
|
||||||
|
"""Proxy: Upsert vectors"""
|
||||||
|
await self.vector_db.add_embeddings(
|
||||||
|
collection=collection_name,
|
||||||
|
ids=ids,
|
||||||
|
embeddings_list=vectors,
|
||||||
|
metadatas=metadata or [{} for _ in vectors],
|
||||||
|
documents=documents,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
query_vector: list[float],
|
||||||
|
limit: int,
|
||||||
|
filter: dict | None = None,
|
||||||
|
search_type: str = 'vector',
|
||||||
|
query_text: str = '',
|
||||||
|
vector_weight: float | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Proxy: Search vectors.
|
||||||
|
|
||||||
|
Returns a list of dicts with keys: 'id', 'distance', 'metadata'.
|
||||||
|
The underlying VectorDatabase.search returns Chroma-style format:
|
||||||
|
{ 'ids': [['id1']], 'distances': [[0.1]], 'metadatas': [[{}]] }
|
||||||
|
"""
|
||||||
|
results = await self.vector_db.search(
|
||||||
|
collection=collection_name,
|
||||||
|
query_embedding=query_vector,
|
||||||
|
k=limit,
|
||||||
|
search_type=search_type,
|
||||||
|
query_text=query_text,
|
||||||
|
filter=filter,
|
||||||
|
vector_weight=vector_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not results or 'ids' not in results or not results['ids']:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Flatten nested lists (Chroma returns batch-style: list of lists)
|
||||||
|
raw_ids = results['ids']
|
||||||
|
raw_dists = results.get('distances', [])
|
||||||
|
raw_metas = results.get('metadatas', [])
|
||||||
|
|
||||||
|
r_ids = raw_ids[0] if raw_ids and isinstance(raw_ids[0], list) else raw_ids
|
||||||
|
r_dists = raw_dists[0] if raw_dists and isinstance(raw_dists[0], list) else raw_dists
|
||||||
|
r_metas = raw_metas[0] if raw_metas and isinstance(raw_metas[0], list) else raw_metas
|
||||||
|
|
||||||
|
parsed_results = []
|
||||||
|
for i, id_val in enumerate(r_ids):
|
||||||
|
parsed_results.append(
|
||||||
|
{
|
||||||
|
'id': id_val,
|
||||||
|
'distance': r_dists[i] if r_dists and i < len(r_dists) else 0.0,
|
||||||
|
'metadata': r_metas[i] if r_metas and i < len(r_metas) else {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return parsed_results
|
||||||
|
|
||||||
|
async def delete_by_file_id(self, collection_name: str, file_ids: list[str]):
|
||||||
|
"""Proxy: Delete vectors by file_id (metadata-level identifier).
|
||||||
|
|
||||||
|
This delegates to VectorDatabase.delete_by_file_id which removes
|
||||||
|
all vectors associated with the given file IDs.
|
||||||
|
"""
|
||||||
|
for file_id in file_ids:
|
||||||
|
await self.vector_db.delete_by_file_id(collection_name, file_id)
|
||||||
|
|
||||||
|
async def delete_collection(self, collection_name: str):
|
||||||
|
"""Proxy: Delete an entire collection."""
|
||||||
|
await self.vector_db.delete_collection(collection_name)
|
||||||
|
|
||||||
|
async def delete_by_filter(self, collection_name: str, filter: dict) -> int:
|
||||||
|
"""Proxy: Delete vectors by metadata filter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of deleted vectors (best-effort; some backends return 0).
|
||||||
|
"""
|
||||||
|
return await self.vector_db.delete_by_filter(collection_name, filter)
|
||||||
|
|
||||||
|
async def list_by_filter(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
filter: dict | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> tuple[list[dict], int]:
|
||||||
|
"""Proxy: List vectors by metadata filter with pagination.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (items, total).
|
||||||
|
"""
|
||||||
|
return await self.vector_db.list_by_filter(collection_name, filter, limit, offset)
|
||||||
|
|||||||
@@ -1,10 +1,28 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import abc
|
import abc
|
||||||
|
import enum
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class SearchType(str, enum.Enum):
|
||||||
|
"""Supported search types for vector databases."""
|
||||||
|
|
||||||
|
VECTOR = 'vector'
|
||||||
|
FULL_TEXT = 'full_text'
|
||||||
|
HYBRID = 'hybrid'
|
||||||
|
|
||||||
|
|
||||||
class VectorDatabase(abc.ABC):
|
class VectorDatabase(abc.ABC):
|
||||||
|
@classmethod
|
||||||
|
def supported_search_types(cls) -> list[SearchType]:
|
||||||
|
"""Return the search types supported by this VDB backend.
|
||||||
|
|
||||||
|
Default: vector search only. Override in subclasses that support
|
||||||
|
full-text or hybrid search.
|
||||||
|
"""
|
||||||
|
return [SearchType.VECTOR]
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def add_embeddings(
|
async def add_embeddings(
|
||||||
self,
|
self,
|
||||||
@@ -12,14 +30,50 @@ class VectorDatabase(abc.ABC):
|
|||||||
ids: list[str],
|
ids: list[str],
|
||||||
embeddings_list: list[list[float]],
|
embeddings_list: list[list[float]],
|
||||||
metadatas: list[dict[str, Any]],
|
metadatas: list[dict[str, Any]],
|
||||||
documents: list[str],
|
documents: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add vector data to the specified collection."""
|
"""Add vector data to the specified collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection: Collection name.
|
||||||
|
ids: Unique IDs for each vector.
|
||||||
|
embeddings_list: List of embedding vectors.
|
||||||
|
metadatas: List of metadata dicts.
|
||||||
|
documents: Optional raw text documents. Required for full-text
|
||||||
|
and hybrid search in backends that support them.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]:
|
async def search(
|
||||||
"""Search for the most similar vectors in the specified collection."""
|
self,
|
||||||
|
collection: str,
|
||||||
|
query_embedding: np.ndarray,
|
||||||
|
k: int = 5,
|
||||||
|
search_type: str = 'vector',
|
||||||
|
query_text: str = '',
|
||||||
|
filter: dict[str, Any] | None = None,
|
||||||
|
vector_weight: float | None = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Search for the most similar vectors in the specified collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection: Collection name.
|
||||||
|
query_embedding: Query vector for similarity search.
|
||||||
|
k: Number of results to return.
|
||||||
|
search_type: One of 'vector', 'full_text', 'hybrid'.
|
||||||
|
query_text: Raw query text, used for full_text and hybrid search.
|
||||||
|
filter: Optional metadata filters using Chroma-style ``where``
|
||||||
|
syntax. Multiple top-level keys are AND-ed. Supported
|
||||||
|
operators: ``$eq``, ``$ne``, ``$gt``, ``$gte``, ``$lt``,
|
||||||
|
``$lte``, ``$in``, ``$nin``. Example::
|
||||||
|
|
||||||
|
{"file_id": "abc"}
|
||||||
|
{"created_at": {"$gte": 1700000000}}
|
||||||
|
{"file_type": {"$in": ["pdf", "docx"]}}
|
||||||
|
vector_weight: Weight for vector search in hybrid mode (0.0–1.0).
|
||||||
|
``None`` means use equal weights (backward compatible).
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -27,6 +81,42 @@ class VectorDatabase(abc.ABC):
|
|||||||
"""Delete vectors from the specified collection by file_id."""
|
"""Delete vectors from the specified collection by file_id."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def delete_by_filter(self, collection: str, filter: dict[str, Any]) -> int:
|
||||||
|
"""Delete vectors matching the given metadata filter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection: Collection name.
|
||||||
|
filter: Metadata filter dict in canonical format (see ``search``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of deleted vectors (best-effort; backends that cannot
|
||||||
|
report an exact count may return 0).
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_by_filter(
|
||||||
|
self,
|
||||||
|
collection: str,
|
||||||
|
filter: dict[str, Any] | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
"""List vectors matching the given metadata filter with pagination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection: Collection name.
|
||||||
|
filter: Optional metadata filter dict in canonical format.
|
||||||
|
limit: Maximum number of items to return.
|
||||||
|
offset: Number of items to skip.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (items, total) where items is a list of dicts with
|
||||||
|
keys 'id', 'document', 'metadata', and total is the best-effort
|
||||||
|
count of all matching vectors (-1 if unknown).
|
||||||
|
"""
|
||||||
|
return [], -1
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get_or_create_collection(self, collection: str):
|
async def get_or_create_collection(self, collection: str):
|
||||||
"""Get or create collection."""
|
"""Get or create collection."""
|
||||||
|
|||||||
@@ -2,11 +2,14 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from chromadb import PersistentClient
|
from chromadb import PersistentClient
|
||||||
from langbot.pkg.vector.vdb import VectorDatabase
|
from langbot.pkg.vector.vdb import VectorDatabase, SearchType
|
||||||
from langbot.pkg.core import app
|
from langbot.pkg.core import app
|
||||||
import chromadb
|
import chromadb
|
||||||
import chromadb.errors
|
import chromadb.errors
|
||||||
|
|
||||||
|
# RRF smoothing constant (standard value from the literature)
|
||||||
|
_RRF_K = 60
|
||||||
|
|
||||||
|
|
||||||
class ChromaVectorDatabase(VectorDatabase):
|
class ChromaVectorDatabase(VectorDatabase):
|
||||||
def __init__(self, ap: app.Application, base_path: str = './data/chroma'):
|
def __init__(self, ap: app.Application, base_path: str = './data/chroma'):
|
||||||
@@ -14,6 +17,10 @@ class ChromaVectorDatabase(VectorDatabase):
|
|||||||
self.client = PersistentClient(path=base_path)
|
self.client = PersistentClient(path=base_path)
|
||||||
self._collections = {}
|
self._collections = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supported_search_types(cls) -> list[SearchType]:
|
||||||
|
return [SearchType.VECTOR, SearchType.FULL_TEXT, SearchType.HYBRID]
|
||||||
|
|
||||||
async def get_or_create_collection(self, collection: str) -> chromadb.Collection:
|
async def get_or_create_collection(self, collection: str) -> chromadb.Collection:
|
||||||
if collection not in self._collections:
|
if collection not in self._collections:
|
||||||
self._collections[collection] = await asyncio.to_thread(
|
self._collections[collection] = await asyncio.to_thread(
|
||||||
@@ -28,27 +35,247 @@ class ChromaVectorDatabase(VectorDatabase):
|
|||||||
ids: list[str],
|
ids: list[str],
|
||||||
embeddings_list: list[list[float]],
|
embeddings_list: list[list[float]],
|
||||||
metadatas: list[dict[str, Any]],
|
metadatas: list[dict[str, Any]],
|
||||||
|
documents: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
col = await self.get_or_create_collection(collection)
|
col = await self.get_or_create_collection(collection)
|
||||||
await asyncio.to_thread(col.add, embeddings=embeddings_list, ids=ids, metadatas=metadatas)
|
kwargs: dict[str, Any] = dict(embeddings=embeddings_list, ids=ids, metadatas=metadatas)
|
||||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.")
|
if documents is not None:
|
||||||
|
kwargs['documents'] = documents
|
||||||
|
await asyncio.to_thread(col.upsert, **kwargs)
|
||||||
|
self.ap.logger.info(f"Upserted {len(ids)} embeddings to Chroma collection '{collection}'.")
|
||||||
|
|
||||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
|
async def search(
|
||||||
|
self,
|
||||||
|
collection: str,
|
||||||
|
query_embedding: list[float],
|
||||||
|
k: int = 5,
|
||||||
|
search_type: str = 'vector',
|
||||||
|
query_text: str = '',
|
||||||
|
filter: dict[str, Any] | None = None,
|
||||||
|
vector_weight: float | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
col = await self.get_or_create_collection(collection)
|
col = await self.get_or_create_collection(collection)
|
||||||
results = await asyncio.to_thread(
|
|
||||||
col.query,
|
if search_type == SearchType.FULL_TEXT:
|
||||||
|
return await self._full_text_search(col, collection, k, query_text, filter)
|
||||||
|
elif search_type == SearchType.HYBRID:
|
||||||
|
return await self._hybrid_search(
|
||||||
|
col, collection, query_embedding, k, query_text, filter, vector_weight=vector_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default: vector search
|
||||||
|
return await self._vector_search(col, collection, query_embedding, k, filter)
|
||||||
|
|
||||||
|
async def _vector_search(
|
||||||
|
self,
|
||||||
|
col: chromadb.Collection,
|
||||||
|
collection: str,
|
||||||
|
query_embedding: list[float],
|
||||||
|
k: int,
|
||||||
|
filter: dict[str, Any] | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
query_kwargs: dict[str, Any] = dict(
|
||||||
query_embeddings=query_embedding,
|
query_embeddings=query_embedding,
|
||||||
n_results=k,
|
n_results=k,
|
||||||
include=['metadatas', 'distances', 'documents'],
|
include=['metadatas', 'distances', 'documents'],
|
||||||
)
|
)
|
||||||
self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.")
|
if filter:
|
||||||
|
query_kwargs['where'] = filter
|
||||||
|
results = await asyncio.to_thread(col.query, **query_kwargs)
|
||||||
|
self.ap.logger.info(
|
||||||
|
f"Chroma vector search in '{collection}' returned {len(results.get('ids', [[]])[0])} results."
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
async def _full_text_search(
|
||||||
|
self,
|
||||||
|
col: chromadb.Collection,
|
||||||
|
collection: str,
|
||||||
|
k: int,
|
||||||
|
query_text: str,
|
||||||
|
filter: dict[str, Any] | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if not query_text:
|
||||||
|
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]], 'documents': [[]]}
|
||||||
|
|
||||||
|
get_kwargs: dict[str, Any] = dict(
|
||||||
|
where_document={'$contains': query_text},
|
||||||
|
include=['metadatas', 'documents'],
|
||||||
|
limit=k,
|
||||||
|
)
|
||||||
|
if filter:
|
||||||
|
get_kwargs['where'] = filter
|
||||||
|
results = await asyncio.to_thread(col.get, **get_kwargs)
|
||||||
|
|
||||||
|
# col.get returns flat lists; wrap into column-major format.
|
||||||
|
# Distances are all 0.0 because Chroma's local $contains is a boolean
|
||||||
|
# filter with no relevance scoring. Chroma's BM25 sparse embedding
|
||||||
|
# function (ChromaBm25EmbeddingFunction) can generate scored sparse
|
||||||
|
# vectors, but sparse vector *indexing* is only available on Chroma
|
||||||
|
# Cloud, not locally. For ranked results, use hybrid mode or apply a
|
||||||
|
# reranker in a downstream stage.
|
||||||
|
ids = results.get('ids', [])
|
||||||
|
metadatas = results.get('metadatas', []) or [None] * len(ids)
|
||||||
|
documents = results.get('documents', []) or [None] * len(ids)
|
||||||
|
distances = [0.0] * len(ids)
|
||||||
|
|
||||||
|
self.ap.logger.info(f"Chroma full-text search in '{collection}' returned {len(ids)} results.")
|
||||||
|
return {'ids': [ids], 'metadatas': [metadatas], 'distances': [distances], 'documents': [documents]}
|
||||||
|
|
||||||
|
async def _hybrid_search(
|
||||||
|
self,
|
||||||
|
col: chromadb.Collection,
|
||||||
|
collection: str,
|
||||||
|
query_embedding: list[float],
|
||||||
|
k: int,
|
||||||
|
query_text: str,
|
||||||
|
filter: dict[str, Any] | None,
|
||||||
|
vector_weight: float | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
# Fall back to pure vector search when no text is provided
|
||||||
|
if not query_text:
|
||||||
|
return await self._vector_search(col, collection, query_embedding, k, filter)
|
||||||
|
|
||||||
|
# Run vector search and full-text search in parallel
|
||||||
|
vector_task = self._vector_search(col, collection, query_embedding, k, filter)
|
||||||
|
text_task = self._full_text_search(col, collection, k, query_text, filter)
|
||||||
|
vector_results, text_results = await asyncio.gather(vector_task, text_task)
|
||||||
|
|
||||||
|
vector_ids = vector_results.get('ids', [[]])[0]
|
||||||
|
text_ids = text_results.get('ids', [[]])[0]
|
||||||
|
|
||||||
|
if not vector_ids and not text_ids:
|
||||||
|
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]], 'documents': [[]]}
|
||||||
|
|
||||||
|
# RRF fusion
|
||||||
|
weights = None
|
||||||
|
if vector_weight is not None:
|
||||||
|
weights = [vector_weight, 1.0 - vector_weight]
|
||||||
|
self.ap.logger.info(
|
||||||
|
f"Chroma hybrid fusion config in '{collection}': "
|
||||||
|
f'vector_weight={vector_weight}, weights={weights or [1.0, 1.0]}, '
|
||||||
|
f'vector_hits={len(vector_ids)}, text_hits={len(text_ids)}'
|
||||||
|
)
|
||||||
|
fused = self._rrf_fuse([vector_ids, text_ids], k, weights=weights)
|
||||||
|
if not fused:
|
||||||
|
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]], 'documents': [[]]}
|
||||||
|
|
||||||
|
fused_ids = [doc_id for doc_id, _ in fused]
|
||||||
|
|
||||||
|
# Fetch full metadata and documents for fused results
|
||||||
|
fetched = await asyncio.to_thread(col.get, ids=fused_ids, include=['metadatas', 'documents'])
|
||||||
|
|
||||||
|
# col.get returns results in arbitrary order; re-order to match fused ranking
|
||||||
|
fetched_map: dict[str, tuple] = {}
|
||||||
|
for i, fid in enumerate(fetched.get('ids', [])):
|
||||||
|
meta = (fetched.get('metadatas') or [None] * len(fetched['ids']))[i]
|
||||||
|
doc = (fetched.get('documents') or [None] * len(fetched['ids']))[i]
|
||||||
|
fetched_map[fid] = (meta, doc)
|
||||||
|
|
||||||
|
ordered_ids = []
|
||||||
|
ordered_metas = []
|
||||||
|
ordered_docs = []
|
||||||
|
ordered_dists = []
|
||||||
|
|
||||||
|
# Normalize RRF scores to 0~1 distances via min-max scaling.
|
||||||
|
# Raw RRF scores are tiny (e.g. 0.016~0.033 with k=60) so a naive
|
||||||
|
# ``1 - score`` would compress all distances into a narrow 0.96~0.98
|
||||||
|
# band with almost no discriminative power. Min-max normalization
|
||||||
|
# spreads them across the full 0~1 range (0.0 = best match).
|
||||||
|
max_score = fused[0][1]
|
||||||
|
min_score = fused[-1][1]
|
||||||
|
score_range = max_score - min_score
|
||||||
|
|
||||||
|
for doc_id, score in fused:
|
||||||
|
if doc_id in fetched_map:
|
||||||
|
meta, doc = fetched_map[doc_id]
|
||||||
|
ordered_ids.append(doc_id)
|
||||||
|
ordered_metas.append(meta)
|
||||||
|
ordered_docs.append(doc)
|
||||||
|
if score_range > 0:
|
||||||
|
ordered_dists.append(1.0 - (score - min_score) / score_range)
|
||||||
|
else:
|
||||||
|
ordered_dists.append(0.0)
|
||||||
|
|
||||||
|
self.ap.logger.info(
|
||||||
|
f"Chroma hybrid search in '{collection}' returned {len(ordered_ids)} results "
|
||||||
|
f'(vector={len(vector_ids)}, text={len(text_ids)}).'
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
'ids': [ordered_ids],
|
||||||
|
'metadatas': [ordered_metas],
|
||||||
|
'distances': [ordered_dists],
|
||||||
|
'documents': [ordered_docs],
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _rrf_fuse(result_lists: list[list[str]], k: int, weights: list[float] | None = None) -> list[tuple[str, float]]:
|
||||||
|
"""Reciprocal Rank Fusion over multiple ranked ID lists.
|
||||||
|
|
||||||
|
Returns a list of (doc_id, rrf_score) sorted by descending score,
|
||||||
|
truncated to *k* entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result_lists: Ranked ID lists from different search methods.
|
||||||
|
k: Number of results to return.
|
||||||
|
weights: Per-list weights. ``None`` means equal weight (1.0 each).
|
||||||
|
"""
|
||||||
|
if weights is None:
|
||||||
|
weights = [1.0] * len(result_lists)
|
||||||
|
scores: dict[str, float] = {}
|
||||||
|
for list_idx, ranked_ids in enumerate(result_lists):
|
||||||
|
w = weights[list_idx]
|
||||||
|
for rank, doc_id in enumerate(ranked_ids):
|
||||||
|
scores[doc_id] = scores.get(doc_id, 0.0) + w / (_RRF_K + rank + 1)
|
||||||
|
sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
return sorted_results[:k]
|
||||||
|
|
||||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||||
col = await self.get_or_create_collection(collection)
|
col = await self.get_or_create_collection(collection)
|
||||||
await asyncio.to_thread(col.delete, where={'file_id': file_id})
|
await asyncio.to_thread(col.delete, where={'file_id': file_id})
|
||||||
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with file_id: {file_id}")
|
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with file_id: {file_id}")
|
||||||
|
|
||||||
|
async def delete_by_filter(self, collection: str, filter: dict[str, Any]) -> int:
|
||||||
|
col = await self.get_or_create_collection(collection)
|
||||||
|
await asyncio.to_thread(col.delete, where=filter)
|
||||||
|
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' by filter")
|
||||||
|
return 0 # Chroma delete does not return a count
|
||||||
|
|
||||||
|
async def list_by_filter(
|
||||||
|
self,
|
||||||
|
collection: str,
|
||||||
|
filter: dict[str, Any] | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
col = await self.get_or_create_collection(collection)
|
||||||
|
get_kwargs: dict[str, Any] = dict(
|
||||||
|
include=['metadatas', 'documents'],
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
if filter:
|
||||||
|
get_kwargs['where'] = filter
|
||||||
|
results = await asyncio.to_thread(col.get, **get_kwargs)
|
||||||
|
|
||||||
|
ids = results.get('ids', [])
|
||||||
|
metadatas = results.get('metadatas', []) or [None] * len(ids)
|
||||||
|
documents = results.get('documents', []) or [None] * len(ids)
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for i, vid in enumerate(ids):
|
||||||
|
items.append(
|
||||||
|
{
|
||||||
|
'id': vid,
|
||||||
|
'document': documents[i] if i < len(documents) else None,
|
||||||
|
'metadata': metadatas[i] if i < len(metadatas) else {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chroma col.count() gives total in collection; filtered count not available
|
||||||
|
total = await asyncio.to_thread(col.count) if not filter else -1
|
||||||
|
return items, total
|
||||||
|
|
||||||
async def delete_collection(self, collection: str):
|
async def delete_collection(self, collection: str):
|
||||||
if collection in self._collections:
|
if collection in self._collections:
|
||||||
del self._collections[collection]
|
del self._collections[collection]
|
||||||
|
|||||||
@@ -4,8 +4,54 @@ from typing import Any, Dict
|
|||||||
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema
|
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema
|
||||||
from pymilvus.milvus_client.index import IndexParams
|
from pymilvus.milvus_client.index import IndexParams
|
||||||
from langbot.pkg.vector.vdb import VectorDatabase
|
from langbot.pkg.vector.vdb import VectorDatabase
|
||||||
|
from langbot.pkg.vector.filter_utils import normalize_filter, strip_unsupported_fields
|
||||||
from langbot.pkg.core import app
|
from langbot.pkg.core import app
|
||||||
|
|
||||||
|
# Milvus schema only stores these metadata fields; filter on other fields is
|
||||||
|
# silently dropped with a warning.
|
||||||
|
_MILVUS_SUPPORTED_FIELDS = {'text', 'file_id', 'chunk_uuid'}
|
||||||
|
|
||||||
|
# Callers use canonical metadata key 'uuid' but Milvus stores it as 'chunk_uuid'.
|
||||||
|
_MILVUS_FIELD_ALIASES = {'uuid': 'chunk_uuid'}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_milvus_expr(filter_dict: dict[str, Any]) -> str:
|
||||||
|
"""Translate canonical filter dict into a Milvus boolean expression string."""
|
||||||
|
triples = normalize_filter(filter_dict)
|
||||||
|
triples = strip_unsupported_fields(triples, _MILVUS_SUPPORTED_FIELDS, _MILVUS_FIELD_ALIASES)
|
||||||
|
if not triples:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
parts: list[str] = []
|
||||||
|
for field, op, value in triples:
|
||||||
|
if op == '$eq':
|
||||||
|
parts.append(f'{field} == {_milvus_literal(value)}')
|
||||||
|
elif op == '$ne':
|
||||||
|
parts.append(f'{field} != {_milvus_literal(value)}')
|
||||||
|
elif op == '$gt':
|
||||||
|
parts.append(f'{field} > {_milvus_literal(value)}')
|
||||||
|
elif op == '$gte':
|
||||||
|
parts.append(f'{field} >= {_milvus_literal(value)}')
|
||||||
|
elif op == '$lt':
|
||||||
|
parts.append(f'{field} < {_milvus_literal(value)}')
|
||||||
|
elif op == '$lte':
|
||||||
|
parts.append(f'{field} <= {_milvus_literal(value)}')
|
||||||
|
elif op == '$in':
|
||||||
|
items = ', '.join(_milvus_literal(v) for v in value)
|
||||||
|
parts.append(f'{field} in [{items}]')
|
||||||
|
elif op == '$nin':
|
||||||
|
items = ', '.join(_milvus_literal(v) for v in value)
|
||||||
|
parts.append(f'{field} not in [{items}]')
|
||||||
|
return ' and '.join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _milvus_literal(value: Any) -> str:
|
||||||
|
"""Format a Python value as a Milvus expression literal."""
|
||||||
|
if isinstance(value, str):
|
||||||
|
escaped = value.replace('\\', '\\\\').replace('"', '\\"')
|
||||||
|
return f'"{escaped}"'
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
class MilvusVectorDatabase(VectorDatabase):
|
class MilvusVectorDatabase(VectorDatabase):
|
||||||
"""Milvus vector database implementation"""
|
"""Milvus vector database implementation"""
|
||||||
@@ -155,6 +201,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
ids: list[str],
|
ids: list[str],
|
||||||
embeddings_list: list[list[float]],
|
embeddings_list: list[list[float]],
|
||||||
metadatas: list[dict[str, Any]],
|
metadatas: list[dict[str, Any]],
|
||||||
|
documents: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add vector embeddings to Milvus collection
|
"""Add vector embeddings to Milvus collection
|
||||||
|
|
||||||
@@ -200,7 +247,16 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
|
|
||||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'")
|
self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'")
|
||||||
|
|
||||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
|
async def search(
|
||||||
|
self,
|
||||||
|
collection: str,
|
||||||
|
query_embedding: list[float],
|
||||||
|
k: int = 5,
|
||||||
|
search_type: str = 'vector',
|
||||||
|
query_text: str = '',
|
||||||
|
filter: dict[str, Any] | None = None,
|
||||||
|
vector_weight: float | None = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Search for similar vectors in Milvus collection
|
"""Search for similar vectors in Milvus collection
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -217,14 +273,19 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
# Perform search
|
# Perform search
|
||||||
search_params = {'metric_type': 'COSINE', 'params': {}}
|
search_params = {'metric_type': 'COSINE', 'params': {}}
|
||||||
|
|
||||||
results = await asyncio.to_thread(
|
search_kwargs: dict[str, Any] = dict(
|
||||||
self.client.search,
|
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
data=[query_embedding],
|
data=[query_embedding],
|
||||||
limit=k,
|
limit=k,
|
||||||
search_params=search_params,
|
search_params=search_params,
|
||||||
output_fields=['text', 'file_id', 'chunk_uuid'],
|
output_fields=['text', 'file_id', 'chunk_uuid'],
|
||||||
)
|
)
|
||||||
|
if filter:
|
||||||
|
expr = _build_milvus_expr(filter)
|
||||||
|
if expr:
|
||||||
|
search_kwargs['filter'] = expr
|
||||||
|
|
||||||
|
results = await asyncio.to_thread(self.client.search, **search_kwargs)
|
||||||
|
|
||||||
# Convert results to Chroma-compatible format
|
# Convert results to Chroma-compatible format
|
||||||
# Milvus returns: [[ {id, distance, entity: {...}} ]]
|
# Milvus returns: [[ {id, distance, entity: {...}} ]]
|
||||||
@@ -268,6 +329,77 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
await asyncio.to_thread(self.client.delete, collection_name=collection, filter=f'file_id == "{file_id}"')
|
await asyncio.to_thread(self.client.delete, collection_name=collection, filter=f'file_id == "{file_id}"')
|
||||||
self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}")
|
self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}")
|
||||||
|
|
||||||
|
async def delete_by_filter(self, collection: str, filter: dict[str, Any]) -> int:
|
||||||
|
collection = self._normalize_collection_name(collection)
|
||||||
|
await self.get_or_create_collection(collection)
|
||||||
|
|
||||||
|
expr = _build_milvus_expr(filter)
|
||||||
|
if not expr:
|
||||||
|
self.ap.logger.warning(
|
||||||
|
f"Milvus delete_by_filter on '{collection}': filter produced empty expression, skipping"
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
await asyncio.to_thread(self.client.delete, collection_name=collection, filter=expr)
|
||||||
|
self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' by filter")
|
||||||
|
return 0 # Milvus delete does not return a count
|
||||||
|
|
||||||
|
async def list_by_filter(
|
||||||
|
self,
|
||||||
|
collection: str,
|
||||||
|
filter: dict[str, Any] | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
collection = self._normalize_collection_name(collection)
|
||||||
|
await self.get_or_create_collection(collection)
|
||||||
|
|
||||||
|
query_kwargs: dict[str, Any] = dict(
|
||||||
|
collection_name=collection,
|
||||||
|
output_fields=['text', 'file_id', 'chunk_uuid'],
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
if filter:
|
||||||
|
expr = _build_milvus_expr(filter)
|
||||||
|
if expr:
|
||||||
|
query_kwargs['filter'] = expr
|
||||||
|
|
||||||
|
results = await asyncio.to_thread(self.client.query, **query_kwargs)
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for row in results:
|
||||||
|
items.append(
|
||||||
|
{
|
||||||
|
'id': row.get('id', ''),
|
||||||
|
'document': row.get('text'),
|
||||||
|
'metadata': {
|
||||||
|
'text': row.get('text', ''),
|
||||||
|
'file_id': row.get('file_id', ''),
|
||||||
|
'uuid': row.get('chunk_uuid', ''),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Milvus query with count(*)
|
||||||
|
total = -1
|
||||||
|
try:
|
||||||
|
count_kwargs: dict[str, Any] = dict(
|
||||||
|
collection_name=collection,
|
||||||
|
output_fields=['count(*)'],
|
||||||
|
)
|
||||||
|
if filter:
|
||||||
|
expr = _build_milvus_expr(filter)
|
||||||
|
if expr:
|
||||||
|
count_kwargs['filter'] = expr
|
||||||
|
count_result = await asyncio.to_thread(self.client.query, **count_kwargs)
|
||||||
|
if count_result:
|
||||||
|
total = count_result[0].get('count(*)', -1)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return items, total
|
||||||
|
|
||||||
async def delete_collection(self, collection: str):
|
async def delete_collection(self, collection: str):
|
||||||
"""Delete a Milvus collection
|
"""Delete a Milvus collection
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,24 @@ from sqlalchemy.orm import declarative_base
|
|||||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||||
from pgvector.sqlalchemy import Vector
|
from pgvector.sqlalchemy import Vector
|
||||||
from langbot.pkg.vector.vdb import VectorDatabase
|
from langbot.pkg.vector.vdb import VectorDatabase
|
||||||
|
from langbot.pkg.vector.filter_utils import normalize_filter, strip_unsupported_fields
|
||||||
from langbot.pkg.core import app
|
from langbot.pkg.core import app
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
# pgvector schema only stores these metadata fields.
|
||||||
|
_PG_SUPPORTED_FIELDS = {'text', 'file_id', 'chunk_uuid'}
|
||||||
|
|
||||||
|
# Callers use canonical metadata key 'uuid' but pgvector stores it as 'chunk_uuid'.
|
||||||
|
_PG_FIELD_ALIASES = {'uuid': 'chunk_uuid'}
|
||||||
|
|
||||||
|
# Map schema field names to SQLAlchemy columns (resolved lazily from PgVectorEntry).
|
||||||
|
_PG_COLUMN_MAP = {
|
||||||
|
'text': 'text',
|
||||||
|
'file_id': 'file_id',
|
||||||
|
'chunk_uuid': 'chunk_uuid',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class PgVectorEntry(Base):
|
class PgVectorEntry(Base):
|
||||||
"""SQLAlchemy model for pgvector entries"""
|
"""SQLAlchemy model for pgvector entries"""
|
||||||
@@ -23,6 +37,33 @@ class PgVectorEntry(Base):
|
|||||||
chunk_uuid = Column(String)
|
chunk_uuid = Column(String)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_pg_conditions(filter_dict: dict[str, Any]) -> list:
|
||||||
|
"""Translate canonical filter dict into a list of SQLAlchemy conditions."""
|
||||||
|
triples = normalize_filter(filter_dict)
|
||||||
|
triples = strip_unsupported_fields(triples, _PG_SUPPORTED_FIELDS, _PG_FIELD_ALIASES)
|
||||||
|
|
||||||
|
conditions = []
|
||||||
|
for field, op, value in triples:
|
||||||
|
col = getattr(PgVectorEntry, _PG_COLUMN_MAP[field])
|
||||||
|
if op == '$eq':
|
||||||
|
conditions.append(col == value)
|
||||||
|
elif op == '$ne':
|
||||||
|
conditions.append(col != value)
|
||||||
|
elif op == '$gt':
|
||||||
|
conditions.append(col > value)
|
||||||
|
elif op == '$gte':
|
||||||
|
conditions.append(col >= value)
|
||||||
|
elif op == '$lt':
|
||||||
|
conditions.append(col < value)
|
||||||
|
elif op == '$lte':
|
||||||
|
conditions.append(col <= value)
|
||||||
|
elif op == '$in':
|
||||||
|
conditions.append(col.in_(value))
|
||||||
|
elif op == '$nin':
|
||||||
|
conditions.append(col.notin_(value))
|
||||||
|
return conditions
|
||||||
|
|
||||||
|
|
||||||
class PgVectorDatabase(VectorDatabase):
|
class PgVectorDatabase(VectorDatabase):
|
||||||
"""PostgreSQL with pgvector extension database implementation"""
|
"""PostgreSQL with pgvector extension database implementation"""
|
||||||
|
|
||||||
@@ -109,6 +150,7 @@ class PgVectorDatabase(VectorDatabase):
|
|||||||
ids: list[str],
|
ids: list[str],
|
||||||
embeddings_list: list[list[float]],
|
embeddings_list: list[list[float]],
|
||||||
metadatas: list[dict[str, Any]],
|
metadatas: list[dict[str, Any]],
|
||||||
|
documents: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add vector embeddings to pgvector
|
"""Add vector embeddings to pgvector
|
||||||
|
|
||||||
@@ -142,7 +184,16 @@ class PgVectorDatabase(VectorDatabase):
|
|||||||
self.ap.logger.error(f'Error adding embeddings to pgvector: {e}')
|
self.ap.logger.error(f'Error adding embeddings to pgvector: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
|
async def search(
|
||||||
|
self,
|
||||||
|
collection: str,
|
||||||
|
query_embedding: list[float],
|
||||||
|
k: int = 5,
|
||||||
|
search_type: str = 'vector',
|
||||||
|
query_text: str = '',
|
||||||
|
filter: dict[str, Any] | None = None,
|
||||||
|
vector_weight: float | None = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Search for similar vectors using cosine distance
|
"""Search for similar vectors using cosine distance
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -174,6 +225,10 @@ class PgVectorDatabase(VectorDatabase):
|
|||||||
.limit(k)
|
.limit(k)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if filter:
|
||||||
|
for cond in _build_pg_conditions(filter):
|
||||||
|
stmt = stmt.filter(cond)
|
||||||
|
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
rows = result.fetchall()
|
rows = result.fetchall()
|
||||||
|
|
||||||
@@ -225,6 +280,98 @@ class PgVectorDatabase(VectorDatabase):
|
|||||||
self.ap.logger.error(f'Error deleting from pgvector: {e}')
|
self.ap.logger.error(f'Error deleting from pgvector: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def delete_by_filter(self, collection: str, filter: dict[str, Any]) -> int:
|
||||||
|
"""Delete vectors matching a metadata filter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection: Collection name
|
||||||
|
filter: Canonical metadata filter dict
|
||||||
|
"""
|
||||||
|
conditions = _build_pg_conditions(filter)
|
||||||
|
if not conditions:
|
||||||
|
self.ap.logger.warning(
|
||||||
|
f"pgvector delete_by_filter on '{collection}': filter produced no conditions, skipping"
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
await self.get_or_create_collection(collection)
|
||||||
|
|
||||||
|
async with self.AsyncSessionLocal() as session:
|
||||||
|
try:
|
||||||
|
from sqlalchemy import delete
|
||||||
|
|
||||||
|
stmt = delete(PgVectorEntry).where(PgVectorEntry.collection == collection)
|
||||||
|
for cond in conditions:
|
||||||
|
stmt = stmt.where(cond)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
await session.commit()
|
||||||
|
deleted = result.rowcount
|
||||||
|
self.ap.logger.info(f"Deleted {deleted} embeddings from pgvector collection '{collection}' by filter")
|
||||||
|
return deleted
|
||||||
|
except Exception as e:
|
||||||
|
await session.rollback()
|
||||||
|
self.ap.logger.error(f'Error deleting from pgvector by filter: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def list_by_filter(
|
||||||
|
self,
|
||||||
|
collection: str,
|
||||||
|
filter: dict[str, Any] | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
await self.get_or_create_collection(collection)
|
||||||
|
|
||||||
|
async with self.AsyncSessionLocal() as session:
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(
|
||||||
|
PgVectorEntry.id,
|
||||||
|
PgVectorEntry.text,
|
||||||
|
PgVectorEntry.file_id,
|
||||||
|
PgVectorEntry.chunk_uuid,
|
||||||
|
)
|
||||||
|
.filter(PgVectorEntry.collection == collection)
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
count_stmt = (
|
||||||
|
select(func.count()).select_from(PgVectorEntry).filter(PgVectorEntry.collection == collection)
|
||||||
|
)
|
||||||
|
|
||||||
|
if filter:
|
||||||
|
for cond in _build_pg_conditions(filter):
|
||||||
|
stmt = stmt.filter(cond)
|
||||||
|
count_stmt = count_stmt.filter(cond)
|
||||||
|
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
rows = result.fetchall()
|
||||||
|
|
||||||
|
count_result = await session.execute(count_stmt)
|
||||||
|
total = count_result.scalar() or 0
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for row in rows:
|
||||||
|
items.append(
|
||||||
|
{
|
||||||
|
'id': row.id,
|
||||||
|
'document': row.text or '',
|
||||||
|
'metadata': {
|
||||||
|
'text': row.text or '',
|
||||||
|
'file_id': row.file_id or '',
|
||||||
|
'uuid': row.chunk_uuid or '',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return items, total
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.error(f'Error listing from pgvector: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
async def delete_collection(self, collection: str):
|
async def delete_collection(self, collection: str):
|
||||||
"""Delete all vectors in a collection
|
"""Delete all vectors in a collection
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,37 @@ from typing import Any, Dict, List
|
|||||||
from qdrant_client import AsyncQdrantClient, models
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
from langbot.pkg.core import app
|
from langbot.pkg.core import app
|
||||||
from langbot.pkg.vector.vdb import VectorDatabase
|
from langbot.pkg.vector.vdb import VectorDatabase
|
||||||
|
from langbot.pkg.vector.filter_utils import normalize_filter
|
||||||
|
|
||||||
|
|
||||||
|
def _build_qdrant_filter(filter_dict: dict[str, Any]) -> models.Filter:
|
||||||
|
"""Translate canonical filter dict into a Qdrant ``models.Filter``."""
|
||||||
|
triples = normalize_filter(filter_dict)
|
||||||
|
must: list[models.Condition] = []
|
||||||
|
must_not: list[models.Condition] = []
|
||||||
|
|
||||||
|
for field, op, value in triples:
|
||||||
|
if op == '$eq':
|
||||||
|
must.append(models.FieldCondition(key=field, match=models.MatchValue(value=value)))
|
||||||
|
elif op == '$ne':
|
||||||
|
must_not.append(models.FieldCondition(key=field, match=models.MatchValue(value=value)))
|
||||||
|
elif op == '$in':
|
||||||
|
must.append(models.FieldCondition(key=field, match=models.MatchAny(any=value)))
|
||||||
|
elif op == '$nin':
|
||||||
|
must_not.append(models.FieldCondition(key=field, match=models.MatchAny(any=value)))
|
||||||
|
elif op in ('$gt', '$gte', '$lt', '$lte'):
|
||||||
|
range_kwargs: dict[str, Any] = {}
|
||||||
|
if op == '$gt':
|
||||||
|
range_kwargs['gt'] = value
|
||||||
|
elif op == '$gte':
|
||||||
|
range_kwargs['gte'] = value
|
||||||
|
elif op == '$lt':
|
||||||
|
range_kwargs['lt'] = value
|
||||||
|
elif op == '$lte':
|
||||||
|
range_kwargs['lte'] = value
|
||||||
|
must.append(models.FieldCondition(key=field, range=models.Range(**range_kwargs)))
|
||||||
|
|
||||||
|
return models.Filter(must=must or None, must_not=must_not or None)
|
||||||
|
|
||||||
|
|
||||||
class QdrantVectorDatabase(VectorDatabase):
|
class QdrantVectorDatabase(VectorDatabase):
|
||||||
@@ -48,6 +79,7 @@ class QdrantVectorDatabase(VectorDatabase):
|
|||||||
ids: List[str],
|
ids: List[str],
|
||||||
embeddings_list: List[List[float]],
|
embeddings_list: List[List[float]],
|
||||||
metadatas: List[Dict[str, Any]],
|
metadatas: List[Dict[str, Any]],
|
||||||
|
documents: List[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not embeddings_list:
|
if not embeddings_list:
|
||||||
return
|
return
|
||||||
@@ -60,19 +92,30 @@ class QdrantVectorDatabase(VectorDatabase):
|
|||||||
await self.client.upsert(collection_name=collection, points=points)
|
await self.client.upsert(collection_name=collection, points=points)
|
||||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Qdrant collection '{collection}'.")
|
self.ap.logger.info(f"Added {len(ids)} embeddings to Qdrant collection '{collection}'.")
|
||||||
|
|
||||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
|
async def search(
|
||||||
|
self,
|
||||||
|
collection: str,
|
||||||
|
query_embedding: list[float],
|
||||||
|
k: int = 5,
|
||||||
|
search_type: str = 'vector',
|
||||||
|
query_text: str = '',
|
||||||
|
filter: dict[str, Any] | None = None,
|
||||||
|
vector_weight: float | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
exists = await self.client.collection_exists(collection)
|
exists = await self.client.collection_exists(collection)
|
||||||
if not exists:
|
if not exists:
|
||||||
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]]}
|
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]]}
|
||||||
|
|
||||||
hits = (
|
query_kwargs: dict[str, Any] = dict(
|
||||||
await self.client.query_points(
|
collection_name=collection,
|
||||||
collection_name=collection,
|
query=query_embedding,
|
||||||
query=query_embedding,
|
limit=k,
|
||||||
limit=k,
|
with_payload=True,
|
||||||
with_payload=True,
|
)
|
||||||
)
|
if filter:
|
||||||
).points
|
query_kwargs['query_filter'] = _build_qdrant_filter(filter)
|
||||||
|
|
||||||
|
hits = (await self.client.query_points(**query_kwargs)).points
|
||||||
ids = [str(hit.id) for hit in hits]
|
ids = [str(hit.id) for hit in hits]
|
||||||
metadatas = [hit.payload or {} for hit in hits]
|
metadatas = [hit.payload or {} for hit in hits]
|
||||||
# Qdrant's score is similarity; convert to a pseudo-distance for consistency
|
# Qdrant's score is similarity; convert to a pseudo-distance for consistency
|
||||||
@@ -95,6 +138,110 @@ class QdrantVectorDatabase(VectorDatabase):
|
|||||||
)
|
)
|
||||||
self.ap.logger.info(f"Deleted embeddings from Qdrant collection '{collection}' with file_id: {file_id}")
|
self.ap.logger.info(f"Deleted embeddings from Qdrant collection '{collection}' with file_id: {file_id}")
|
||||||
|
|
||||||
|
async def delete_by_filter(self, collection: str, filter: dict[str, Any]) -> int:
|
||||||
|
exists = await self.client.collection_exists(collection)
|
||||||
|
if not exists:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
qdrant_filter = _build_qdrant_filter(filter)
|
||||||
|
await self.client.delete(
|
||||||
|
collection_name=collection,
|
||||||
|
points_selector=qdrant_filter,
|
||||||
|
)
|
||||||
|
self.ap.logger.info(f"Deleted embeddings from Qdrant collection '{collection}' by filter")
|
||||||
|
return 0 # Qdrant delete does not return a count
|
||||||
|
|
||||||
|
async def list_by_filter(
|
||||||
|
self,
|
||||||
|
collection: str,
|
||||||
|
filter: dict[str, Any] | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
exists = await self.client.collection_exists(collection)
|
||||||
|
if not exists:
|
||||||
|
return [], 0
|
||||||
|
|
||||||
|
qdrant_filter = _build_qdrant_filter(filter) if filter else None
|
||||||
|
|
||||||
|
# Qdrant scroll uses cursor-based pagination (offset = point ID),
|
||||||
|
# not numeric skip. To support numeric offset we scroll through
|
||||||
|
# `offset + limit` items and discard the first `offset`.
|
||||||
|
remaining_to_skip = offset
|
||||||
|
remaining_to_collect = limit
|
||||||
|
cursor: int | str | None = None
|
||||||
|
collected: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
while remaining_to_skip > 0 or remaining_to_collect > 0:
|
||||||
|
batch_size = remaining_to_skip + remaining_to_collect if remaining_to_skip > 0 else remaining_to_collect
|
||||||
|
scroll_kwargs: dict[str, Any] = dict(
|
||||||
|
collection_name=collection,
|
||||||
|
limit=min(batch_size, 256),
|
||||||
|
with_payload=True if remaining_to_skip == 0 else False,
|
||||||
|
with_vectors=False,
|
||||||
|
)
|
||||||
|
if qdrant_filter:
|
||||||
|
scroll_kwargs['scroll_filter'] = qdrant_filter
|
||||||
|
if cursor is not None:
|
||||||
|
scroll_kwargs['offset'] = cursor
|
||||||
|
|
||||||
|
points, next_cursor = await self.client.scroll(**scroll_kwargs)
|
||||||
|
if not points:
|
||||||
|
break
|
||||||
|
|
||||||
|
for point in points:
|
||||||
|
if remaining_to_skip > 0:
|
||||||
|
remaining_to_skip -= 1
|
||||||
|
continue
|
||||||
|
if remaining_to_collect <= 0:
|
||||||
|
break
|
||||||
|
# Re-fetch payload if we skipped it during the skip phase
|
||||||
|
payload = point.payload or {}
|
||||||
|
collected.append(
|
||||||
|
{
|
||||||
|
'id': str(point.id),
|
||||||
|
'document': payload.get('text') or payload.get('document'),
|
||||||
|
'metadata': payload,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
remaining_to_collect -= 1
|
||||||
|
|
||||||
|
if next_cursor is None:
|
||||||
|
break
|
||||||
|
cursor = next_cursor
|
||||||
|
|
||||||
|
# If we skipped without payload, re-fetch the collected items' payloads
|
||||||
|
# (only needed when offset > 0 and items were collected in a skip batch)
|
||||||
|
if offset > 0 and collected:
|
||||||
|
refetch_ids = [item['id'] for item in collected if not item.get('metadata')]
|
||||||
|
if refetch_ids:
|
||||||
|
fetched_points = await self.client.retrieve(
|
||||||
|
collection_name=collection,
|
||||||
|
ids=refetch_ids,
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=False,
|
||||||
|
)
|
||||||
|
payload_map = {str(p.id): p.payload or {} for p in fetched_points}
|
||||||
|
for item in collected:
|
||||||
|
if item['id'] in payload_map:
|
||||||
|
payload = payload_map[item['id']]
|
||||||
|
item['metadata'] = payload
|
||||||
|
item['document'] = payload.get('text') or payload.get('document')
|
||||||
|
|
||||||
|
# Use count() for accurate total (supports filter)
|
||||||
|
total = -1
|
||||||
|
try:
|
||||||
|
count_result = await self.client.count(
|
||||||
|
collection_name=collection,
|
||||||
|
count_filter=qdrant_filter,
|
||||||
|
exact=True,
|
||||||
|
)
|
||||||
|
total = count_result.count
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return collected, total
|
||||||
|
|
||||||
async def delete_collection(self, collection: str):
|
async def delete_collection(self, collection: str):
|
||||||
try:
|
try:
|
||||||
await self.client.delete_collection(collection)
|
await self.client.delete_collection(collection)
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from decimal import Decimal
|
||||||
|
import re
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
from langbot.pkg.core import app
|
from langbot.pkg.core import app
|
||||||
from langbot.pkg.vector.vdb import VectorDatabase
|
from langbot.pkg.vector.vdb import VectorDatabase, SearchType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pyseekdb
|
import pyseekdb
|
||||||
@@ -25,9 +27,13 @@ class SeekDBVectorDatabase(VectorDatabase):
|
|||||||
SeekDB is an AI-native search database by OceanBase that unifies
|
SeekDB is an AI-native search database by OceanBase that unifies
|
||||||
relational, vector, text, JSON and GIS in a single engine.
|
relational, vector, text, JSON and GIS in a single engine.
|
||||||
|
|
||||||
Supports both embedded mode and remote server mode.
|
Supports embedded mode, remote server mode, and full-text/hybrid search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supported_search_types(cls) -> list[SearchType]:
|
||||||
|
return [SearchType.VECTOR, SearchType.FULL_TEXT, SearchType.HYBRID]
|
||||||
|
|
||||||
def __init__(self, ap: app.Application):
|
def __init__(self, ap: app.Application):
|
||||||
if not SEEKDB_AVAILABLE:
|
if not SEEKDB_AVAILABLE:
|
||||||
raise ImportError('pyseekdb is not installed. Install it with: pip install pyseekdb')
|
raise ImportError('pyseekdb is not installed. Install it with: pip install pyseekdb')
|
||||||
@@ -89,6 +95,7 @@ class SeekDBVectorDatabase(VectorDatabase):
|
|||||||
{
|
{
|
||||||
'\x00': '',
|
'\x00': '',
|
||||||
'\\': '\\\\',
|
'\\': '\\\\',
|
||||||
|
"'": "''", # Standard SQL escaping (OceanBase NO_BACKSLASH_ESCAPES)
|
||||||
'"': '\\"',
|
'"': '\\"',
|
||||||
'\n': '\\n',
|
'\n': '\\n',
|
||||||
'\r': '\\r',
|
'\r': '\\r',
|
||||||
@@ -96,8 +103,28 @@ class SeekDBVectorDatabase(VectorDatabase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _normalize_collection_name(self, collection: str) -> str:
|
||||||
|
"""SeekDB only accepts [a-zA-Z0-9_], while LangBot uses UUID-like KB IDs."""
|
||||||
|
normalized = re.sub(r'[^A-Za-z0-9_]', '_', collection)
|
||||||
|
if normalized != collection:
|
||||||
|
self.ap.logger.info(f"Normalized SeekDB collection name: '{collection}' -> '{normalized}'")
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def _json_safe(self, value: Any) -> Any:
|
||||||
|
"""Convert SeekDB result values into JSON-serializable Python primitives."""
|
||||||
|
if isinstance(value, Decimal):
|
||||||
|
return float(value)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {k: self._json_safe(v) for k, v in value.items()}
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [self._json_safe(v) for v in value]
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
return [self._json_safe(v) for v in value]
|
||||||
|
return value
|
||||||
|
|
||||||
async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None) -> Any:
|
async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None) -> Any:
|
||||||
"""Internal method to get or create a collection with proper configuration."""
|
"""Internal method to get or create a collection with proper configuration."""
|
||||||
|
collection = self._normalize_collection_name(collection)
|
||||||
if collection in self._collections:
|
if collection in self._collections:
|
||||||
return self._collections[collection]
|
return self._collections[collection]
|
||||||
|
|
||||||
@@ -111,8 +138,10 @@ class SeekDBVectorDatabase(VectorDatabase):
|
|||||||
|
|
||||||
# Collection doesn't exist, create it
|
# Collection doesn't exist, create it
|
||||||
if vector_size is None:
|
if vector_size is None:
|
||||||
# Default dimension if not specified
|
raise ValueError(
|
||||||
vector_size = 384
|
f"Cannot create SeekDB collection '{collection}' without knowing the vector dimension. "
|
||||||
|
'Ensure add_embeddings is called before any standalone get_or_create_collection.'
|
||||||
|
)
|
||||||
|
|
||||||
# Create HNSW configuration
|
# Create HNSW configuration
|
||||||
config = HNSWConfiguration(dimension=vector_size, distance='cosine')
|
config = HNSWConfiguration(dimension=vector_size, distance='cosine')
|
||||||
@@ -147,7 +176,12 @@ class SeekDBVectorDatabase(VectorDatabase):
|
|||||||
return await self._get_or_create_collection_internal(collection)
|
return await self._get_or_create_collection_internal(collection)
|
||||||
|
|
||||||
async def add_embeddings(
|
async def add_embeddings(
|
||||||
self, collection: str, ids: List[str], embeddings_list: List[List[float]], metadatas: List[Dict[str, Any]]
|
self,
|
||||||
|
collection: str,
|
||||||
|
ids: List[str],
|
||||||
|
embeddings_list: List[List[float]],
|
||||||
|
metadatas: List[Dict[str, Any]],
|
||||||
|
documents: List[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add vector embeddings to the specified collection.
|
"""Add vector embeddings to the specified collection.
|
||||||
|
|
||||||
@@ -156,31 +190,51 @@ class SeekDBVectorDatabase(VectorDatabase):
|
|||||||
ids: List of document IDs
|
ids: List of document IDs
|
||||||
embeddings_list: List of embedding vectors
|
embeddings_list: List of embedding vectors
|
||||||
metadatas: List of metadata dictionaries
|
metadatas: List of metadata dictionaries
|
||||||
|
documents: Optional raw text documents for full-text search support
|
||||||
"""
|
"""
|
||||||
if not embeddings_list:
|
if not embeddings_list:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
collection = self._normalize_collection_name(collection)
|
||||||
# Ensure collection exists with correct dimension
|
# Ensure collection exists with correct dimension
|
||||||
vector_size = len(embeddings_list[0])
|
vector_size = len(embeddings_list[0])
|
||||||
coll = await self._get_or_create_collection_internal(collection, vector_size)
|
coll = await self._get_or_create_collection_internal(collection, vector_size)
|
||||||
|
|
||||||
cleaned_metadatas = [self._clean_metadata(meta) for meta in metadatas]
|
cleaned_metadatas = [self._clean_metadata(meta) for meta in metadatas]
|
||||||
|
|
||||||
await asyncio.to_thread(coll.add, ids=ids, embeddings=embeddings_list, metadatas=cleaned_metadatas)
|
kwargs: Dict[str, Any] = dict(ids=ids, embeddings=embeddings_list, metadatas=cleaned_metadatas)
|
||||||
|
if documents is not None:
|
||||||
|
kwargs['documents'] = [doc.translate(self._escape_table) for doc in documents]
|
||||||
|
await asyncio.to_thread(coll.add, **kwargs)
|
||||||
|
|
||||||
self.ap.logger.info(f"Added {len(ids)} embeddings to SeekDB collection '{collection}'")
|
self.ap.logger.info(f"Added {len(ids)} embeddings to SeekDB collection '{collection}'")
|
||||||
|
|
||||||
async def search(self, collection: str, query_embedding: List[float], k: int = 5) -> Dict[str, Any]:
|
async def search(
|
||||||
|
self,
|
||||||
|
collection: str,
|
||||||
|
query_embedding: List[float],
|
||||||
|
k: int = 5,
|
||||||
|
search_type: str = 'vector',
|
||||||
|
query_text: str = '',
|
||||||
|
filter: Dict[str, Any] | None = None,
|
||||||
|
vector_weight: float | None = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Search for the most similar vectors in the specified collection.
|
"""Search for the most similar vectors in the specified collection.
|
||||||
|
|
||||||
|
SeekDB supports vector, full-text, and hybrid search modes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
collection: Collection name
|
collection: Collection name
|
||||||
query_embedding: Query vector
|
query_embedding: Query vector (used for vector and hybrid modes)
|
||||||
k: Number of results to return
|
k: Number of results to return
|
||||||
|
search_type: One of 'vector', 'full_text', 'hybrid'
|
||||||
|
query_text: Raw query text (used for full_text and hybrid modes)
|
||||||
|
filter: Optional metadata filters (Chroma-style ``where`` syntax).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with 'ids', 'metadatas', 'distances' keys
|
Dictionary with 'ids', 'metadatas', 'distances' keys
|
||||||
"""
|
"""
|
||||||
|
collection = self._normalize_collection_name(collection)
|
||||||
# Check if collection exists
|
# Check if collection exists
|
||||||
exists = await asyncio.to_thread(self.client.has_collection, collection)
|
exists = await asyncio.to_thread(self.client.has_collection, collection)
|
||||||
if not exists:
|
if not exists:
|
||||||
@@ -193,11 +247,88 @@ class SeekDBVectorDatabase(VectorDatabase):
|
|||||||
else:
|
else:
|
||||||
coll = self._collections[collection]
|
coll = self._collections[collection]
|
||||||
|
|
||||||
# Perform query
|
# Route by search type.
|
||||||
# SeekDB's query() returns: {'ids': [[...]], 'metadatas': [[...]], 'distances': [[...]]}
|
# pyseekdb's query() always requires embeddings, so full-text and
|
||||||
results = await asyncio.to_thread(coll.query, query_embeddings=query_embedding, n_results=k)
|
# hybrid modes use hybrid_search() which supports text-only queries
|
||||||
|
# and returns the same nested-list format with distances.
|
||||||
|
if search_type == SearchType.FULL_TEXT:
|
||||||
|
if not query_text:
|
||||||
|
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]]}
|
||||||
|
|
||||||
self.ap.logger.info(f"SeekDB search in '{collection}' returned {len(results.get('ids', [[]])[0])} results")
|
query_cfg: Dict[str, Any] = {
|
||||||
|
'where_document': {'$contains': query_text},
|
||||||
|
'n_results': k,
|
||||||
|
}
|
||||||
|
if filter:
|
||||||
|
query_cfg['where'] = filter
|
||||||
|
|
||||||
|
# TODO: pyseekdb hybrid_search with query-only (no knn) returns None
|
||||||
|
# for IDs due to column name mismatch (*/_id vs _id).
|
||||||
|
# See: https://github.com/oceanbase/pyseekdb/issues/171
|
||||||
|
results = await asyncio.to_thread(
|
||||||
|
coll.hybrid_search,
|
||||||
|
query=query_cfg,
|
||||||
|
knn=None,
|
||||||
|
n_results=k,
|
||||||
|
include=['documents', 'metadatas'],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif search_type == SearchType.HYBRID:
|
||||||
|
if not query_text:
|
||||||
|
# Fall back to pure vector search when no text is provided
|
||||||
|
query_kwargs: Dict[str, Any] = {
|
||||||
|
'n_results': k,
|
||||||
|
'query_embeddings': query_embedding,
|
||||||
|
}
|
||||||
|
if filter:
|
||||||
|
query_kwargs['where'] = filter
|
||||||
|
results = await asyncio.to_thread(coll.query, **query_kwargs)
|
||||||
|
else:
|
||||||
|
query_cfg = {
|
||||||
|
'where_document': {'$contains': query_text},
|
||||||
|
'n_results': k,
|
||||||
|
}
|
||||||
|
knn_cfg: Dict[str, Any] = {
|
||||||
|
'query_embeddings': query_embedding,
|
||||||
|
'n_results': k,
|
||||||
|
}
|
||||||
|
if filter:
|
||||||
|
query_cfg['where'] = filter
|
||||||
|
knn_cfg['where'] = filter
|
||||||
|
|
||||||
|
# Apply vector_weight via pyseekdb's native boost parameter
|
||||||
|
if vector_weight is not None:
|
||||||
|
knn_cfg['boost'] = vector_weight
|
||||||
|
query_cfg['boost'] = 1.0 - vector_weight
|
||||||
|
self.ap.logger.info(
|
||||||
|
f"SeekDB hybrid fusion config in '{collection}': "
|
||||||
|
f'vector_weight={vector_weight}, '
|
||||||
|
f'knn_boost={knn_cfg.get("boost", 1.0)}, '
|
||||||
|
f'query_boost={query_cfg.get("boost", 1.0)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await asyncio.to_thread(
|
||||||
|
coll.hybrid_search,
|
||||||
|
query=query_cfg,
|
||||||
|
knn=knn_cfg,
|
||||||
|
rank={'rrf': {}},
|
||||||
|
n_results=k,
|
||||||
|
include=['documents', 'metadatas'],
|
||||||
|
)
|
||||||
|
self.ap.logger.info(
|
||||||
|
f"SeekDB hybrid search in '{collection}' returned {len(results.get('ids', [[]])[0])} results."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Default: vector search via query()
|
||||||
|
query_kwargs = {'n_results': k, 'query_embeddings': query_embedding}
|
||||||
|
if filter:
|
||||||
|
query_kwargs['where'] = filter
|
||||||
|
results = await asyncio.to_thread(coll.query, **query_kwargs)
|
||||||
|
|
||||||
|
results = self._json_safe(results)
|
||||||
|
self.ap.logger.info(
|
||||||
|
f"SeekDB {search_type} search in '{collection}' returned {len(results.get('ids', [[]])[0])} results"
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -208,6 +339,7 @@ class SeekDBVectorDatabase(VectorDatabase):
|
|||||||
collection: Collection name
|
collection: Collection name
|
||||||
file_id: File ID to delete
|
file_id: File ID to delete
|
||||||
"""
|
"""
|
||||||
|
collection = self._normalize_collection_name(collection)
|
||||||
# Check if collection exists
|
# Check if collection exists
|
||||||
exists = await asyncio.to_thread(self.client.has_collection, collection)
|
exists = await asyncio.to_thread(self.client.has_collection, collection)
|
||||||
if not exists:
|
if not exists:
|
||||||
@@ -227,12 +359,82 @@ class SeekDBVectorDatabase(VectorDatabase):
|
|||||||
|
|
||||||
self.ap.logger.info(f"Deleted embeddings from SeekDB collection '{collection}' with file_id: {file_id}")
|
self.ap.logger.info(f"Deleted embeddings from SeekDB collection '{collection}' with file_id: {file_id}")
|
||||||
|
|
||||||
|
async def delete_by_filter(self, collection: str, filter: Dict[str, Any]) -> int:
|
||||||
|
"""Delete vectors from the collection by metadata filter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection: Collection name
|
||||||
|
filter: Chroma-style ``where`` filter dict
|
||||||
|
"""
|
||||||
|
collection = self._normalize_collection_name(collection)
|
||||||
|
exists = await asyncio.to_thread(self.client.has_collection, collection)
|
||||||
|
if not exists:
|
||||||
|
self.ap.logger.warning(f"SeekDB collection '{collection}' not found for deletion")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if collection not in self._collections:
|
||||||
|
coll = await asyncio.to_thread(self.client.get_collection, collection, embedding_function=None)
|
||||||
|
self._collections[collection] = coll
|
||||||
|
else:
|
||||||
|
coll = self._collections[collection]
|
||||||
|
|
||||||
|
await asyncio.to_thread(coll.delete, where=filter)
|
||||||
|
self.ap.logger.info(f"Deleted embeddings from SeekDB collection '{collection}' by filter")
|
||||||
|
return 0 # SeekDB delete does not return a count
|
||||||
|
|
||||||
|
async def list_by_filter(
|
||||||
|
self,
|
||||||
|
collection: str,
|
||||||
|
filter: Dict[str, Any] | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> tuple[list[Dict[str, Any]], int]:
|
||||||
|
collection = self._normalize_collection_name(collection)
|
||||||
|
exists = await asyncio.to_thread(self.client.has_collection, collection)
|
||||||
|
if not exists:
|
||||||
|
return [], 0
|
||||||
|
|
||||||
|
if collection not in self._collections:
|
||||||
|
coll = await asyncio.to_thread(self.client.get_collection, collection, embedding_function=None)
|
||||||
|
self._collections[collection] = coll
|
||||||
|
else:
|
||||||
|
coll = self._collections[collection]
|
||||||
|
|
||||||
|
get_kwargs: Dict[str, Any] = dict(
|
||||||
|
include=['metadatas', 'documents'],
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
if filter:
|
||||||
|
get_kwargs['where'] = filter
|
||||||
|
|
||||||
|
results = await asyncio.to_thread(coll.get, **get_kwargs)
|
||||||
|
|
||||||
|
results = self._json_safe(results)
|
||||||
|
ids = results.get('ids', [])
|
||||||
|
metadatas = results.get('metadatas', []) or [None] * len(ids)
|
||||||
|
documents = results.get('documents', []) or [None] * len(ids)
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for i, vid in enumerate(ids):
|
||||||
|
items.append(
|
||||||
|
{
|
||||||
|
'id': vid,
|
||||||
|
'document': documents[i] if i < len(documents) else None,
|
||||||
|
'metadata': metadatas[i] if i < len(metadatas) else {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
total = await asyncio.to_thread(coll.count) if not filter else -1
|
||||||
|
return items, total
|
||||||
|
|
||||||
async def delete_collection(self, collection: str):
|
async def delete_collection(self, collection: str):
|
||||||
"""Delete the entire collection.
|
"""Delete the entire collection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
collection: Collection name
|
collection: Collection name
|
||||||
"""
|
"""
|
||||||
|
collection = self._normalize_collection_name(collection)
|
||||||
# Remove from cache
|
# Remove from cache
|
||||||
if collection in self._collections:
|
if collection in self._collections:
|
||||||
del self._collections[collection]
|
del self._collections[collection]
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ admins: []
|
|||||||
api:
|
api:
|
||||||
port: 5300
|
port: 5300
|
||||||
webhook_prefix: 'http://127.0.0.1:5300'
|
webhook_prefix: 'http://127.0.0.1:5300'
|
||||||
|
extra_webhook_prefix: ''
|
||||||
command:
|
command:
|
||||||
enable: true
|
enable: true
|
||||||
prefix:
|
prefix:
|
||||||
@@ -15,6 +16,7 @@ proxy:
|
|||||||
http: ''
|
http: ''
|
||||||
https: ''
|
https: ''
|
||||||
system:
|
system:
|
||||||
|
instance_id: ''
|
||||||
edition: community
|
edition: community
|
||||||
recovery_key: ''
|
recovery_key: ''
|
||||||
allow_modify_login_info: true
|
allow_modify_login_info: true
|
||||||
|
|||||||
@@ -41,7 +41,10 @@
|
|||||||
"runner": "local-agent"
|
"runner": "local-agent"
|
||||||
},
|
},
|
||||||
"local-agent": {
|
"local-agent": {
|
||||||
"model": "",
|
"model": {
|
||||||
|
"primary": "",
|
||||||
|
"fallbacks": []
|
||||||
|
},
|
||||||
"max-round": 10,
|
"max-round": 10,
|
||||||
"prompt": [
|
"prompt": [
|
||||||
{
|
{
|
||||||
@@ -95,11 +98,12 @@
|
|||||||
"max": 0
|
"max": 0
|
||||||
},
|
},
|
||||||
"misc": {
|
"misc": {
|
||||||
"hide-exception": true,
|
"exception-handling": "show-hint",
|
||||||
|
"failure-hint": "Request failed.",
|
||||||
"at-sender": true,
|
"at-sender": true,
|
||||||
"quote-origin": true,
|
"quote-origin": true,
|
||||||
"track-function-calls": false,
|
"track-function-calls": false,
|
||||||
"remove-think": false
|
"remove-think": false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,8 +59,11 @@ stages:
|
|||||||
label:
|
label:
|
||||||
en_US: Model
|
en_US: Model
|
||||||
zh_Hans: 模型
|
zh_Hans: 模型
|
||||||
type: llm-model-selector
|
type: model-fallback-selector
|
||||||
required: true
|
required: true
|
||||||
|
default:
|
||||||
|
primary: ''
|
||||||
|
fallbacks: []
|
||||||
- name: max-round
|
- name: max-round
|
||||||
label:
|
label:
|
||||||
en_US: Max Round
|
en_US: Max Round
|
||||||
|
|||||||
@@ -78,13 +78,39 @@ stages:
|
|||||||
en_US: Misc
|
en_US: Misc
|
||||||
zh_Hans: 杂项
|
zh_Hans: 杂项
|
||||||
config:
|
config:
|
||||||
- name: hide-exception
|
- name: exception-handling
|
||||||
label:
|
label:
|
||||||
en_US: Hide Exception
|
en_US: Exception Handling Strategy
|
||||||
zh_Hans: 不输出异常信息给用户
|
zh_Hans: 异常处理策略
|
||||||
type: boolean
|
description:
|
||||||
|
en_US: Controls how error messages are displayed to the user when an AI request fails
|
||||||
|
zh_Hans: 控制 AI 请求失败时向用户展示错误信息的方式
|
||||||
|
type: select
|
||||||
required: true
|
required: true
|
||||||
default: true
|
default: show-hint
|
||||||
|
options:
|
||||||
|
- name: show-error
|
||||||
|
label:
|
||||||
|
en_US: Show Full Error
|
||||||
|
zh_Hans: 显示完整报错信息
|
||||||
|
- name: show-hint
|
||||||
|
label:
|
||||||
|
en_US: Show Failure Hint
|
||||||
|
zh_Hans: 仅文字提示
|
||||||
|
- name: hide
|
||||||
|
label:
|
||||||
|
en_US: Hide All
|
||||||
|
zh_Hans: 不显示任何异常信息
|
||||||
|
- name: failure-hint
|
||||||
|
label:
|
||||||
|
en_US: Failure Hint Text
|
||||||
|
zh_Hans: 失败提示文本
|
||||||
|
description:
|
||||||
|
en_US: The text to display when a request fails. Only effective when Exception Handling Strategy is set to "Show Failure Hint"
|
||||||
|
zh_Hans: 请求失败时显示的提示文本,仅在异常处理策略设置为"仅文字提示"时生效
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
default: 'Request failed.'
|
||||||
- name: at-sender
|
- name: at-sender
|
||||||
label:
|
label:
|
||||||
en_US: At Sender
|
en_US: At Sender
|
||||||
@@ -119,3 +145,4 @@ stages:
|
|||||||
type: boolean
|
type: boolean
|
||||||
required: true
|
required: true
|
||||||
default: false
|
default: false
|
||||||
|
|
||||||
|
|||||||
@@ -2,77 +2,5 @@
|
|||||||
"说明": "mask将替换敏感词中的每一个字,若mask_word值不为空,则将敏感词整个替换为mask_word的值",
|
"说明": "mask将替换敏感词中的每一个字,若mask_word值不为空,则将敏感词整个替换为mask_word的值",
|
||||||
"mask": "*",
|
"mask": "*",
|
||||||
"mask_word": "",
|
"mask_word": "",
|
||||||
"words": [
|
"words": []
|
||||||
"习近平",
|
|
||||||
"胡锦涛",
|
|
||||||
"江泽民",
|
|
||||||
"温家宝",
|
|
||||||
"李克强",
|
|
||||||
"李长春",
|
|
||||||
"毛泽东",
|
|
||||||
"邓小平",
|
|
||||||
"周恩来",
|
|
||||||
"马克思",
|
|
||||||
"社会主义",
|
|
||||||
"共产党",
|
|
||||||
"共产主义",
|
|
||||||
"大陆官方",
|
|
||||||
"北京政权",
|
|
||||||
"中华帝国",
|
|
||||||
"中国政府",
|
|
||||||
"共狗",
|
|
||||||
"六四事件",
|
|
||||||
"天安门",
|
|
||||||
"六四",
|
|
||||||
"政治局常委",
|
|
||||||
"两会",
|
|
||||||
"共青团",
|
|
||||||
"学潮",
|
|
||||||
"八九",
|
|
||||||
"二十大",
|
|
||||||
"民进党",
|
|
||||||
"台独",
|
|
||||||
"台湾独立",
|
|
||||||
"台湾国",
|
|
||||||
"国民党",
|
|
||||||
"台湾民国",
|
|
||||||
"中华民国",
|
|
||||||
"pornhub",
|
|
||||||
"Pornhub",
|
|
||||||
"[Yy]ou[Pp]orn",
|
|
||||||
"porn",
|
|
||||||
"Porn",
|
|
||||||
"[Xx][Vv]ideos",
|
|
||||||
"[Rr]ed[Tt]ube",
|
|
||||||
"[Xx][Hh]amster",
|
|
||||||
"[Ss]pank[Ww]ire",
|
|
||||||
"[Ss]pank[Bb]ang",
|
|
||||||
"[Tt]ube8",
|
|
||||||
"[Yy]ou[Jj]izz",
|
|
||||||
"[Bb]razzers",
|
|
||||||
"[Nn]aughty[ ]?[Aa]merica",
|
|
||||||
"作爱",
|
|
||||||
"做爱",
|
|
||||||
"性交",
|
|
||||||
"性爱",
|
|
||||||
"自慰",
|
|
||||||
"阴茎",
|
|
||||||
"淫妇",
|
|
||||||
"肛交",
|
|
||||||
"交配",
|
|
||||||
"性关系",
|
|
||||||
"性活动",
|
|
||||||
"色情",
|
|
||||||
"色图",
|
|
||||||
"涩图",
|
|
||||||
"裸体",
|
|
||||||
"小穴",
|
|
||||||
"淫荡",
|
|
||||||
"性爱",
|
|
||||||
"翻墙",
|
|
||||||
"VPN",
|
|
||||||
"科学上网",
|
|
||||||
"挂梯子",
|
|
||||||
"GFW"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
@@ -91,14 +91,15 @@ class TestWebhookDisplayPrefix:
|
|||||||
|
|
||||||
def test_default_webhook_prefix(self):
|
def test_default_webhook_prefix(self):
|
||||||
"""Test that the default webhook display prefix is correctly set"""
|
"""Test that the default webhook display prefix is correctly set"""
|
||||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||||
|
|
||||||
# Should have the default value
|
# Should have the default value
|
||||||
assert cfg['api']['webhook_prefix'] == 'http://127.0.0.1:5300'
|
assert cfg['api']['webhook_prefix'] == 'http://127.0.0.1:5300'
|
||||||
|
assert cfg['api']['extra_webhook_prefix'] == ''
|
||||||
|
|
||||||
def test_webhook_prefix_env_override(self):
|
def test_webhook_prefix_env_override(self):
|
||||||
"""Test overriding webhook_prefix via environment variable"""
|
"""Test overriding webhook_prefix via environment variable"""
|
||||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||||
|
|
||||||
# Set environment variable
|
# Set environment variable
|
||||||
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com:8080'
|
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com:8080'
|
||||||
@@ -112,7 +113,7 @@ class TestWebhookDisplayPrefix:
|
|||||||
|
|
||||||
def test_webhook_prefix_with_custom_domain(self):
|
def test_webhook_prefix_with_custom_domain(self):
|
||||||
"""Test webhook_prefix with custom domain"""
|
"""Test webhook_prefix with custom domain"""
|
||||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||||
|
|
||||||
# Set to a custom domain
|
# Set to a custom domain
|
||||||
os.environ['API__WEBHOOK_PREFIX'] = 'https://bot.mycompany.com'
|
os.environ['API__WEBHOOK_PREFIX'] = 'https://bot.mycompany.com'
|
||||||
@@ -126,7 +127,7 @@ class TestWebhookDisplayPrefix:
|
|||||||
|
|
||||||
def test_webhook_prefix_with_subdirectory(self):
|
def test_webhook_prefix_with_subdirectory(self):
|
||||||
"""Test webhook_prefix with subdirectory path"""
|
"""Test webhook_prefix with subdirectory path"""
|
||||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||||
|
|
||||||
# Set to a URL with subdirectory
|
# Set to a URL with subdirectory
|
||||||
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com/langbot'
|
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com/langbot'
|
||||||
@@ -138,6 +139,37 @@ class TestWebhookDisplayPrefix:
|
|||||||
# Cleanup
|
# Cleanup
|
||||||
del os.environ['API__WEBHOOK_PREFIX']
|
del os.environ['API__WEBHOOK_PREFIX']
|
||||||
|
|
||||||
|
def test_extra_webhook_prefix_default_empty(self):
|
||||||
|
"""Test that extra_webhook_prefix defaults to empty string"""
|
||||||
|
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||||
|
|
||||||
|
bot_uuid = 'test-bot-uuid'
|
||||||
|
webhook_prefix = cfg['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
||||||
|
extra_webhook_prefix = cfg['api'].get('extra_webhook_prefix', '')
|
||||||
|
webhook_url = f'/bots/{bot_uuid}'
|
||||||
|
|
||||||
|
assert f'{webhook_prefix}{webhook_url}' == 'http://127.0.0.1:5300/bots/test-bot-uuid'
|
||||||
|
# extra should be empty when not configured
|
||||||
|
assert extra_webhook_prefix == ''
|
||||||
|
|
||||||
|
def test_extra_webhook_prefix_env_override(self):
|
||||||
|
"""Test overriding extra_webhook_prefix via environment variable"""
|
||||||
|
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||||
|
|
||||||
|
os.environ['API__EXTRA_WEBHOOK_PREFIX'] = 'https://extra.example.com'
|
||||||
|
|
||||||
|
result = _apply_env_overrides_to_config(cfg)
|
||||||
|
|
||||||
|
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'
|
||||||
|
|
||||||
|
bot_uuid = 'test-bot-uuid'
|
||||||
|
extra_prefix = result['api']['extra_webhook_prefix']
|
||||||
|
webhook_url = f'/bots/{bot_uuid}'
|
||||||
|
assert f'{extra_prefix}{webhook_url}' == 'https://extra.example.com/bots/test-bot-uuid'
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
del os.environ['API__EXTRA_WEBHOOK_PREFIX']
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__, '-v'])
|
pytest.main([__file__, '-v'])
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ def sample_query(sample_message_chain, sample_message_event, mock_adapter):
|
|||||||
pipeline_config={
|
pipeline_config={
|
||||||
'ai': {
|
'ai': {
|
||||||
'runner': {'runner': 'local-agent'},
|
'runner': {'runner': 'local-agent'},
|
||||||
'local-agent': {'model': 'test-model-uuid', 'prompt': 'test-prompt'},
|
'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'},
|
||||||
},
|
},
|
||||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
||||||
'trigger': {'misc': {'combine-quote-message': False}},
|
'trigger': {'misc': {'combine-quote-message': False}},
|
||||||
@@ -219,7 +219,7 @@ def sample_pipeline_config():
|
|||||||
return {
|
return {
|
||||||
'ai': {
|
'ai': {
|
||||||
'runner': {'runner': 'local-agent'},
|
'runner': {'runner': 'local-agent'},
|
||||||
'local-agent': {'model': 'test-model-uuid', 'prompt': 'test-prompt'},
|
'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'},
|
||||||
},
|
},
|
||||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
||||||
'trigger': {'misc': {'combine-quote-message': False}},
|
'trigger': {'misc': {'combine-quote-message': False}},
|
||||||
|
|||||||
113
tests/unit_tests/pipeline/test_config_coercion.py
Normal file
113
tests/unit_tests/pipeline/test_config_coercion.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""Unit tests for config_coercion module"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langbot.pkg.pipeline.config_coercion import _coerce_value, coerce_pipeline_config
|
||||||
|
|
||||||
|
|
||||||
|
class TestCoerceValue:
|
||||||
|
"""Tests for _coerce_value function"""
|
||||||
|
|
||||||
|
def test_none_passthrough(self):
|
||||||
|
assert _coerce_value(None, 'integer') is None
|
||||||
|
assert _coerce_value(None, 'boolean') is None
|
||||||
|
|
||||||
|
def test_string_to_integer(self):
|
||||||
|
assert _coerce_value('120', 'integer') == 120
|
||||||
|
assert _coerce_value('0', 'integer') == 0
|
||||||
|
assert _coerce_value('-5', 'integer') == -5
|
||||||
|
|
||||||
|
def test_integer_passthrough(self):
|
||||||
|
assert _coerce_value(42, 'integer') == 42
|
||||||
|
|
||||||
|
def test_string_to_float(self):
|
||||||
|
assert _coerce_value('3.14', 'number') == 3.14
|
||||||
|
assert _coerce_value('3.14', 'float') == 3.14
|
||||||
|
|
||||||
|
def test_int_to_float(self):
|
||||||
|
assert _coerce_value(3, 'number') == 3.0
|
||||||
|
assert isinstance(_coerce_value(3, 'number'), float)
|
||||||
|
|
||||||
|
def test_float_passthrough(self):
|
||||||
|
assert _coerce_value(3.14, 'float') == 3.14
|
||||||
|
|
||||||
|
def test_string_to_bool(self):
|
||||||
|
assert _coerce_value('true', 'boolean') is True
|
||||||
|
assert _coerce_value('True', 'boolean') is True
|
||||||
|
assert _coerce_value('false', 'boolean') is False
|
||||||
|
assert _coerce_value('False', 'boolean') is False
|
||||||
|
|
||||||
|
def test_bool_passthrough(self):
|
||||||
|
assert _coerce_value(True, 'boolean') is True
|
||||||
|
assert _coerce_value(False, 'boolean') is False
|
||||||
|
|
||||||
|
def test_invalid_bool_string_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_coerce_value('notabool', 'boolean')
|
||||||
|
|
||||||
|
def test_unknown_type_passthrough(self):
|
||||||
|
assert _coerce_value('hello', 'string') == 'hello'
|
||||||
|
assert _coerce_value('hello', 'unknown') == 'hello'
|
||||||
|
|
||||||
|
def test_invalid_integer_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_coerce_value('abc', 'integer')
|
||||||
|
|
||||||
|
|
||||||
|
class TestCoercePipelineConfig:
|
||||||
|
"""Tests for coerce_pipeline_config function"""
|
||||||
|
|
||||||
|
def _make_meta(self, section_name: str, stage_name: str, fields: list[dict]) -> dict:
|
||||||
|
return {
|
||||||
|
'name': section_name,
|
||||||
|
'stages': [{'name': stage_name, 'config': fields}],
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_coerce_integer_in_config(self):
|
||||||
|
config = {'trigger': {'misc': {'timeout': '120'}}}
|
||||||
|
meta = self._make_meta('trigger', 'misc', [{'name': 'timeout', 'type': 'integer'}])
|
||||||
|
coerce_pipeline_config(config, meta)
|
||||||
|
assert config['trigger']['misc']['timeout'] == 120
|
||||||
|
|
||||||
|
def test_coerce_boolean_in_config(self):
|
||||||
|
config = {'output': {'misc': {'at-sender': 'true'}}}
|
||||||
|
meta = self._make_meta('output', 'misc', [{'name': 'at-sender', 'type': 'boolean'}])
|
||||||
|
coerce_pipeline_config(config, meta)
|
||||||
|
assert config['output']['misc']['at-sender'] is True
|
||||||
|
|
||||||
|
def test_missing_section_skipped(self):
|
||||||
|
config = {'ai': {}}
|
||||||
|
meta = self._make_meta('trigger', 'misc', [{'name': 'x', 'type': 'integer'}])
|
||||||
|
coerce_pipeline_config(config, meta) # should not raise
|
||||||
|
|
||||||
|
def test_missing_field_skipped(self):
|
||||||
|
config = {'trigger': {'misc': {}}}
|
||||||
|
meta = self._make_meta('trigger', 'misc', [{'name': 'nonexistent', 'type': 'integer'}])
|
||||||
|
coerce_pipeline_config(config, meta) # should not raise
|
||||||
|
|
||||||
|
def test_invalid_value_logs_warning(self, caplog):
|
||||||
|
config = {'trigger': {'misc': {'timeout': 'abc'}}}
|
||||||
|
meta = self._make_meta('trigger', 'misc', [{'name': 'timeout', 'type': 'integer'}])
|
||||||
|
import logging
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
coerce_pipeline_config(config, meta)
|
||||||
|
assert config['trigger']['misc']['timeout'] == 'abc' # unchanged
|
||||||
|
assert 'Failed to coerce' in caplog.text
|
||||||
|
|
||||||
|
def test_empty_metadata(self):
|
||||||
|
config = {'trigger': {'misc': {'timeout': '120'}}}
|
||||||
|
coerce_pipeline_config(config) # no metadata args, should not raise
|
||||||
|
|
||||||
|
def test_multiple_metadata(self):
|
||||||
|
config = {
|
||||||
|
'trigger': {'misc': {'timeout': '120'}},
|
||||||
|
'output': {'misc': {'at-sender': 'false'}},
|
||||||
|
}
|
||||||
|
meta_trigger = self._make_meta('trigger', 'misc', [{'name': 'timeout', 'type': 'integer'}])
|
||||||
|
meta_output = self._make_meta('output', 'misc', [{'name': 'at-sender', 'type': 'boolean'}])
|
||||||
|
coerce_pipeline_config(config, meta_trigger, meta_output)
|
||||||
|
assert config['trigger']['misc']['timeout'] == 120
|
||||||
|
assert config['output']['misc']['at-sender'] is False
|
||||||
@@ -38,13 +38,11 @@ async def test_plugin_list_filter_by_component_kinds():
|
|||||||
'manifest': {
|
'manifest': {
|
||||||
'metadata': {
|
'metadata': {
|
||||||
'author': 'author2',
|
'author': 'author2',
|
||||||
'name': 'plugin_with_knowledge_retriever_only',
|
'name': 'plugin_with_knowledge_engine_only',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': [
|
'components': [{'manifest': {'manifest': {'kind': 'KnowledgeEngine', 'metadata': {'name': 'retriever1'}}}}],
|
||||||
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'debug': False,
|
'debug': False,
|
||||||
@@ -81,7 +79,7 @@ async def test_plugin_list_filter_by_component_kinds():
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': [
|
'components': [
|
||||||
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever2'}}}},
|
{'manifest': {'manifest': {'kind': 'KnowledgeEngine', 'metadata': {'name': 'retriever2'}}}},
|
||||||
{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool2'}}}},
|
{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool2'}}}},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
@@ -108,8 +106,8 @@ async def test_plugin_list_filter_by_component_kinds():
|
|||||||
assert 'plugin_with_command' in plugin_names
|
assert 'plugin_with_command' in plugin_names
|
||||||
assert 'plugin_with_event_listener' in plugin_names
|
assert 'plugin_with_event_listener' in plugin_names
|
||||||
assert 'plugin_with_mixed_components' in plugin_names
|
assert 'plugin_with_mixed_components' in plugin_names
|
||||||
# Plugin with only KnowledgeRetriever should NOT be included
|
# Plugin with only KnowledgeEngine should NOT be included
|
||||||
assert 'plugin_with_knowledge_retriever_only' not in plugin_names
|
assert 'plugin_with_knowledge_engine_only' not in plugin_names
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -150,9 +148,7 @@ async def test_plugin_list_filter_no_filter():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': [
|
'components': [{'manifest': {'manifest': {'kind': 'KnowledgeEngine', 'metadata': {'name': 'retriever1'}}}}],
|
||||||
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -189,7 +185,7 @@ async def test_plugin_list_filter_empty_result():
|
|||||||
connector = PluginRuntimeConnector(mock_app, AsyncMock())
|
connector = PluginRuntimeConnector(mock_app, AsyncMock())
|
||||||
connector.handler = MagicMock()
|
connector.handler = MagicMock()
|
||||||
|
|
||||||
# Mock plugin data - only KnowledgeRetriever plugins
|
# Mock plugin data - only KnowledgeEngine plugins
|
||||||
mock_plugins = [
|
mock_plugins = [
|
||||||
{
|
{
|
||||||
'debug': False,
|
'debug': False,
|
||||||
@@ -201,9 +197,7 @@ async def test_plugin_list_filter_empty_result():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': [
|
'components': [{'manifest': {'manifest': {'kind': 'KnowledgeEngine', 'metadata': {'name': 'retriever1'}}}}],
|
||||||
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user