mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-05 05:16:03 +00:00
Compare commits
13 Commits
feat/agent
...
v4.9.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6896a55485 | ||
|
|
4b0fad233e | ||
|
|
52eb991a70 | ||
|
|
10c716be0c | ||
|
|
6e77351eda | ||
|
|
20f5ebd9b8 | ||
|
|
d2c75329cf | ||
|
|
7e2fe082f0 | ||
|
|
d451b059fd | ||
|
|
93c52fcd4c | ||
|
|
f1608682e6 | ||
|
|
077e631c13 | ||
|
|
d7df1f05d1 |
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "langbot"
|
name = "langbot"
|
||||||
version = "4.9.0"
|
version = "4.9.2"
|
||||||
description = "Production-grade platform for building agentic IM bots"
|
description = "Production-grade platform for building agentic IM bots"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license-files = ["LICENSE"]
|
license-files = ["LICENSE"]
|
||||||
@@ -64,7 +64,7 @@ dependencies = [
|
|||||||
"chromadb>=1.0.0,<2.0.0",
|
"chromadb>=1.0.0,<2.0.0",
|
||||||
"qdrant-client (>=1.15.1,<2.0.0)",
|
"qdrant-client (>=1.15.1,<2.0.0)",
|
||||||
"pyseekdb==1.1.0.post3",
|
"pyseekdb==1.1.0.post3",
|
||||||
"langbot-plugin==0.3.0",
|
"langbot-plugin==0.3.1",
|
||||||
"asyncpg>=0.30.0",
|
"asyncpg>=0.30.0",
|
||||||
"line-bot-sdk>=3.19.0",
|
"line-bot-sdk>=3.19.0",
|
||||||
"tboxsdk>=0.0.10",
|
"tboxsdk>=0.0.10",
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
"""LangBot - Production-grade platform for building agentic IM bots"""
|
"""LangBot - Production-grade platform for building agentic IM bots"""
|
||||||
|
|
||||||
__version__ = '4.9.0'
|
__version__ = '4.9.2'
|
||||||
|
|||||||
@@ -199,6 +199,253 @@ class StreamSessionManager:
|
|||||||
self._msg_index.pop(msg_id, None)
|
self._msg_index.pop(msg_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def download_encrypted_file(download_url: str, encoding_aes_key: str, logger: EventLogger) -> Optional[str]:
|
||||||
|
"""Download an AES-encrypted file from WeChat Work and return as data URI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
download_url: The encrypted file download URL.
|
||||||
|
encoding_aes_key: The AES key used for decryption (base64-encoded, without trailing '=').
|
||||||
|
logger: Logger instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A data URI string (e.g. 'data:image/jpeg;base64,...') or None on failure.
|
||||||
|
"""
|
||||||
|
if not download_url:
|
||||||
|
return None
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(download_url)
|
||||||
|
if response.status_code != 200:
|
||||||
|
await logger.error(f'failed to get file: {response.text}')
|
||||||
|
return None
|
||||||
|
encrypted_bytes = response.content
|
||||||
|
|
||||||
|
aes_key = base64.b64decode(encoding_aes_key + '=')
|
||||||
|
iv = aes_key[:16]
|
||||||
|
|
||||||
|
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
||||||
|
decrypted = cipher.decrypt(encrypted_bytes)
|
||||||
|
|
||||||
|
pad_len = decrypted[-1]
|
||||||
|
decrypted = decrypted[:-pad_len]
|
||||||
|
|
||||||
|
if decrypted.startswith(b'\xff\xd8'):
|
||||||
|
mime_type = 'image/jpeg'
|
||||||
|
elif decrypted.startswith(b'\x89PNG'):
|
||||||
|
mime_type = 'image/png'
|
||||||
|
elif decrypted.startswith((b'GIF87a', b'GIF89a')):
|
||||||
|
mime_type = 'image/gif'
|
||||||
|
elif decrypted.startswith(b'BM'):
|
||||||
|
mime_type = 'image/bmp'
|
||||||
|
elif decrypted.startswith(b'II*\x00') or decrypted.startswith(b'MM\x00*'):
|
||||||
|
mime_type = 'image/tiff'
|
||||||
|
else:
|
||||||
|
mime_type = 'application/octet-stream'
|
||||||
|
|
||||||
|
base64_str = base64.b64encode(decrypted).decode('utf-8')
|
||||||
|
return f'data:{mime_type};base64,{base64_str}'
|
||||||
|
|
||||||
|
|
||||||
|
async def parse_wecom_bot_message(
|
||||||
|
msg_json: dict[str, Any], encoding_aes_key: str, logger: EventLogger
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Parse a decrypted WeChat Work AI Bot message JSON into a unified message dict.
|
||||||
|
|
||||||
|
This is the shared message parsing logic used by both webhook and WebSocket modes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg_json: The decrypted message JSON from WeChat Work.
|
||||||
|
encoding_aes_key: AES key for file decryption.
|
||||||
|
logger: Logger instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict suitable for constructing a WecomBotEvent.
|
||||||
|
"""
|
||||||
|
message_data: dict[str, Any] = {}
|
||||||
|
|
||||||
|
msg_type = msg_json.get('msgtype', '')
|
||||||
|
if msg_type:
|
||||||
|
message_data['msgtype'] = msg_type
|
||||||
|
|
||||||
|
if msg_json.get('chattype', '') == 'single':
|
||||||
|
message_data['type'] = 'single'
|
||||||
|
elif msg_json.get('chattype', '') == 'group':
|
||||||
|
message_data['type'] = 'group'
|
||||||
|
|
||||||
|
max_inline_file_size = 5 * 1024 * 1024
|
||||||
|
|
||||||
|
async def _safe_download(url: str):
|
||||||
|
if not url:
|
||||||
|
return None
|
||||||
|
return await download_encrypted_file(url, encoding_aes_key, logger)
|
||||||
|
|
||||||
|
if msg_type == 'text':
|
||||||
|
message_data['content'] = msg_json.get('text', {}).get('content')
|
||||||
|
elif msg_type == 'markdown':
|
||||||
|
message_data['content'] = msg_json.get('markdown', {}).get('content') or msg_json.get('text', {}).get(
|
||||||
|
'content', ''
|
||||||
|
)
|
||||||
|
elif msg_type == 'image':
|
||||||
|
picurl = msg_json.get('image', {}).get('url', '')
|
||||||
|
base64_data = await _safe_download(picurl)
|
||||||
|
if base64_data:
|
||||||
|
message_data['picurl'] = base64_data
|
||||||
|
message_data['images'] = [base64_data]
|
||||||
|
elif msg_type == 'voice':
|
||||||
|
voice_info = msg_json.get('voice', {}) or {}
|
||||||
|
download_url = voice_info.get('url')
|
||||||
|
message_data['voice'] = {
|
||||||
|
'url': download_url,
|
||||||
|
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||||
|
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||||
|
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||||
|
}
|
||||||
|
if voice_info.get('content'):
|
||||||
|
message_data['content'] = voice_info.get('content')
|
||||||
|
if (message_data['voice'].get('filesize') or 0) <= max_inline_file_size:
|
||||||
|
voice_base64 = await _safe_download(download_url)
|
||||||
|
if voice_base64:
|
||||||
|
message_data['voice']['base64'] = voice_base64
|
||||||
|
elif msg_type == 'video':
|
||||||
|
video_info = msg_json.get('video', {}) or {}
|
||||||
|
download_url = video_info.get('url')
|
||||||
|
video_data = {
|
||||||
|
'url': download_url,
|
||||||
|
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||||
|
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||||
|
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||||
|
'filename': video_info.get('filename') or video_info.get('name'),
|
||||||
|
}
|
||||||
|
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||||
|
video_base64 = await _safe_download(download_url)
|
||||||
|
if video_base64:
|
||||||
|
video_data['base64'] = video_base64
|
||||||
|
message_data['video'] = video_data
|
||||||
|
elif msg_type == 'file':
|
||||||
|
file_info = msg_json.get('file', {}) or {}
|
||||||
|
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||||
|
file_data = {
|
||||||
|
'filename': file_info.get('filename') or file_info.get('name'),
|
||||||
|
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||||
|
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||||
|
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||||
|
'download_url': download_url,
|
||||||
|
'extra': file_info,
|
||||||
|
}
|
||||||
|
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||||
|
file_base64 = await _safe_download(download_url)
|
||||||
|
if file_base64:
|
||||||
|
file_data['base64'] = file_base64
|
||||||
|
message_data['file'] = file_data
|
||||||
|
elif msg_type == 'link':
|
||||||
|
message_data['link'] = msg_json.get('link', {})
|
||||||
|
if not message_data.get('content'):
|
||||||
|
title = message_data['link'].get('title', '')
|
||||||
|
desc = message_data['link'].get('description') or message_data['link'].get('digest', '')
|
||||||
|
message_data['content'] = '\n'.join(filter(None, [title, desc]))
|
||||||
|
elif msg_type == 'mixed':
|
||||||
|
items = msg_json.get('mixed', {}).get('msg_item', [])
|
||||||
|
texts = []
|
||||||
|
images = []
|
||||||
|
files = []
|
||||||
|
voices = []
|
||||||
|
videos = []
|
||||||
|
links = []
|
||||||
|
for item in items:
|
||||||
|
item_type = item.get('msgtype')
|
||||||
|
if item_type == 'text':
|
||||||
|
texts.append(item.get('text', {}).get('content', ''))
|
||||||
|
elif item_type == 'image':
|
||||||
|
img_url = item.get('image', {}).get('url')
|
||||||
|
base64_data = await _safe_download(img_url)
|
||||||
|
if base64_data:
|
||||||
|
images.append(base64_data)
|
||||||
|
elif item_type == 'file':
|
||||||
|
file_info = item.get('file', {}) or {}
|
||||||
|
download_url = file_info.get('url') or file_info.get('fileurl')
|
||||||
|
file_data = {
|
||||||
|
'filename': file_info.get('filename') or file_info.get('name'),
|
||||||
|
'filesize': file_info.get('filesize') or file_info.get('size'),
|
||||||
|
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
||||||
|
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
||||||
|
'download_url': download_url,
|
||||||
|
'extra': file_info,
|
||||||
|
}
|
||||||
|
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
||||||
|
file_base64 = await _safe_download(download_url)
|
||||||
|
if file_base64:
|
||||||
|
file_data['base64'] = file_base64
|
||||||
|
files.append(file_data)
|
||||||
|
elif item_type == 'voice':
|
||||||
|
voice_info = item.get('voice', {}) or {}
|
||||||
|
download_url = voice_info.get('url')
|
||||||
|
voice_data = {
|
||||||
|
'url': download_url,
|
||||||
|
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
||||||
|
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
||||||
|
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
||||||
|
}
|
||||||
|
if voice_info.get('content'):
|
||||||
|
texts.append(voice_info.get('content'))
|
||||||
|
if (voice_data.get('filesize') or 0) <= max_inline_file_size:
|
||||||
|
voice_base64 = await _safe_download(download_url)
|
||||||
|
if voice_base64:
|
||||||
|
voice_data['base64'] = voice_base64
|
||||||
|
voices.append(voice_data)
|
||||||
|
elif item_type == 'video':
|
||||||
|
video_info = item.get('video', {}) or {}
|
||||||
|
download_url = video_info.get('url')
|
||||||
|
video_data = {
|
||||||
|
'url': download_url,
|
||||||
|
'filesize': video_info.get('filesize') or video_info.get('size'),
|
||||||
|
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
||||||
|
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
||||||
|
'filename': video_info.get('filename') or video_info.get('name'),
|
||||||
|
}
|
||||||
|
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
||||||
|
video_base64 = await _safe_download(download_url)
|
||||||
|
if video_base64:
|
||||||
|
video_data['base64'] = video_base64
|
||||||
|
videos.append(video_data)
|
||||||
|
elif item_type == 'link':
|
||||||
|
links.append(item.get('link', {}))
|
||||||
|
|
||||||
|
if texts:
|
||||||
|
message_data['content'] = ' '.join(texts)
|
||||||
|
if images:
|
||||||
|
message_data['images'] = images
|
||||||
|
message_data['picurl'] = images[0]
|
||||||
|
if files:
|
||||||
|
message_data['files'] = files
|
||||||
|
message_data['file'] = files[0]
|
||||||
|
if voices:
|
||||||
|
message_data['voices'] = voices
|
||||||
|
message_data['voice'] = voices[0]
|
||||||
|
if videos:
|
||||||
|
message_data['videos'] = videos
|
||||||
|
message_data['video'] = videos[0]
|
||||||
|
if links:
|
||||||
|
message_data['link'] = links[0]
|
||||||
|
if items:
|
||||||
|
message_data['attachments'] = items
|
||||||
|
else:
|
||||||
|
message_data['raw_msg'] = msg_json
|
||||||
|
|
||||||
|
from_info = msg_json.get('from', {})
|
||||||
|
message_data['userid'] = from_info.get('userid', '')
|
||||||
|
message_data['username'] = from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
|
||||||
|
|
||||||
|
if msg_json.get('chattype', '') == 'group':
|
||||||
|
message_data['chatid'] = msg_json.get('chatid', '')
|
||||||
|
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
|
||||||
|
|
||||||
|
message_data['msgid'] = msg_json.get('msgid', '')
|
||||||
|
|
||||||
|
if msg_json.get('aibotid'):
|
||||||
|
message_data['aibotid'] = msg_json.get('aibotid', '')
|
||||||
|
|
||||||
|
return message_data
|
||||||
|
|
||||||
|
|
||||||
class WecomBotClient:
|
class WecomBotClient:
|
||||||
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger, unified_mode: bool = False):
|
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger, unified_mode: bool = False):
|
||||||
"""企业微信智能机器人客户端。
|
"""企业微信智能机器人客户端。
|
||||||
@@ -455,196 +702,7 @@ class WecomBotClient:
|
|||||||
return await self._handle_post_initial_response(msg_json, nonce)
|
return await self._handle_post_initial_response(msg_json, nonce)
|
||||||
|
|
||||||
async def get_message(self, msg_json):
|
async def get_message(self, msg_json):
|
||||||
message_data = {}
|
return await parse_wecom_bot_message(msg_json, self.EnCodingAESKey, self.logger)
|
||||||
|
|
||||||
msg_type = msg_json.get('msgtype', '')
|
|
||||||
if msg_type:
|
|
||||||
message_data['msgtype'] = msg_type
|
|
||||||
|
|
||||||
if msg_json.get('chattype', '') == 'single':
|
|
||||||
message_data['type'] = 'single'
|
|
||||||
elif msg_json.get('chattype', '') == 'group':
|
|
||||||
message_data['type'] = 'group'
|
|
||||||
|
|
||||||
max_inline_file_size = 5 * 1024 * 1024 # avoid decoding very large payloads by default
|
|
||||||
|
|
||||||
async def _safe_download(url: str):
|
|
||||||
if not url:
|
|
||||||
return None
|
|
||||||
return await self.download_url_to_base64(url, self.EnCodingAESKey)
|
|
||||||
|
|
||||||
if msg_type == 'text':
|
|
||||||
message_data['content'] = msg_json.get('text', {}).get('content')
|
|
||||||
elif msg_type == 'markdown':
|
|
||||||
message_data['content'] = msg_json.get('markdown', {}).get('content') or msg_json.get('text', {}).get(
|
|
||||||
'content', ''
|
|
||||||
)
|
|
||||||
elif msg_type == 'image':
|
|
||||||
picurl = msg_json.get('image', {}).get('url', '')
|
|
||||||
base64_data = await _safe_download(picurl)
|
|
||||||
if base64_data:
|
|
||||||
message_data['picurl'] = base64_data
|
|
||||||
message_data['images'] = [base64_data]
|
|
||||||
elif msg_type == 'voice':
|
|
||||||
voice_info = msg_json.get('voice', {}) or {}
|
|
||||||
download_url = voice_info.get('url')
|
|
||||||
message_data['voice'] = {
|
|
||||||
'url': download_url,
|
|
||||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
|
||||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
|
||||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
|
||||||
}
|
|
||||||
# 企业微信智能转写文本(如果已有)直接复用,避免重复转写
|
|
||||||
if voice_info.get('content'):
|
|
||||||
message_data['content'] = voice_info.get('content')
|
|
||||||
if (message_data['voice'].get('filesize') or 0) <= max_inline_file_size:
|
|
||||||
voice_base64 = await _safe_download(download_url)
|
|
||||||
if voice_base64:
|
|
||||||
message_data['voice']['base64'] = voice_base64
|
|
||||||
elif msg_type == 'video':
|
|
||||||
video_info = msg_json.get('video', {}) or {}
|
|
||||||
download_url = video_info.get('url')
|
|
||||||
video_data = {
|
|
||||||
'url': download_url,
|
|
||||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
|
||||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
|
||||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
|
||||||
'filename': video_info.get('filename') or video_info.get('name'),
|
|
||||||
}
|
|
||||||
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
|
||||||
video_base64 = await _safe_download(download_url)
|
|
||||||
if video_base64:
|
|
||||||
video_data['base64'] = video_base64
|
|
||||||
message_data['video'] = video_data
|
|
||||||
elif msg_type == 'file':
|
|
||||||
file_info = msg_json.get('file', {}) or {}
|
|
||||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
|
||||||
file_data = {
|
|
||||||
'filename': file_info.get('filename') or file_info.get('name'),
|
|
||||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
|
||||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
|
||||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
|
||||||
'download_url': download_url,
|
|
||||||
'extra': file_info,
|
|
||||||
}
|
|
||||||
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
|
||||||
file_base64 = await _safe_download(download_url)
|
|
||||||
if file_base64:
|
|
||||||
file_data['base64'] = file_base64
|
|
||||||
message_data['file'] = file_data
|
|
||||||
elif msg_type == 'link':
|
|
||||||
message_data['link'] = msg_json.get('link', {})
|
|
||||||
if not message_data.get('content'):
|
|
||||||
title = message_data['link'].get('title', '')
|
|
||||||
desc = message_data['link'].get('description') or message_data['link'].get('digest', '')
|
|
||||||
message_data['content'] = '\n'.join(filter(None, [title, desc]))
|
|
||||||
elif msg_type == 'mixed':
|
|
||||||
items = msg_json.get('mixed', {}).get('msg_item', [])
|
|
||||||
texts = []
|
|
||||||
images = []
|
|
||||||
files = []
|
|
||||||
voices = []
|
|
||||||
videos = []
|
|
||||||
links = []
|
|
||||||
for item in items:
|
|
||||||
item_type = item.get('msgtype')
|
|
||||||
if item_type == 'text':
|
|
||||||
texts.append(item.get('text', {}).get('content', ''))
|
|
||||||
elif item_type == 'image':
|
|
||||||
img_url = item.get('image', {}).get('url')
|
|
||||||
base64_data = await _safe_download(img_url)
|
|
||||||
if base64_data:
|
|
||||||
images.append(base64_data)
|
|
||||||
elif item_type == 'file':
|
|
||||||
file_info = item.get('file', {}) or {}
|
|
||||||
download_url = file_info.get('url') or file_info.get('fileurl')
|
|
||||||
file_data = {
|
|
||||||
'filename': file_info.get('filename') or file_info.get('name'),
|
|
||||||
'filesize': file_info.get('filesize') or file_info.get('size'),
|
|
||||||
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
|
||||||
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
|
||||||
'download_url': download_url,
|
|
||||||
'extra': file_info,
|
|
||||||
}
|
|
||||||
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
|
||||||
file_base64 = await _safe_download(download_url)
|
|
||||||
if file_base64:
|
|
||||||
file_data['base64'] = file_base64
|
|
||||||
files.append(file_data)
|
|
||||||
elif item_type == 'voice':
|
|
||||||
voice_info = item.get('voice', {}) or {}
|
|
||||||
download_url = voice_info.get('url')
|
|
||||||
voice_data = {
|
|
||||||
'url': download_url,
|
|
||||||
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
|
||||||
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
|
||||||
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
|
||||||
}
|
|
||||||
if voice_info.get('content'):
|
|
||||||
texts.append(voice_info.get('content'))
|
|
||||||
if (voice_data.get('filesize') or 0) <= max_inline_file_size:
|
|
||||||
voice_base64 = await _safe_download(download_url)
|
|
||||||
if voice_base64:
|
|
||||||
voice_data['base64'] = voice_base64
|
|
||||||
voices.append(voice_data)
|
|
||||||
elif item_type == 'video':
|
|
||||||
video_info = item.get('video', {}) or {}
|
|
||||||
download_url = video_info.get('url')
|
|
||||||
video_data = {
|
|
||||||
'url': download_url,
|
|
||||||
'filesize': video_info.get('filesize') or video_info.get('size'),
|
|
||||||
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
|
||||||
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
|
||||||
'filename': video_info.get('filename') or video_info.get('name'),
|
|
||||||
}
|
|
||||||
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
|
||||||
video_base64 = await _safe_download(download_url)
|
|
||||||
if video_base64:
|
|
||||||
video_data['base64'] = video_base64
|
|
||||||
videos.append(video_data)
|
|
||||||
elif item_type == 'link':
|
|
||||||
links.append(item.get('link', {}))
|
|
||||||
|
|
||||||
if texts:
|
|
||||||
message_data['content'] = ' '.join(texts) # 拼接所有 text
|
|
||||||
if images:
|
|
||||||
message_data['images'] = images
|
|
||||||
message_data['picurl'] = images[0] # 只保留第一个 image
|
|
||||||
if files:
|
|
||||||
message_data['files'] = files
|
|
||||||
message_data['file'] = files[0]
|
|
||||||
if voices:
|
|
||||||
message_data['voices'] = voices
|
|
||||||
message_data['voice'] = voices[0]
|
|
||||||
if videos:
|
|
||||||
message_data['videos'] = videos
|
|
||||||
message_data['video'] = videos[0]
|
|
||||||
if links:
|
|
||||||
message_data['link'] = links[0]
|
|
||||||
if items:
|
|
||||||
message_data['attachments'] = items
|
|
||||||
else:
|
|
||||||
message_data['raw_msg'] = msg_json
|
|
||||||
|
|
||||||
# Extract user information
|
|
||||||
from_info = msg_json.get('from', {})
|
|
||||||
message_data['userid'] = from_info.get('userid', '')
|
|
||||||
message_data['username'] = (
|
|
||||||
from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract chat/group information
|
|
||||||
if msg_json.get('chattype', '') == 'group':
|
|
||||||
message_data['chatid'] = msg_json.get('chatid', '')
|
|
||||||
# Try to get group name if available
|
|
||||||
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
|
|
||||||
|
|
||||||
message_data['msgid'] = msg_json.get('msgid', '')
|
|
||||||
|
|
||||||
if msg_json.get('aibotid'):
|
|
||||||
message_data['aibotid'] = msg_json.get('aibotid', '')
|
|
||||||
|
|
||||||
return message_data
|
|
||||||
|
|
||||||
async def _handle_message(self, event: wecombotevent.WecomBotEvent):
|
async def _handle_message(self, event: wecombotevent.WecomBotEvent):
|
||||||
"""
|
"""
|
||||||
@@ -712,39 +770,7 @@ class WecomBotClient:
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
async def download_url_to_base64(self, download_url, encoding_aes_key):
|
async def download_url_to_base64(self, download_url, encoding_aes_key):
|
||||||
async with httpx.AsyncClient() as client:
|
return await download_encrypted_file(download_url, encoding_aes_key, self.logger)
|
||||||
response = await client.get(download_url)
|
|
||||||
if response.status_code != 200:
|
|
||||||
await self.logger.error(f'failed to get file: {response.text}')
|
|
||||||
return None
|
|
||||||
|
|
||||||
encrypted_bytes = response.content
|
|
||||||
|
|
||||||
aes_key = base64.b64decode(encoding_aes_key + '=') # base64 补齐
|
|
||||||
iv = aes_key[:16]
|
|
||||||
|
|
||||||
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
|
||||||
decrypted = cipher.decrypt(encrypted_bytes)
|
|
||||||
|
|
||||||
pad_len = decrypted[-1]
|
|
||||||
decrypted = decrypted[:-pad_len]
|
|
||||||
|
|
||||||
if decrypted.startswith(b'\xff\xd8'): # JPEG
|
|
||||||
mime_type = 'image/jpeg'
|
|
||||||
elif decrypted.startswith(b'\x89PNG'): # PNG
|
|
||||||
mime_type = 'image/png'
|
|
||||||
elif decrypted.startswith((b'GIF87a', b'GIF89a')): # GIF
|
|
||||||
mime_type = 'image/gif'
|
|
||||||
elif decrypted.startswith(b'BM'): # BMP
|
|
||||||
mime_type = 'image/bmp'
|
|
||||||
elif decrypted.startswith(b'II*\x00') or decrypted.startswith(b'MM\x00*'): # TIFF
|
|
||||||
mime_type = 'image/tiff'
|
|
||||||
else:
|
|
||||||
mime_type = 'application/octet-stream'
|
|
||||||
|
|
||||||
# 转 base64
|
|
||||||
base64_str = base64.b64encode(decrypted).decode('utf-8')
|
|
||||||
return f'data:{mime_type};base64,{base64_str}'
|
|
||||||
|
|
||||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
596
src/langbot/libs/wecom_ai_bot_api/ws_client.py
Normal file
596
src/langbot/libs/wecom_ai_bot_api/ws_client.py
Normal file
@@ -0,0 +1,596 @@
|
|||||||
|
"""WeChat Work AI Bot WebSocket long connection client.
|
||||||
|
|
||||||
|
Implements the WebSocket protocol for receiving messages and sending replies
|
||||||
|
via a persistent connection to wss://openws.work.weixin.qq.com, as an
|
||||||
|
alternative to the HTTP callback (webhook) mode.
|
||||||
|
|
||||||
|
Protocol reference: https://developer.work.weixin.qq.com/document/path/101463
|
||||||
|
Official Node.js SDK: https://github.com/WecomTeam/aibot-node-sdk
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from langbot.libs.wecom_ai_bot_api import wecombotevent
|
||||||
|
from langbot.libs.wecom_ai_bot_api.api import parse_wecom_bot_message
|
||||||
|
from langbot.pkg.platform.logger import EventLogger
|
||||||
|
|
||||||
|
DEFAULT_WS_URL = 'wss://openws.work.weixin.qq.com'
|
||||||
|
|
||||||
|
# WebSocket frame command constants
|
||||||
|
CMD_SUBSCRIBE = 'aibot_subscribe'
|
||||||
|
CMD_HEARTBEAT = 'ping'
|
||||||
|
CMD_MSG_CALLBACK = 'aibot_msg_callback'
|
||||||
|
CMD_EVENT_CALLBACK = 'aibot_event_callback'
|
||||||
|
CMD_RESPOND_MSG = 'aibot_respond_msg'
|
||||||
|
CMD_RESPOND_WELCOME = 'aibot_respond_welcome_msg'
|
||||||
|
CMD_RESPOND_UPDATE = 'aibot_respond_update_msg'
|
||||||
|
CMD_SEND_MSG = 'aibot_send_msg'
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_req_id(prefix: str) -> str:
|
||||||
|
"""Generate a unique request ID in the format: {prefix}_{timestamp}_{random}."""
|
||||||
|
ts = int(time.time() * 1000)
|
||||||
|
rand = secrets.token_hex(4)
|
||||||
|
return f'{prefix}_{ts}_{rand}'
|
||||||
|
|
||||||
|
|
||||||
|
class WecomBotWsClient:
|
||||||
|
"""WeChat Work AI Bot WebSocket long connection client.
|
||||||
|
|
||||||
|
Provides message receiving, streaming reply, proactive message sending,
|
||||||
|
and event callback handling over a persistent WebSocket connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bot_id: str,
|
||||||
|
secret: str,
|
||||||
|
logger: EventLogger,
|
||||||
|
encoding_aes_key: str = '',
|
||||||
|
ws_url: str = DEFAULT_WS_URL,
|
||||||
|
heartbeat_interval: float = 30.0,
|
||||||
|
max_reconnect_attempts: int = -1,
|
||||||
|
reconnect_base_delay: float = 1.0,
|
||||||
|
reconnect_max_delay: float = 30.0,
|
||||||
|
):
|
||||||
|
self.bot_id = bot_id
|
||||||
|
self.secret = secret
|
||||||
|
self.logger = logger
|
||||||
|
self.encoding_aes_key = encoding_aes_key
|
||||||
|
self.ws_url = ws_url
|
||||||
|
self.heartbeat_interval = heartbeat_interval
|
||||||
|
self.max_reconnect_attempts = max_reconnect_attempts
|
||||||
|
self.reconnect_base_delay = reconnect_base_delay
|
||||||
|
self.reconnect_max_delay = reconnect_max_delay
|
||||||
|
|
||||||
|
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
self._running = False
|
||||||
|
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||||
|
self._missed_pong_count = 0
|
||||||
|
self._max_missed_pong = 2
|
||||||
|
self._reconnect_attempts = 0
|
||||||
|
|
||||||
|
# Message handler registry (same pattern as WecomBotClient)
|
||||||
|
self._message_handlers: dict[str, list[Callable]] = {}
|
||||||
|
# Message deduplication
|
||||||
|
self._msg_id_map: dict[str, int] = {}
|
||||||
|
|
||||||
|
# Pending ACK futures: req_id -> Future[dict]
|
||||||
|
self._pending_acks: dict[str, asyncio.Future] = {}
|
||||||
|
# Per-req_id serial reply queues
|
||||||
|
self._reply_queues: dict[str, asyncio.Queue] = {}
|
||||||
|
self._reply_workers: dict[str, asyncio.Task] = {}
|
||||||
|
self._reply_ack_timeout = 5.0
|
||||||
|
|
||||||
|
# Stream ID tracking for WebSocket mode
|
||||||
|
self._stream_ids: dict[str, str] = {} # msg_id -> req_id|stream_id
|
||||||
|
# Dedup: skip sending when content hasn't changed
|
||||||
|
self._stream_last_content: dict[str, str] = {} # msg_id -> last content sent
|
||||||
|
|
||||||
|
# ── Public API ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""Connect to WebSocket server with automatic reconnection.
|
||||||
|
|
||||||
|
This method blocks until disconnect() is called or max reconnect
|
||||||
|
attempts are exhausted.
|
||||||
|
"""
|
||||||
|
self._running = True
|
||||||
|
self._reconnect_attempts = 0
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
await self._connect_once()
|
||||||
|
except Exception:
|
||||||
|
if not self._running:
|
||||||
|
break
|
||||||
|
await self.logger.error(f'WebSocket connection error: {traceback.format_exc()}')
|
||||||
|
|
||||||
|
if not self._running:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reconnect with exponential backoff
|
||||||
|
if self.max_reconnect_attempts != -1 and self._reconnect_attempts >= self.max_reconnect_attempts:
|
||||||
|
await self.logger.error(f'Max reconnect attempts reached ({self.max_reconnect_attempts}), giving up')
|
||||||
|
break
|
||||||
|
|
||||||
|
self._reconnect_attempts += 1
|
||||||
|
delay = min(
|
||||||
|
self.reconnect_base_delay * (2 ** (self._reconnect_attempts - 1)),
|
||||||
|
self.reconnect_max_delay,
|
||||||
|
)
|
||||||
|
await self.logger.info(f'Reconnecting in {delay:.1f}s (attempt {self._reconnect_attempts})...')
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
"""Gracefully disconnect from the WebSocket server."""
|
||||||
|
self._running = False
|
||||||
|
if self._heartbeat_task and not self._heartbeat_task.done():
|
||||||
|
self._heartbeat_task.cancel()
|
||||||
|
for task in self._reply_workers.values():
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
if self._ws and not self._ws.closed:
|
||||||
|
await self._ws.close()
|
||||||
|
self._ws = None
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
def on_message(self, msg_type: str) -> Callable:
|
||||||
|
"""Decorator to register a message handler.
|
||||||
|
|
||||||
|
Same interface as WecomBotClient.on_message for compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg_type: 'single', 'group', or specific message type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable[[wecombotevent.WecomBotEvent], Any]):
|
||||||
|
if msg_type not in self._message_handlers:
|
||||||
|
self._message_handlers[msg_type] = []
|
||||||
|
self._message_handlers[msg_type].append(func)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
async def reply_stream(
|
||||||
|
self,
|
||||||
|
req_id: str,
|
||||||
|
stream_id: str,
|
||||||
|
content: str,
|
||||||
|
finish: bool = False,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Send a streaming reply frame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req_id: The req_id from the original message frame (must be passed through).
|
||||||
|
stream_id: The stream ID for this streaming session.
|
||||||
|
content: The content to send (supports Markdown).
|
||||||
|
finish: Whether this is the final chunk.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ACK frame dict, or None on failure.
|
||||||
|
"""
|
||||||
|
body = {
|
||||||
|
'msgtype': 'stream',
|
||||||
|
'stream': {
|
||||||
|
'id': stream_id,
|
||||||
|
'finish': finish,
|
||||||
|
'content': content,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return await self._send_reply(req_id, body)
|
||||||
|
|
||||||
|
async def reply_text(self, req_id: str, content: str) -> Optional[dict]:
|
||||||
|
"""Send a non-streaming text reply.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req_id: The req_id from the original message frame.
|
||||||
|
content: The text content to reply.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ACK frame dict, or None on failure.
|
||||||
|
"""
|
||||||
|
body = {
|
||||||
|
'msgtype': 'markdown',
|
||||||
|
'markdown': {
|
||||||
|
'content': content,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return await self._send_reply(req_id, body)
|
||||||
|
|
||||||
|
async def send_message(self, chat_id: str, content: str, msgtype: str = 'markdown') -> Optional[dict]:
|
||||||
|
"""Proactively send a message to a specified chat.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: The chat ID (userid for single chat, chatid for group chat).
|
||||||
|
content: The message content.
|
||||||
|
msgtype: Message type, 'markdown' by default.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ACK frame dict, or None on failure.
|
||||||
|
"""
|
||||||
|
req_id = _generate_req_id(CMD_SEND_MSG)
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
'chatid': chat_id,
|
||||||
|
'msgtype': msgtype,
|
||||||
|
}
|
||||||
|
if msgtype == 'markdown':
|
||||||
|
body['markdown'] = {'content': content}
|
||||||
|
elif msgtype == 'text':
|
||||||
|
body['text'] = {'content': content}
|
||||||
|
return await self._send_reply(req_id, body, cmd=CMD_SEND_MSG)
|
||||||
|
|
||||||
|
async def push_stream_chunk(self, msg_id: str, content: str, is_final: bool = False) -> bool:
|
||||||
|
"""Push a streaming chunk for a given message ID.
|
||||||
|
|
||||||
|
Compatible interface with WecomBotClient.push_stream_chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg_id: The original message ID.
|
||||||
|
content: The cumulative content from the pipeline.
|
||||||
|
is_final: Whether this is the final chunk.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the stream session exists and chunk was sent.
|
||||||
|
"""
|
||||||
|
key = self._stream_ids.get(msg_id)
|
||||||
|
if not key:
|
||||||
|
return False
|
||||||
|
req_id, stream_id = key.split('|', 1)
|
||||||
|
try:
|
||||||
|
# Skip sending if content hasn't changed (e.g. during tool call argument streaming)
|
||||||
|
if not is_final and content == self._stream_last_content.get(msg_id):
|
||||||
|
return True
|
||||||
|
await self.reply_stream(req_id, stream_id, content, finish=is_final)
|
||||||
|
self._stream_last_content[msg_id] = content
|
||||||
|
if is_final:
|
||||||
|
self._stream_ids.pop(msg_id, None)
|
||||||
|
self._stream_last_content.pop(msg_id, None)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Failed to push stream chunk: {traceback.format_exc()}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def set_message(self, msg_id: str, content: str):
|
||||||
|
"""Fallback: send content as a final stream chunk or direct reply.
|
||||||
|
|
||||||
|
Compatible interface with WecomBotClient.set_message.
|
||||||
|
"""
|
||||||
|
handled = await self.push_stream_chunk(msg_id, content, is_final=True)
|
||||||
|
if not handled:
|
||||||
|
await self.logger.warning(f'No active stream for msg_id={msg_id}, message dropped')
|
||||||
|
|
||||||
|
# ── Connection lifecycle ────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _connect_once(self):
|
||||||
|
"""Establish a single WebSocket connection, authenticate, and listen."""
|
||||||
|
await self.logger.info(f'Connecting to {self.ws_url}...')
|
||||||
|
|
||||||
|
self._session = aiohttp.ClientSession()
|
||||||
|
try:
|
||||||
|
self._ws = await self._session.ws_connect(self.ws_url)
|
||||||
|
self._missed_pong_count = 0
|
||||||
|
self._reconnect_attempts = 0
|
||||||
|
await self.logger.info('WebSocket connected, sending auth...')
|
||||||
|
|
||||||
|
await self._send_auth()
|
||||||
|
|
||||||
|
# Wait for auth response
|
||||||
|
auth_ok = await self._wait_for_auth()
|
||||||
|
if not auth_ok:
|
||||||
|
await self.logger.error('Authentication failed')
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.logger.info('Authenticated successfully')
|
||||||
|
|
||||||
|
# Start heartbeat
|
||||||
|
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._listen_loop()
|
||||||
|
finally:
|
||||||
|
if self._heartbeat_task and not self._heartbeat_task.done():
|
||||||
|
self._heartbeat_task.cancel()
|
||||||
|
self._clear_pending_acks('Connection closed')
|
||||||
|
finally:
|
||||||
|
if self._ws and not self._ws.closed:
|
||||||
|
await self._ws.close()
|
||||||
|
self._ws = None
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
async def _send_auth(self):
|
||||||
|
"""Send the authentication frame."""
|
||||||
|
frame = {
|
||||||
|
'cmd': CMD_SUBSCRIBE,
|
||||||
|
'headers': {'req_id': _generate_req_id(CMD_SUBSCRIBE)},
|
||||||
|
'body': {
|
||||||
|
'bot_id': self.bot_id,
|
||||||
|
'secret': self.secret,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await self._send_frame(frame)
|
||||||
|
|
||||||
|
async def _wait_for_auth(self) -> bool:
|
||||||
|
"""Wait for and validate the authentication response."""
|
||||||
|
try:
|
||||||
|
msg = await asyncio.wait_for(self._ws.receive(), timeout=10.0)
|
||||||
|
if msg.type in (aiohttp.WSMsgType.TEXT,):
|
||||||
|
frame = json.loads(msg.data)
|
||||||
|
req_id = frame.get('headers', {}).get('req_id', '')
|
||||||
|
if req_id.startswith(CMD_SUBSCRIBE) and frame.get('errcode') == 0:
|
||||||
|
return True
|
||||||
|
await self.logger.error(f'Auth response: errcode={frame.get("errcode")}, errmsg={frame.get("errmsg")}')
|
||||||
|
return False
|
||||||
|
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||||
|
await self.logger.error(f'WebSocket closed during auth: {msg.type}')
|
||||||
|
return False
|
||||||
|
await self.logger.error(f'Unexpected message type during auth: {msg.type}')
|
||||||
|
return False
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
await self.logger.error('Auth response timeout')
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _heartbeat_loop(self):
|
||||||
|
"""Periodically send heartbeat pings."""
|
||||||
|
try:
|
||||||
|
while self._running and self._ws and not self._ws.closed:
|
||||||
|
await asyncio.sleep(self.heartbeat_interval)
|
||||||
|
if not self._running or not self._ws or self._ws.closed:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self._missed_pong_count >= self._max_missed_pong:
|
||||||
|
await self.logger.warning(
|
||||||
|
f'No heartbeat ack for {self._missed_pong_count} consecutive pings, connection considered dead'
|
||||||
|
)
|
||||||
|
await self._ws.close()
|
||||||
|
break
|
||||||
|
|
||||||
|
self._missed_pong_count += 1
|
||||||
|
frame = {
|
||||||
|
'cmd': CMD_HEARTBEAT,
|
||||||
|
'headers': {'req_id': _generate_req_id(CMD_HEARTBEAT)},
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
await self._send_frame(frame)
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _listen_loop(self):
|
||||||
|
"""Listen for incoming WebSocket frames and dispatch them."""
|
||||||
|
async for msg in self._ws:
|
||||||
|
if not self._running:
|
||||||
|
break
|
||||||
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||||
|
try:
|
||||||
|
frame = json.loads(msg.data)
|
||||||
|
await self._handle_frame(frame)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
await self.logger.error(f'Failed to parse WebSocket message: {str(msg.data)[:200]}')
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Error handling frame: {traceback.format_exc()}')
|
||||||
|
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||||
|
try:
|
||||||
|
frame = json.loads(msg.data)
|
||||||
|
await self._handle_frame(frame)
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Error handling binary frame: {traceback.format_exc()}')
|
||||||
|
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||||
|
await self.logger.warning(f'WebSocket connection closed: {msg.type}')
|
||||||
|
break
|
||||||
|
|
||||||
|
# ── Frame handling ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _handle_frame(self, frame: dict):
|
||||||
|
"""Route an incoming frame to the appropriate handler."""
|
||||||
|
cmd = frame.get('cmd', '')
|
||||||
|
|
||||||
|
# Message push
|
||||||
|
if cmd == CMD_MSG_CALLBACK:
|
||||||
|
asyncio.create_task(self._handle_message_callback(frame))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Event push
|
||||||
|
if cmd == CMD_EVENT_CALLBACK:
|
||||||
|
asyncio.create_task(self._handle_event_callback(frame))
|
||||||
|
return
|
||||||
|
|
||||||
|
# No cmd → response/ACK frame, dispatch by req_id prefix
|
||||||
|
req_id = frame.get('headers', {}).get('req_id', '')
|
||||||
|
|
||||||
|
# Check pending ACKs first
|
||||||
|
if req_id in self._pending_acks:
|
||||||
|
future = self._pending_acks.pop(req_id)
|
||||||
|
if not future.done():
|
||||||
|
future.set_result(frame)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Heartbeat response
|
||||||
|
if req_id.startswith(CMD_HEARTBEAT):
|
||||||
|
if frame.get('errcode') == 0:
|
||||||
|
self._missed_pong_count = 0
|
||||||
|
return
|
||||||
|
|
||||||
|
# Unknown frame
|
||||||
|
await self.logger.warning(f'Unknown frame: {json.dumps(frame, ensure_ascii=False)[:200]}')
|
||||||
|
|
||||||
|
async def _handle_message_callback(self, frame: dict):
|
||||||
|
"""Handle an incoming message callback frame."""
|
||||||
|
try:
|
||||||
|
body = frame.get('body', {})
|
||||||
|
req_id = frame.get('headers', {}).get('req_id', '')
|
||||||
|
|
||||||
|
# Parse message using shared logic
|
||||||
|
message_data = await parse_wecom_bot_message(body, self.encoding_aes_key, self.logger)
|
||||||
|
if not message_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Generate stream_id for this message and store the mapping
|
||||||
|
stream_id = _generate_req_id('stream')
|
||||||
|
msg_id = message_data.get('msgid', '')
|
||||||
|
if msg_id:
|
||||||
|
self._stream_ids[msg_id] = f'{req_id}|{stream_id}'
|
||||||
|
message_data['stream_id'] = stream_id
|
||||||
|
message_data['req_id'] = req_id
|
||||||
|
|
||||||
|
event = wecombotevent.WecomBotEvent(message_data)
|
||||||
|
await self._dispatch_event(event)
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Error in message callback: {traceback.format_exc()}')
|
||||||
|
|
||||||
|
async def _handle_event_callback(self, frame: dict):
|
||||||
|
"""Handle an incoming event callback frame (enter_chat, template_card_event, etc.)."""
|
||||||
|
try:
|
||||||
|
body = frame.get('body', {})
|
||||||
|
req_id = frame.get('headers', {}).get('req_id', '')
|
||||||
|
|
||||||
|
event_info = body.get('event', {})
|
||||||
|
event_type = event_info.get('eventtype', '')
|
||||||
|
|
||||||
|
message_data = {
|
||||||
|
'msgtype': 'event',
|
||||||
|
'type': body.get('chattype', 'single'),
|
||||||
|
'event': event_info,
|
||||||
|
'eventtype': event_type,
|
||||||
|
'msgid': body.get('msgid', ''),
|
||||||
|
'aibotid': body.get('aibotid', ''),
|
||||||
|
'req_id': req_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
from_info = body.get('from', {})
|
||||||
|
message_data['userid'] = from_info.get('userid', '')
|
||||||
|
message_data['username'] = from_info.get('alias', '') or from_info.get('userid', '')
|
||||||
|
|
||||||
|
if body.get('chatid'):
|
||||||
|
message_data['chatid'] = body.get('chatid', '')
|
||||||
|
|
||||||
|
event = wecombotevent.WecomBotEvent(message_data)
|
||||||
|
|
||||||
|
# Dispatch to event-specific handlers
|
||||||
|
if event_type in self._message_handlers:
|
||||||
|
for handler in self._message_handlers[event_type]:
|
||||||
|
await handler(event)
|
||||||
|
|
||||||
|
# Also dispatch to generic 'event' handlers
|
||||||
|
if 'event' in self._message_handlers:
|
||||||
|
for handler in self._message_handlers['event']:
|
||||||
|
await handler(event)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Error in event callback: {traceback.format_exc()}')
|
||||||
|
|
||||||
|
async def _dispatch_event(self, event: wecombotevent.WecomBotEvent):
|
||||||
|
"""Dispatch a message event to registered handlers with deduplication."""
|
||||||
|
try:
|
||||||
|
message_id = event.message_id
|
||||||
|
if message_id in self._msg_id_map:
|
||||||
|
self._msg_id_map[message_id] += 1
|
||||||
|
return
|
||||||
|
self._msg_id_map[message_id] = 1
|
||||||
|
|
||||||
|
msg_type = event.type
|
||||||
|
if msg_type in self._message_handlers:
|
||||||
|
for handler in self._message_handlers[msg_type]:
|
||||||
|
await handler(event)
|
||||||
|
except Exception:
|
||||||
|
await self.logger.error(f'Error dispatching event: {traceback.format_exc()}')
|
||||||
|
|
||||||
|
# ── Reply sending with serial queue ─────────────────────────────
|
||||||
|
|
||||||
|
async def _send_reply(
|
||||||
|
self,
|
||||||
|
req_id: str,
|
||||||
|
body: dict,
|
||||||
|
cmd: str = CMD_RESPOND_MSG,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Send a reply frame and wait for ACK.
|
||||||
|
|
||||||
|
Replies with the same req_id are serialized to maintain ordering.
|
||||||
|
"""
|
||||||
|
if not self._ws or self._ws.closed:
|
||||||
|
return None
|
||||||
|
|
||||||
|
frame = {
|
||||||
|
'cmd': cmd,
|
||||||
|
'headers': {'req_id': req_id},
|
||||||
|
'body': body,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Ensure serial delivery per req_id
|
||||||
|
if req_id not in self._reply_queues:
|
||||||
|
self._reply_queues[req_id] = asyncio.Queue()
|
||||||
|
self._reply_workers[req_id] = asyncio.create_task(self._reply_queue_worker(req_id))
|
||||||
|
|
||||||
|
future: asyncio.Future = asyncio.get_event_loop().create_future()
|
||||||
|
await self._reply_queues[req_id].put((frame, future))
|
||||||
|
return await future
|
||||||
|
|
||||||
|
async def _reply_queue_worker(self, req_id: str):
|
||||||
|
"""Process reply queue items serially for a given req_id."""
|
||||||
|
queue = self._reply_queues[req_id]
|
||||||
|
try:
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
frame, future = await asyncio.wait_for(queue.get(), timeout=60.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Queue idle, clean up worker
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
ack = await self._send_and_wait_ack(frame)
|
||||||
|
if not future.done():
|
||||||
|
future.set_result(ack)
|
||||||
|
except Exception as e:
|
||||||
|
if not future.done():
|
||||||
|
future.set_exception(e)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self._reply_queues.pop(req_id, None)
|
||||||
|
self._reply_workers.pop(req_id, None)
|
||||||
|
|
||||||
|
async def _send_and_wait_ack(self, frame: dict) -> Optional[dict]:
|
||||||
|
"""Send a frame and wait for the corresponding ACK."""
|
||||||
|
req_id = frame['headers']['req_id']
|
||||||
|
ack_future: asyncio.Future = asyncio.get_event_loop().create_future()
|
||||||
|
self._pending_acks[req_id] = ack_future
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._send_frame(frame)
|
||||||
|
result = await asyncio.wait_for(ack_future, timeout=self._reply_ack_timeout)
|
||||||
|
if result.get('errcode', 0) != 0:
|
||||||
|
await self.logger.warning(
|
||||||
|
f'Reply ACK error: errcode={result.get("errcode")}, errmsg={result.get("errmsg")}'
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._pending_acks.pop(req_id, None)
|
||||||
|
await self.logger.warning(f'Reply ACK timeout ({self._reply_ack_timeout}s) for req_id={req_id}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _send_frame(self, frame: dict):
|
||||||
|
"""Send a JSON frame over the WebSocket connection."""
|
||||||
|
if self._ws and not self._ws.closed:
|
||||||
|
await self._ws.send_str(json.dumps(frame, ensure_ascii=False))
|
||||||
|
|
||||||
|
def _clear_pending_acks(self, reason: str):
|
||||||
|
"""Reject all pending ACK futures on disconnection."""
|
||||||
|
for req_id, future in self._pending_acks.items():
|
||||||
|
if not future.done():
|
||||||
|
future.set_exception(ConnectionError(reason))
|
||||||
|
self._pending_acks.clear()
|
||||||
@@ -70,12 +70,17 @@ class BotService:
|
|||||||
'lark',
|
'lark',
|
||||||
]:
|
]:
|
||||||
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
||||||
|
extra_webhook_prefix = self.ap.instance_config.data['api'].get('extra_webhook_prefix', '')
|
||||||
webhook_url = f'/bots/{bot_uuid}'
|
webhook_url = f'/bots/{bot_uuid}'
|
||||||
adapter_runtime_values['webhook_url'] = webhook_url
|
adapter_runtime_values['webhook_url'] = webhook_url
|
||||||
adapter_runtime_values['webhook_full_url'] = f'{webhook_prefix}{webhook_url}'
|
adapter_runtime_values['webhook_full_url'] = f'{webhook_prefix}{webhook_url}'
|
||||||
|
adapter_runtime_values['extra_webhook_full_url'] = (
|
||||||
|
f'{extra_webhook_prefix}{webhook_url}' if extra_webhook_prefix else ''
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
adapter_runtime_values['webhook_url'] = None
|
adapter_runtime_values['webhook_url'] = None
|
||||||
adapter_runtime_values['webhook_full_url'] = None
|
adapter_runtime_values['webhook_full_url'] = None
|
||||||
|
adapter_runtime_values['extra_webhook_full_url'] = None
|
||||||
|
|
||||||
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
|
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
|
||||||
|
|
||||||
|
|||||||
@@ -105,11 +105,16 @@ class LLMModelsService:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
pipeline = result.first()
|
pipeline = result.first()
|
||||||
if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '':
|
if pipeline is not None:
|
||||||
pipeline_config = pipeline.config
|
model_config = pipeline.config.get('ai', {}).get('local-agent', {}).get('model', {})
|
||||||
pipeline_config['ai']['local-agent']['model'] = model_data['uuid']
|
if not model_config.get('primary', ''):
|
||||||
pipeline_data = {'config': pipeline_config}
|
pipeline_config = pipeline.config
|
||||||
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
|
pipeline_config['ai']['local-agent']['model'] = {
|
||||||
|
'primary': model_data['uuid'],
|
||||||
|
'fallbacks': [],
|
||||||
|
}
|
||||||
|
pipeline_data = {'config': pipeline_config}
|
||||||
|
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
|
||||||
|
|
||||||
return model_data['uuid']
|
return model_data['uuid']
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
from .. import migration
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
@migration.migration_class(24)
|
||||||
|
class DBMigrateWecomBotWebSocketMode(migration.DBMigration):
|
||||||
|
"""Add enable-webhook field to existing wecombot adapter configs.
|
||||||
|
|
||||||
|
Existing wecombot bots were all using webhook mode, so we set
|
||||||
|
enable-webhook=true to preserve their behavior after the new
|
||||||
|
WebSocket long connection mode is introduced as default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def upgrade(self):
|
||||||
|
"""Upgrade"""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text("SELECT uuid, adapter_config FROM bots WHERE adapter = 'wecombot'")
|
||||||
|
)
|
||||||
|
bots = result.fetchall()
|
||||||
|
|
||||||
|
for bot_row in bots:
|
||||||
|
bot_uuid = bot_row[0]
|
||||||
|
adapter_config = json.loads(bot_row[1]) if isinstance(bot_row[1], str) else bot_row[1]
|
||||||
|
|
||||||
|
if 'enable-webhook' in adapter_config:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Determine mode based on existing config: if webhook fields are present, keep webhook mode
|
||||||
|
has_webhook_config = bool(
|
||||||
|
adapter_config.get('Token') and adapter_config.get('EncodingAESKey') and adapter_config.get('Corpid')
|
||||||
|
)
|
||||||
|
adapter_config['enable-webhook'] = has_webhook_config
|
||||||
|
|
||||||
|
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('UPDATE bots SET adapter_config = :config::jsonb WHERE uuid = :uuid'),
|
||||||
|
{'config': json.dumps(adapter_config), 'uuid': bot_uuid},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text('UPDATE bots SET adapter_config = :config WHERE uuid = :uuid'),
|
||||||
|
{'config': json.dumps(adapter_config), 'uuid': bot_uuid},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def downgrade(self):
|
||||||
|
"""Downgrade"""
|
||||||
|
pass
|
||||||
@@ -575,6 +575,127 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
|||||||
|
|
||||||
|
|
||||||
class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||||
|
_processed_thread_quote_cache: typing.ClassVar[dict[str, float]] = {}
|
||||||
|
_processed_thread_quote_cache_max_size: typing.ClassVar[int] = 4096
|
||||||
|
_processed_thread_quote_cache_ttl_seconds: typing.ClassVar[int] = 86400
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _prune_processed_thread_quote_cache(cls, now: typing.Optional[float] = None) -> None:
|
||||||
|
if now is None:
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
expire_before = now - cls._processed_thread_quote_cache_ttl_seconds
|
||||||
|
while cls._processed_thread_quote_cache:
|
||||||
|
oldest_key, oldest_ts = next(iter(cls._processed_thread_quote_cache.items()))
|
||||||
|
if oldest_ts >= expire_before:
|
||||||
|
break
|
||||||
|
cls._processed_thread_quote_cache.pop(oldest_key, None)
|
||||||
|
|
||||||
|
while len(cls._processed_thread_quote_cache) > cls._processed_thread_quote_cache_max_size:
|
||||||
|
oldest_key = next(iter(cls._processed_thread_quote_cache))
|
||||||
|
cls._processed_thread_quote_cache.pop(oldest_key, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _mark_thread_quote_processed(cls, thread_id: str) -> None:
|
||||||
|
now = time.time()
|
||||||
|
cls._prune_processed_thread_quote_cache(now)
|
||||||
|
cls._processed_thread_quote_cache[thread_id] = now
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_quote_message_id(cls, message: EventMessage) -> typing.Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract the message ID to quote from the given message.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- First thread reply in a topic: return parent_id and mark topic as processed
|
||||||
|
- Follow-up thread replies in the same topic: return None
|
||||||
|
- Non-thread message: return parent_id if valid (non-empty, different from message_id)
|
||||||
|
|
||||||
|
Thread reply state is kept in a bounded TTL cache to avoid unbounded memory growth.
|
||||||
|
"""
|
||||||
|
parent_id = getattr(message, 'parent_id', None)
|
||||||
|
if not parent_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
message_id = getattr(message, 'message_id', None)
|
||||||
|
if parent_id == message_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
thread_id = getattr(message, 'thread_id', None)
|
||||||
|
if thread_id:
|
||||||
|
cls._prune_processed_thread_quote_cache()
|
||||||
|
if thread_id in cls._processed_thread_quote_cache:
|
||||||
|
return None
|
||||||
|
cls._mark_thread_quote_processed(thread_id)
|
||||||
|
|
||||||
|
return parent_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_event_message_from_message_item(message_item: Message) -> typing.Optional[EventMessage]:
|
||||||
|
"""
|
||||||
|
Build EventMessage from SDK typed Message item.
|
||||||
|
|
||||||
|
Returns None if body or content is missing.
|
||||||
|
"""
|
||||||
|
body = getattr(message_item, 'body', None)
|
||||||
|
if not body:
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = getattr(body, 'content', None)
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
event_data = {
|
||||||
|
'message_id': message_item.message_id,
|
||||||
|
'message_type': message_item.msg_type,
|
||||||
|
'content': content,
|
||||||
|
'create_time': message_item.create_time,
|
||||||
|
'mentions': getattr(message_item, 'mentions', []) or [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Preserve thread-related fields
|
||||||
|
if hasattr(message_item, 'parent_id') and message_item.parent_id:
|
||||||
|
event_data['parent_id'] = message_item.parent_id
|
||||||
|
if hasattr(message_item, 'root_id') and message_item.root_id:
|
||||||
|
event_data['root_id'] = message_item.root_id
|
||||||
|
if hasattr(message_item, 'thread_id') and message_item.thread_id:
|
||||||
|
event_data['thread_id'] = message_item.thread_id
|
||||||
|
if hasattr(message_item, 'chat_id') and message_item.chat_id:
|
||||||
|
event_data['chat_id'] = message_item.chat_id
|
||||||
|
|
||||||
|
return EventMessage(event_data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _fetch_quoted_message(
|
||||||
|
quote_message_id: str,
|
||||||
|
api_client: lark_oapi.Client,
|
||||||
|
) -> typing.Optional[platform_message.MessageChain]:
|
||||||
|
"""
|
||||||
|
Fetch the quoted message and convert to MessageChain.
|
||||||
|
|
||||||
|
Returns None if:
|
||||||
|
- API call fails
|
||||||
|
- Response items is empty
|
||||||
|
- Message item normalization fails
|
||||||
|
"""
|
||||||
|
request = GetMessageRequest.builder().message_id(quote_message_id).build()
|
||||||
|
response = await api_client.im.v1.message.aget(request)
|
||||||
|
|
||||||
|
if not response.success():
|
||||||
|
return None
|
||||||
|
|
||||||
|
items = getattr(response.data, 'items', None)
|
||||||
|
if not items:
|
||||||
|
return None
|
||||||
|
|
||||||
|
message_item = items[0]
|
||||||
|
event_message = LarkEventConverter._build_event_message_from_message_item(message_item)
|
||||||
|
if event_message is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
quote_chain = await LarkMessageConverter.target2yiri(event_message, api_client)
|
||||||
|
return quote_chain
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def yiri2target(
|
async def yiri2target(
|
||||||
event: platform_events.MessageEvent,
|
event: platform_events.MessageEvent,
|
||||||
@@ -587,6 +708,23 @@ class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
|||||||
) -> platform_events.Event:
|
) -> platform_events.Event:
|
||||||
message_chain = await LarkMessageConverter.target2yiri(event.event.message, api_client)
|
message_chain = await LarkMessageConverter.target2yiri(event.event.message, api_client)
|
||||||
|
|
||||||
|
# Check for quote/reply message
|
||||||
|
quote_message_id = LarkEventConverter._extract_quote_message_id(event.event.message)
|
||||||
|
if quote_message_id:
|
||||||
|
quote_chain = await LarkEventConverter._fetch_quoted_message(quote_message_id, api_client)
|
||||||
|
if quote_chain:
|
||||||
|
# Filter out Source component from quoted chain, keep only content
|
||||||
|
quote_origin = platform_message.MessageChain(
|
||||||
|
[comp for comp in quote_chain if not isinstance(comp, platform_message.Source)]
|
||||||
|
)
|
||||||
|
if quote_origin:
|
||||||
|
message_chain.append(
|
||||||
|
platform_message.Quote(
|
||||||
|
message_id=quote_message_id,
|
||||||
|
origin=quote_origin,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if event.event.message.chat_type == 'p2p':
|
if event.event.message.chat_type == 'p2p':
|
||||||
return platform_events.FriendMessage(
|
return platform_events.FriendMessage(
|
||||||
sender=platform_entities.Friend(
|
sender=platform_entities.Friend(
|
||||||
@@ -770,6 +908,32 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
self.request_tenant_access_token(tenant_key)
|
self.request_tenant_access_token(tenant_key)
|
||||||
return self.tenant_access_tokens.get(tenant_key)['token'] if self.tenant_access_tokens.get(tenant_key) else None
|
return self.tenant_access_tokens.get(tenant_key)['token'] if self.tenant_access_tokens.get(tenant_key) else None
|
||||||
|
|
||||||
|
def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None:
|
||||||
|
"""
|
||||||
|
Get topic-scoped launcher_id for thread-aware session isolation.
|
||||||
|
|
||||||
|
For group thread messages, returns "{group_id}_{thread_id}"
|
||||||
|
to ensure conversation context stays stable per topic.
|
||||||
|
|
||||||
|
Returns None for non-thread messages or P2P messages.
|
||||||
|
"""
|
||||||
|
source_event = getattr(event.source_platform_object, 'event', None)
|
||||||
|
if not source_event:
|
||||||
|
return None
|
||||||
|
|
||||||
|
message = getattr(source_event, 'message', None)
|
||||||
|
if not message:
|
||||||
|
return None
|
||||||
|
|
||||||
|
thread_id = getattr(message, 'thread_id', None)
|
||||||
|
if not thread_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(event, platform_events.GroupMessage):
|
||||||
|
return f'{event.group.id}_{thread_id}'
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def build_api_client(self, config):
|
def build_api_client(self, config):
|
||||||
app_id = config['app_id']
|
app_id = config['app_id']
|
||||||
app_secret = config['app_secret']
|
app_secret = config['app_secret']
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import langbot_plugin.api.entities.builtin.platform.entities as platform_entitie
|
|||||||
from ..logger import EventLogger
|
from ..logger import EventLogger
|
||||||
from langbot.libs.wecom_ai_bot_api.wecombotevent import WecomBotEvent
|
from langbot.libs.wecom_ai_bot_api.wecombotevent import WecomBotEvent
|
||||||
from langbot.libs.wecom_ai_bot_api.api import WecomBotClient
|
from langbot.libs.wecom_ai_bot_api.api import WecomBotClient
|
||||||
|
from langbot.libs.wecom_ai_bot_api.ws_client import WecomBotWsClient
|
||||||
|
|
||||||
|
|
||||||
class WecomBotMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
class WecomBotMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||||
@@ -176,27 +177,42 @@ class WecomBotEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
|||||||
|
|
||||||
|
|
||||||
class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||||
bot: WecomBotClient
|
bot: typing.Union[WecomBotClient, WecomBotWsClient]
|
||||||
bot_account_id: str
|
bot_account_id: str
|
||||||
message_converter: WecomBotMessageConverter = WecomBotMessageConverter()
|
message_converter: WecomBotMessageConverter = WecomBotMessageConverter()
|
||||||
event_converter: WecomBotEventConverter = WecomBotEventConverter()
|
event_converter: WecomBotEventConverter = WecomBotEventConverter()
|
||||||
config: dict
|
config: dict
|
||||||
bot_uuid: str = None
|
bot_uuid: str = None
|
||||||
|
_ws_mode: bool = False
|
||||||
|
|
||||||
def __init__(self, config: dict, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
required_keys = ['Token', 'EncodingAESKey', 'Corpid', 'BotId']
|
enable_webhook = config.get('enable-webhook', False)
|
||||||
missing_keys = [key for key in required_keys if key not in config]
|
|
||||||
if missing_keys:
|
|
||||||
raise Exception(f'WecomBot 缺少配置项: {missing_keys}')
|
|
||||||
|
|
||||||
bot = WecomBotClient(
|
if not enable_webhook:
|
||||||
Token=config['Token'],
|
bot = WecomBotWsClient(
|
||||||
EnCodingAESKey=config['EncodingAESKey'],
|
bot_id=config['BotId'],
|
||||||
Corpid=config['Corpid'],
|
secret=config['Secret'],
|
||||||
logger=logger,
|
logger=logger,
|
||||||
unified_mode=True,
|
encoding_aes_key=config.get('EncodingAESKey', ''),
|
||||||
)
|
)
|
||||||
bot_account_id = config['BotId']
|
ws_mode = True
|
||||||
|
else:
|
||||||
|
# Webhook callback mode
|
||||||
|
required_keys = ['Token', 'EncodingAESKey', 'Corpid']
|
||||||
|
missing_keys = [key for key in required_keys if key not in config or not config[key]]
|
||||||
|
if missing_keys:
|
||||||
|
raise Exception(f'WecomBot webhook mode missing config: {missing_keys}')
|
||||||
|
|
||||||
|
bot = WecomBotClient(
|
||||||
|
Token=config['Token'],
|
||||||
|
EnCodingAESKey=config['EncodingAESKey'],
|
||||||
|
Corpid=config['Corpid'],
|
||||||
|
logger=logger,
|
||||||
|
unified_mode=True,
|
||||||
|
)
|
||||||
|
ws_mode = False
|
||||||
|
|
||||||
|
bot_account_id = config.get('BotId', '')
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
config=config,
|
config=config,
|
||||||
@@ -204,6 +220,7 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
bot=bot,
|
bot=bot,
|
||||||
bot_account_id=bot_account_id,
|
bot_account_id=bot_account_id,
|
||||||
)
|
)
|
||||||
|
self._ws_mode = ws_mode
|
||||||
|
|
||||||
async def reply_message(
|
async def reply_message(
|
||||||
self,
|
self,
|
||||||
@@ -212,7 +229,15 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
quote_origin: bool = False,
|
quote_origin: bool = False,
|
||||||
):
|
):
|
||||||
content = await self.message_converter.yiri2target(message)
|
content = await self.message_converter.yiri2target(message)
|
||||||
await self.bot.set_message(message_source.source_platform_object.message_id, content)
|
if self._ws_mode:
|
||||||
|
event = message_source.source_platform_object
|
||||||
|
req_id = event.get('req_id', '')
|
||||||
|
if req_id:
|
||||||
|
await self.bot.reply_text(req_id, content)
|
||||||
|
else:
|
||||||
|
await self.bot.set_message(event.message_id, content)
|
||||||
|
else:
|
||||||
|
await self.bot.set_message(message_source.source_platform_object.message_id, content)
|
||||||
|
|
||||||
async def reply_message_chunk(
|
async def reply_message_chunk(
|
||||||
self,
|
self,
|
||||||
@@ -222,31 +247,22 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
quote_origin: bool = False,
|
quote_origin: bool = False,
|
||||||
is_final: bool = False,
|
is_final: bool = False,
|
||||||
):
|
):
|
||||||
"""将流水线增量输出写入企业微信 stream 会话。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_source: 流水线提供的原始消息事件。
|
|
||||||
bot_message: 当前片段对应的模型元信息(未使用)。
|
|
||||||
message: 需要回复的消息链。
|
|
||||||
quote_origin: 是否引用原消息(企业微信暂不支持)。
|
|
||||||
is_final: 标记当前片段是否为最终回复。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 包含 `stream` 键,标识写入是否成功。
|
|
||||||
|
|
||||||
Example:
|
|
||||||
在流水线 `reply_message_chunk` 调用中自动触发,无需手动调用。
|
|
||||||
"""
|
|
||||||
# 转换为纯文本(智能机器人当前协议仅支持文本流)
|
|
||||||
content = await self.message_converter.yiri2target(message)
|
content = await self.message_converter.yiri2target(message)
|
||||||
msg_id = message_source.source_platform_object.message_id
|
msg_id = message_source.source_platform_object.message_id
|
||||||
|
|
||||||
# 将片段推送到 WecomBotClient 中的队列,返回值用于判断是否走降级逻辑
|
if self._ws_mode:
|
||||||
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
|
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
|
||||||
if not success and is_final:
|
if not success and is_final:
|
||||||
# 未命中流式队列时使用旧有 set_message 兜底
|
event = message_source.source_platform_object
|
||||||
await self.bot.set_message(msg_id, content)
|
req_id = event.get('req_id', '')
|
||||||
return {'stream': success}
|
if req_id:
|
||||||
|
await self.bot.reply_text(req_id, content)
|
||||||
|
return {'stream': success}
|
||||||
|
else:
|
||||||
|
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
|
||||||
|
if not success and is_final:
|
||||||
|
await self.bot.set_message(msg_id, content)
|
||||||
|
return {'stream': success}
|
||||||
|
|
||||||
async def is_stream_output_supported(self) -> bool:
|
async def is_stream_output_supported(self) -> bool:
|
||||||
"""智能机器人侧默认开启流式能力。
|
"""智能机器人侧默认开启流式能力。
|
||||||
@@ -259,7 +275,11 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
async def send_message(self, target_type, target_id, message):
|
async def send_message(self, target_type, target_id, message):
|
||||||
pass
|
if self._ws_mode:
|
||||||
|
content = await self.message_converter.yiri2target(message)
|
||||||
|
await self.bot.send_message(target_id, content)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
def register_listener(
|
def register_listener(
|
||||||
self,
|
self,
|
||||||
@@ -288,29 +308,25 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
|||||||
self.bot_uuid = bot_uuid
|
self.bot_uuid = bot_uuid
|
||||||
|
|
||||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||||
"""处理统一 webhook 请求。
|
if self._ws_mode:
|
||||||
|
return None
|
||||||
Args:
|
|
||||||
bot_uuid: Bot 的 UUID
|
|
||||||
path: 子路径(如果有的话)
|
|
||||||
request: Quart Request 对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
响应数据
|
|
||||||
"""
|
|
||||||
return await self.bot.handle_unified_webhook(request)
|
return await self.bot.handle_unified_webhook(request)
|
||||||
|
|
||||||
async def run_async(self):
|
async def run_async(self):
|
||||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
if self._ws_mode:
|
||||||
# 保持运行但不启动独立端口
|
await self.bot.connect()
|
||||||
|
else:
|
||||||
|
|
||||||
async def keep_alive():
|
async def keep_alive():
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
await keep_alive()
|
await keep_alive()
|
||||||
|
|
||||||
async def kill(self) -> bool:
|
async def kill(self) -> bool:
|
||||||
|
if self._ws_mode:
|
||||||
|
await self.bot.disconnect()
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def unregister_listener(
|
async def unregister_listener(
|
||||||
|
|||||||
@@ -11,35 +11,64 @@ metadata:
|
|||||||
icon: wecombot.png
|
icon: wecombot.png
|
||||||
spec:
|
spec:
|
||||||
config:
|
config:
|
||||||
|
- name: BotId
|
||||||
|
label:
|
||||||
|
en_US: BotId
|
||||||
|
zh_Hans: 机器人ID (BotId)
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: ""
|
||||||
|
- name: enable-webhook
|
||||||
|
label:
|
||||||
|
en_US: Enable Webhook Mode
|
||||||
|
zh_Hans: 启用Webhook模式
|
||||||
|
description:
|
||||||
|
en_US: If enabled, the bot will use webhook mode to receive messages. Otherwise, it will use WS long connection mode
|
||||||
|
zh_Hans: 如果启用,机器人将使用 Webhook 模式接收消息。否则,将使用 WS 长连接模式
|
||||||
|
type: boolean
|
||||||
|
required: true
|
||||||
|
default: false
|
||||||
|
- name: Secret
|
||||||
|
label:
|
||||||
|
en_US: Secret
|
||||||
|
zh_Hans: 机器人密钥 (Secret)
|
||||||
|
description:
|
||||||
|
en_US: Required for WebSocket long connection mode
|
||||||
|
zh_Hans: 使用 WS 长连接模式时必填
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
default: ""
|
||||||
- name: Corpid
|
- name: Corpid
|
||||||
label:
|
label:
|
||||||
en_US: Corpid
|
en_US: Corpid
|
||||||
zh_Hans: 企业ID
|
zh_Hans: 企业ID
|
||||||
|
description:
|
||||||
|
en_US: Required for Webhook mode
|
||||||
|
zh_Hans: 使用 Webhook 模式时必填
|
||||||
type: string
|
type: string
|
||||||
required: true
|
required: false
|
||||||
default: ""
|
default: ""
|
||||||
- name: Token
|
- name: Token
|
||||||
label:
|
label:
|
||||||
en_US: Token
|
en_US: Token
|
||||||
zh_Hans: 令牌 (Token)
|
zh_Hans: 令牌 (Token)
|
||||||
|
description:
|
||||||
|
en_US: Required for Webhook mode
|
||||||
|
zh_Hans: 使用 Webhook 模式时必填
|
||||||
type: string
|
type: string
|
||||||
required: true
|
required: false
|
||||||
default: ""
|
default: ""
|
||||||
- name: EncodingAESKey
|
- name: EncodingAESKey
|
||||||
label:
|
label:
|
||||||
en_US: EncodingAESKey
|
en_US: EncodingAESKey
|
||||||
zh_Hans: 消息加解密密钥 (EncodingAESKey)
|
zh_Hans: 消息加解密密钥 (EncodingAESKey)
|
||||||
type: string
|
description:
|
||||||
required: true
|
en_US: Required for Webhook mode. Optional for WebSocket mode (used for file decryption)
|
||||||
default: ""
|
zh_Hans: 使用 Webhook 模式时必填。WebSocket 模式下可选(用于文件解密)
|
||||||
- name: BotId
|
|
||||||
label:
|
|
||||||
en_US: BotId
|
|
||||||
zh_Hans: 机器人ID
|
|
||||||
type: string
|
type: string
|
||||||
required: false
|
required: false
|
||||||
default: ""
|
default: ""
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
path: ./wecombot.py
|
path: ./wecombot.py
|
||||||
attr: WecomBotAdapter
|
attr: WecomBotAdapter
|
||||||
|
|||||||
@@ -565,6 +565,16 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return _make_rag_error_response(e, 'FileServiceError', storage_path=storage_path)
|
return _make_rag_error_response(e, 'FileServiceError', storage_path=storage_path)
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.LIST_PARSERS)
|
||||||
|
async def list_parsers(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
"""Plugin requests host to list available parser plugins."""
|
||||||
|
mime_type = data.get('mime_type')
|
||||||
|
try:
|
||||||
|
parsers = await self.ap.knowledge_service.list_parsers(mime_type)
|
||||||
|
return handler.ActionResponse.success(data={'parsers': parsers})
|
||||||
|
except Exception as e:
|
||||||
|
return _make_rag_error_response(e, 'ParserDiscoveryError', mime_type=mime_type)
|
||||||
|
|
||||||
@self.action(PluginToRuntimeAction.INVOKE_PARSER)
|
@self.action(PluginToRuntimeAction.INVOKE_PARSER)
|
||||||
async def invoke_parser(data: dict[str, Any]) -> handler.ActionResponse:
|
async def invoke_parser(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
"""Plugin requests host to invoke a parser plugin."""
|
"""Plugin requests host to invoke a parser plugin."""
|
||||||
@@ -589,6 +599,94 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return _make_rag_error_response(e, 'ParserError')
|
return _make_rag_error_response(e, 'ParserError')
|
||||||
|
|
||||||
|
# ================= Knowledge Base Query APIs =================
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.LIST_PIPELINE_KNOWLEDGE_BASES)
|
||||||
|
async def list_pipeline_knowledge_bases(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
"""List knowledge bases configured for the current query's pipeline."""
|
||||||
|
query_id = data['query_id']
|
||||||
|
|
||||||
|
if query_id not in self.ap.query_pool.cached_queries:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Query with query_id {query_id} not found',
|
||||||
|
)
|
||||||
|
|
||||||
|
query = self.ap.query_pool.cached_queries[query_id]
|
||||||
|
|
||||||
|
kb_uuids = []
|
||||||
|
if query.pipeline_config:
|
||||||
|
local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {})
|
||||||
|
kb_uuids = local_agent_config.get('knowledge-bases', [])
|
||||||
|
# Backward compatibility
|
||||||
|
if not kb_uuids:
|
||||||
|
old_kb_uuid = local_agent_config.get('knowledge-base', '')
|
||||||
|
if old_kb_uuid and old_kb_uuid != '__none__':
|
||||||
|
kb_uuids = [old_kb_uuid]
|
||||||
|
|
||||||
|
knowledge_bases = []
|
||||||
|
for kb_uuid in kb_uuids:
|
||||||
|
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||||
|
if kb:
|
||||||
|
knowledge_bases.append(
|
||||||
|
{
|
||||||
|
'uuid': kb.get_uuid(),
|
||||||
|
'name': kb.get_name(),
|
||||||
|
'description': kb.knowledge_base_entity.description or '',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return handler.ActionResponse.success(data={'knowledge_bases': knowledge_bases})
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.RETRIEVE_KNOWLEDGE_BASE)
|
||||||
|
async def retrieve_knowledge_base(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
"""Retrieve documents from a knowledge base within the pipeline's scope."""
|
||||||
|
query_id = data['query_id']
|
||||||
|
kb_id = data['kb_id']
|
||||||
|
query_text = data['query_text']
|
||||||
|
top_k = data.get('top_k', 5)
|
||||||
|
filters = data.get('filters', {})
|
||||||
|
|
||||||
|
if query_id not in self.ap.query_pool.cached_queries:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Query with query_id {query_id} not found',
|
||||||
|
)
|
||||||
|
|
||||||
|
query = self.ap.query_pool.cached_queries[query_id]
|
||||||
|
|
||||||
|
# Validate kb_id is in pipeline's allowed list
|
||||||
|
allowed_kb_uuids = []
|
||||||
|
if query.pipeline_config:
|
||||||
|
local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {})
|
||||||
|
allowed_kb_uuids = local_agent_config.get('knowledge-bases', [])
|
||||||
|
if not allowed_kb_uuids:
|
||||||
|
old_kb_uuid = local_agent_config.get('knowledge-base', '')
|
||||||
|
if old_kb_uuid and old_kb_uuid != '__none__':
|
||||||
|
allowed_kb_uuids = [old_kb_uuid]
|
||||||
|
|
||||||
|
if kb_id not in allowed_kb_uuids:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Knowledge base {kb_id} is not configured for this pipeline',
|
||||||
|
)
|
||||||
|
|
||||||
|
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_id)
|
||||||
|
if not kb:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Knowledge base {kb_id} not found',
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
entries = await kb.retrieve(
|
||||||
|
query_text,
|
||||||
|
settings={
|
||||||
|
'top_k': top_k,
|
||||||
|
'filters': filters,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
results = [entry.model_dump(mode='json') for entry in entries]
|
||||||
|
return handler.ActionResponse.success(data={'results': results})
|
||||||
|
except Exception as e:
|
||||||
|
return _make_rag_error_response(e, 'RetrievalError', kb_id=kb_id)
|
||||||
|
|
||||||
@self.action(CommonAction.PING)
|
@self.action(CommonAction.PING)
|
||||||
async def ping(data: dict[str, Any]) -> handler.ActionResponse:
|
async def ping(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
"""Ping"""
|
"""Ping"""
|
||||||
@@ -895,7 +993,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
result = await self.call_action(
|
result = await self.call_action(
|
||||||
LangBotToRuntimeAction.RAG_INGEST_DOCUMENT,
|
LangBotToRuntimeAction.RAG_INGEST_DOCUMENT,
|
||||||
{'plugin_author': plugin_author, 'plugin_name': plugin_name, 'context': context_data},
|
{'plugin_author': plugin_author, 'plugin_name': plugin_name, 'context': context_data},
|
||||||
timeout=300, # Ingestion can be slow
|
timeout=1200, # Ingestion can be slow for large documents
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -168,6 +168,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
|||||||
result = await kb.retrieve(
|
result = await kb.retrieve(
|
||||||
user_message_text,
|
user_message_text,
|
||||||
settings={
|
settings={
|
||||||
|
'bot_uuid': query.bot_uuid or '',
|
||||||
'sender_id': str(query.sender_id),
|
'sender_id': str(query.sender_id),
|
||||||
'session_name': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
'session_name': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import langbot
|
|||||||
|
|
||||||
semantic_version = f'v{langbot.__version__}'
|
semantic_version = f'v{langbot.__version__}'
|
||||||
|
|
||||||
required_database_version = 23
|
required_database_version = 24
|
||||||
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
|
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
|
||||||
|
|
||||||
debug_mode = False
|
debug_mode = False
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class VectorDBManager:
|
|||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Proxy: Search vectors.
|
"""Proxy: Search vectors.
|
||||||
|
|
||||||
Returns a list of dicts with keys: 'id', 'score', 'metadata'.
|
Returns a list of dicts with keys: 'id', 'distance', 'metadata'.
|
||||||
The underlying VectorDatabase.search returns Chroma-style format:
|
The underlying VectorDatabase.search returns Chroma-style format:
|
||||||
{ 'ids': [['id1']], 'distances': [[0.1]], 'metadatas': [[{}]] }
|
{ 'ids': [['id1']], 'distances': [[0.1]], 'metadatas': [[{}]] }
|
||||||
"""
|
"""
|
||||||
@@ -130,7 +130,7 @@ class VectorDBManager:
|
|||||||
parsed_results.append(
|
parsed_results.append(
|
||||||
{
|
{
|
||||||
'id': id_val,
|
'id': id_val,
|
||||||
'score': r_dists[i] if r_dists and i < len(r_dists) else 0.0,
|
'distance': r_dists[i] if r_dists and i < len(r_dists) else 0.0,
|
||||||
'metadata': r_metas[i] if r_metas and i < len(r_metas) else {},
|
'metadata': r_metas[i] if r_metas and i < len(r_metas) else {},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ admins: []
|
|||||||
api:
|
api:
|
||||||
port: 5300
|
port: 5300
|
||||||
webhook_prefix: 'http://127.0.0.1:5300'
|
webhook_prefix: 'http://127.0.0.1:5300'
|
||||||
|
extra_webhook_prefix: ''
|
||||||
command:
|
command:
|
||||||
enable: true
|
enable: true
|
||||||
prefix:
|
prefix:
|
||||||
|
|||||||
@@ -41,7 +41,10 @@
|
|||||||
"runner": "local-agent"
|
"runner": "local-agent"
|
||||||
},
|
},
|
||||||
"local-agent": {
|
"local-agent": {
|
||||||
"model": "",
|
"model": {
|
||||||
|
"primary": "",
|
||||||
|
"fallbacks": []
|
||||||
|
},
|
||||||
"max-round": 10,
|
"max-round": 10,
|
||||||
"prompt": [
|
"prompt": [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -91,14 +91,15 @@ class TestWebhookDisplayPrefix:
|
|||||||
|
|
||||||
def test_default_webhook_prefix(self):
|
def test_default_webhook_prefix(self):
|
||||||
"""Test that the default webhook display prefix is correctly set"""
|
"""Test that the default webhook display prefix is correctly set"""
|
||||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||||
|
|
||||||
# Should have the default value
|
# Should have the default value
|
||||||
assert cfg['api']['webhook_prefix'] == 'http://127.0.0.1:5300'
|
assert cfg['api']['webhook_prefix'] == 'http://127.0.0.1:5300'
|
||||||
|
assert cfg['api']['extra_webhook_prefix'] == ''
|
||||||
|
|
||||||
def test_webhook_prefix_env_override(self):
|
def test_webhook_prefix_env_override(self):
|
||||||
"""Test overriding webhook_prefix via environment variable"""
|
"""Test overriding webhook_prefix via environment variable"""
|
||||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||||
|
|
||||||
# Set environment variable
|
# Set environment variable
|
||||||
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com:8080'
|
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com:8080'
|
||||||
@@ -112,7 +113,7 @@ class TestWebhookDisplayPrefix:
|
|||||||
|
|
||||||
def test_webhook_prefix_with_custom_domain(self):
|
def test_webhook_prefix_with_custom_domain(self):
|
||||||
"""Test webhook_prefix with custom domain"""
|
"""Test webhook_prefix with custom domain"""
|
||||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||||
|
|
||||||
# Set to a custom domain
|
# Set to a custom domain
|
||||||
os.environ['API__WEBHOOK_PREFIX'] = 'https://bot.mycompany.com'
|
os.environ['API__WEBHOOK_PREFIX'] = 'https://bot.mycompany.com'
|
||||||
@@ -126,7 +127,7 @@ class TestWebhookDisplayPrefix:
|
|||||||
|
|
||||||
def test_webhook_prefix_with_subdirectory(self):
|
def test_webhook_prefix_with_subdirectory(self):
|
||||||
"""Test webhook_prefix with subdirectory path"""
|
"""Test webhook_prefix with subdirectory path"""
|
||||||
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
|
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||||
|
|
||||||
# Set to a URL with subdirectory
|
# Set to a URL with subdirectory
|
||||||
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com/langbot'
|
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com/langbot'
|
||||||
@@ -138,6 +139,37 @@ class TestWebhookDisplayPrefix:
|
|||||||
# Cleanup
|
# Cleanup
|
||||||
del os.environ['API__WEBHOOK_PREFIX']
|
del os.environ['API__WEBHOOK_PREFIX']
|
||||||
|
|
||||||
|
def test_extra_webhook_prefix_default_empty(self):
|
||||||
|
"""Test that extra_webhook_prefix defaults to empty string"""
|
||||||
|
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||||
|
|
||||||
|
bot_uuid = 'test-bot-uuid'
|
||||||
|
webhook_prefix = cfg['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
||||||
|
extra_webhook_prefix = cfg['api'].get('extra_webhook_prefix', '')
|
||||||
|
webhook_url = f'/bots/{bot_uuid}'
|
||||||
|
|
||||||
|
assert f'{webhook_prefix}{webhook_url}' == 'http://127.0.0.1:5300/bots/test-bot-uuid'
|
||||||
|
# extra should be empty when not configured
|
||||||
|
assert extra_webhook_prefix == ''
|
||||||
|
|
||||||
|
def test_extra_webhook_prefix_env_override(self):
|
||||||
|
"""Test overriding extra_webhook_prefix via environment variable"""
|
||||||
|
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
|
||||||
|
|
||||||
|
os.environ['API__EXTRA_WEBHOOK_PREFIX'] = 'https://extra.example.com'
|
||||||
|
|
||||||
|
result = _apply_env_overrides_to_config(cfg)
|
||||||
|
|
||||||
|
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'
|
||||||
|
|
||||||
|
bot_uuid = 'test-bot-uuid'
|
||||||
|
extra_prefix = result['api']['extra_webhook_prefix']
|
||||||
|
webhook_url = f'/bots/{bot_uuid}'
|
||||||
|
assert f'{extra_prefix}{webhook_url}' == 'https://extra.example.com/bots/test-bot-uuid'
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
del os.environ['API__EXTRA_WEBHOOK_PREFIX']
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__, '-v'])
|
pytest.main([__file__, '-v'])
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ def sample_query(sample_message_chain, sample_message_event, mock_adapter):
|
|||||||
pipeline_config={
|
pipeline_config={
|
||||||
'ai': {
|
'ai': {
|
||||||
'runner': {'runner': 'local-agent'},
|
'runner': {'runner': 'local-agent'},
|
||||||
'local-agent': {'model': 'test-model-uuid', 'prompt': 'test-prompt'},
|
'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'},
|
||||||
},
|
},
|
||||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
||||||
'trigger': {'misc': {'combine-quote-message': False}},
|
'trigger': {'misc': {'combine-quote-message': False}},
|
||||||
@@ -219,7 +219,7 @@ def sample_pipeline_config():
|
|||||||
return {
|
return {
|
||||||
'ai': {
|
'ai': {
|
||||||
'runner': {'runner': 'local-agent'},
|
'runner': {'runner': 'local-agent'},
|
||||||
'local-agent': {'model': 'test-model-uuid', 'prompt': 'test-prompt'},
|
'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'},
|
||||||
},
|
},
|
||||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
||||||
'trigger': {'misc': {'combine-quote-message': False}},
|
'trigger': {'misc': {'combine-quote-message': False}},
|
||||||
|
|||||||
@@ -102,5 +102,10 @@
|
|||||||
"typescript": "^5.8.3",
|
"typescript": "^5.8.3",
|
||||||
"typescript-eslint": "^8.31.1"
|
"typescript-eslint": "^8.31.1"
|
||||||
},
|
},
|
||||||
"packageManager": "pnpm@8.9.2+sha512.b9d35fe91b2a5854dadc43034a3e7b2e675fa4b56e20e8e09ef078fa553c18f8aed44051e7b36e8b8dd435f97eb0c44c4ff3b44fc7c6fa7d21e1fac17bbe661e"
|
"packageManager": "pnpm@8.9.2+sha512.b9d35fe91b2a5854dadc43034a3e7b2e675fa4b56e20e8e09ef078fa553c18f8aed44051e7b36e8b8dd435f97eb0c44c4ff3b44fc7c6fa7d21e1fac17bbe661e",
|
||||||
}
|
"pnpm": {
|
||||||
|
"overrides": {
|
||||||
|
"minimatch": "3.1.3"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
34
web/pnpm-lock.yaml
generated
34
web/pnpm-lock.yaml
generated
@@ -4,6 +4,9 @@ settings:
|
|||||||
autoInstallPeers: true
|
autoInstallPeers: true
|
||||||
excludeLinksFromLockfile: false
|
excludeLinksFromLockfile: false
|
||||||
|
|
||||||
|
overrides:
|
||||||
|
minimatch: 3.1.3
|
||||||
|
|
||||||
dependencies:
|
dependencies:
|
||||||
'@dnd-kit/core':
|
'@dnd-kit/core':
|
||||||
specifier: ^6.3.1
|
specifier: ^6.3.1
|
||||||
@@ -345,7 +348,7 @@ packages:
|
|||||||
dependencies:
|
dependencies:
|
||||||
'@eslint/object-schema': 2.1.7
|
'@eslint/object-schema': 2.1.7
|
||||||
debug: 4.4.3
|
debug: 4.4.3
|
||||||
minimatch: 3.1.2
|
minimatch: 3.1.3
|
||||||
transitivePeerDependencies:
|
transitivePeerDependencies:
|
||||||
- supports-color
|
- supports-color
|
||||||
dev: true
|
dev: true
|
||||||
@@ -375,7 +378,7 @@ packages:
|
|||||||
ignore: 5.3.2
|
ignore: 5.3.2
|
||||||
import-fresh: 3.3.1
|
import-fresh: 3.3.1
|
||||||
js-yaml: 4.1.1
|
js-yaml: 4.1.1
|
||||||
minimatch: 3.1.2
|
minimatch: 3.1.3
|
||||||
strip-json-comments: 3.1.1
|
strip-json-comments: 3.1.1
|
||||||
transitivePeerDependencies:
|
transitivePeerDependencies:
|
||||||
- supports-color
|
- supports-color
|
||||||
@@ -2260,7 +2263,7 @@ packages:
|
|||||||
'@typescript-eslint/types': 8.54.0
|
'@typescript-eslint/types': 8.54.0
|
||||||
'@typescript-eslint/visitor-keys': 8.54.0
|
'@typescript-eslint/visitor-keys': 8.54.0
|
||||||
debug: 4.4.3
|
debug: 4.4.3
|
||||||
minimatch: 9.0.5
|
minimatch: 3.1.3
|
||||||
semver: 7.7.3
|
semver: 7.7.3
|
||||||
tinyglobby: 0.2.15
|
tinyglobby: 0.2.15
|
||||||
ts-api-utils: 2.4.0(typescript@5.9.3)
|
ts-api-utils: 2.4.0(typescript@5.9.3)
|
||||||
@@ -2678,12 +2681,6 @@ packages:
|
|||||||
concat-map: 0.0.1
|
concat-map: 0.0.1
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
/brace-expansion@2.0.2:
|
|
||||||
resolution: {integrity: sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==}
|
|
||||||
dependencies:
|
|
||||||
balanced-match: 1.0.2
|
|
||||||
dev: true
|
|
||||||
|
|
||||||
/braces@3.0.3:
|
/braces@3.0.3:
|
||||||
resolution: {integrity: sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==}
|
resolution: {integrity: sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==}
|
||||||
engines: {node: '>=8'}
|
engines: {node: '>=8'}
|
||||||
@@ -3345,7 +3342,7 @@ packages:
|
|||||||
hasown: 2.0.2
|
hasown: 2.0.2
|
||||||
is-core-module: 2.16.1
|
is-core-module: 2.16.1
|
||||||
is-glob: 4.0.3
|
is-glob: 4.0.3
|
||||||
minimatch: 3.1.2
|
minimatch: 3.1.3
|
||||||
object.fromentries: 2.0.8
|
object.fromentries: 2.0.8
|
||||||
object.groupby: 1.0.3
|
object.groupby: 1.0.3
|
||||||
object.values: 1.2.1
|
object.values: 1.2.1
|
||||||
@@ -3376,7 +3373,7 @@ packages:
|
|||||||
hasown: 2.0.2
|
hasown: 2.0.2
|
||||||
jsx-ast-utils: 3.3.5
|
jsx-ast-utils: 3.3.5
|
||||||
language-tags: 1.0.9
|
language-tags: 1.0.9
|
||||||
minimatch: 3.1.2
|
minimatch: 3.1.3
|
||||||
object.fromentries: 2.0.8
|
object.fromentries: 2.0.8
|
||||||
safe-regex-test: 1.1.0
|
safe-regex-test: 1.1.0
|
||||||
string.prototype.includes: 2.0.1
|
string.prototype.includes: 2.0.1
|
||||||
@@ -3428,7 +3425,7 @@ packages:
|
|||||||
estraverse: 5.3.0
|
estraverse: 5.3.0
|
||||||
hasown: 2.0.2
|
hasown: 2.0.2
|
||||||
jsx-ast-utils: 3.3.5
|
jsx-ast-utils: 3.3.5
|
||||||
minimatch: 3.1.2
|
minimatch: 3.1.3
|
||||||
object.entries: 1.1.9
|
object.entries: 1.1.9
|
||||||
object.fromentries: 2.0.8
|
object.fromentries: 2.0.8
|
||||||
object.values: 1.2.1
|
object.values: 1.2.1
|
||||||
@@ -3498,7 +3495,7 @@ packages:
|
|||||||
is-glob: 4.0.3
|
is-glob: 4.0.3
|
||||||
json-stable-stringify-without-jsonify: 1.0.1
|
json-stable-stringify-without-jsonify: 1.0.1
|
||||||
lodash.merge: 4.6.2
|
lodash.merge: 4.6.2
|
||||||
minimatch: 3.1.2
|
minimatch: 3.1.3
|
||||||
natural-compare: 1.4.0
|
natural-compare: 1.4.0
|
||||||
optionator: 0.9.4
|
optionator: 0.9.4
|
||||||
transitivePeerDependencies:
|
transitivePeerDependencies:
|
||||||
@@ -5113,19 +5110,12 @@ packages:
|
|||||||
engines: {node: '>=18'}
|
engines: {node: '>=18'}
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
/minimatch@3.1.2:
|
/minimatch@3.1.3:
|
||||||
resolution: {integrity: sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==}
|
resolution: {integrity: sha512-M2GCs7Vk83NxkUyQV1bkABc4yxgz9kILhHImZiBPAZ9ybuvCb0/H7lEl5XvIg3g+9d4eNotkZA5IWwYl0tibaA==}
|
||||||
dependencies:
|
dependencies:
|
||||||
brace-expansion: 1.1.12
|
brace-expansion: 1.1.12
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
/minimatch@9.0.5:
|
|
||||||
resolution: {integrity: sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==}
|
|
||||||
engines: {node: '>=16 || 14 >=14.17'}
|
|
||||||
dependencies:
|
|
||||||
brace-expansion: 2.0.2
|
|
||||||
dev: true
|
|
||||||
|
|
||||||
/minimist@1.2.8:
|
/minimist@1.2.8:
|
||||||
resolution: {integrity: sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==}
|
resolution: {integrity: sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==}
|
||||||
dev: true
|
dev: true
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useMemo, useState } from 'react';
|
||||||
import {
|
import {
|
||||||
IChooseAdapterEntity,
|
IChooseAdapterEntity,
|
||||||
IPipelineEntity,
|
IPipelineEntity,
|
||||||
@@ -113,109 +113,73 @@ export default function BotForm({
|
|||||||
const [dynamicFormConfigList, setDynamicFormConfigList] = useState<
|
const [dynamicFormConfigList, setDynamicFormConfigList] = useState<
|
||||||
IDynamicFormItemSchema[]
|
IDynamicFormItemSchema[]
|
||||||
>([]);
|
>([]);
|
||||||
const [filteredDynamicFormConfigList, setFilteredDynamicFormConfigList] =
|
|
||||||
useState<IDynamicFormItemSchema[]>([]);
|
|
||||||
const [, setIsLoading] = useState<boolean>(false);
|
const [, setIsLoading] = useState<boolean>(false);
|
||||||
const [webhookUrl, setWebhookUrl] = useState<string>('');
|
const [webhookUrl, setWebhookUrl] = useState<string>('');
|
||||||
const webhookInputRef = React.useRef<HTMLInputElement>(null);
|
const [extraWebhookUrl, setExtraWebhookUrl] = useState<string>('');
|
||||||
const [copied, setCopied] = useState<boolean>(false);
|
const [copied, setCopied] = useState<boolean>(false);
|
||||||
|
const [extraCopied, setExtraCopied] = useState<boolean>(false);
|
||||||
|
|
||||||
// Watch adapter and adapter_config for filtering
|
// Watch adapter and adapter_config for filtering
|
||||||
const currentAdapter = form.watch('adapter');
|
const currentAdapter = form.watch('adapter');
|
||||||
const currentAdapterConfig = form.watch('adapter_config');
|
const currentAdapterConfig = form.watch('adapter_config');
|
||||||
|
|
||||||
|
// Derive the filtered config list via useMemo instead of useEffect+setState
|
||||||
|
// to avoid creating new array references that would cause DynamicFormComponent
|
||||||
|
// to re-subscribe its form.watch, re-emit values, and trigger an infinite loop.
|
||||||
|
// Only depend on the specific field we care about (enable-webhook) rather than
|
||||||
|
// the entire currentAdapterConfig object, which changes on every emission.
|
||||||
|
const enableWebhook = currentAdapterConfig?.['enable-webhook'];
|
||||||
|
const filteredDynamicFormConfigList = useMemo(() => {
|
||||||
|
if (currentAdapter === 'lark' && enableWebhook === false) {
|
||||||
|
// Hide encrypt-key field when webhook is disabled
|
||||||
|
return dynamicFormConfigList.filter(
|
||||||
|
(config) => config.name !== 'encrypt-key',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// For non-Lark adapters or when webhook is enabled/undefined, show all fields
|
||||||
|
return dynamicFormConfigList;
|
||||||
|
}, [currentAdapter, enableWebhook, dynamicFormConfigList]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setBotFormValues();
|
setBotFormValues();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
// Filter dynamic form config list based on enable-webhook status for Lark adapter
|
// 复制到剪贴板的辅助函数
|
||||||
useEffect(() => {
|
const copyToClipboard = (
|
||||||
if (currentAdapter === 'lark') {
|
text: string,
|
||||||
const enableWebhook = currentAdapterConfig?.['enable-webhook'];
|
setStatus: React.Dispatch<React.SetStateAction<boolean>>,
|
||||||
if (enableWebhook === false) {
|
) => {
|
||||||
// Hide encrypt-key field when webhook is disabled
|
if (navigator.clipboard && navigator.clipboard.writeText) {
|
||||||
setFilteredDynamicFormConfigList(
|
navigator.clipboard
|
||||||
dynamicFormConfigList.filter(
|
.writeText(text)
|
||||||
(config) => config.name !== 'encrypt-key',
|
.then(() => {
|
||||||
),
|
setStatus(true);
|
||||||
);
|
setTimeout(() => setStatus(false), 2000);
|
||||||
} else {
|
})
|
||||||
// Show all fields when webhook is enabled or undefined
|
.catch(() => {
|
||||||
setFilteredDynamicFormConfigList(dynamicFormConfigList);
|
// 降级:创建临时textarea复制
|
||||||
}
|
fallbackCopy(text, setStatus);
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
// For non-Lark adapters, show all fields
|
fallbackCopy(text, setStatus);
|
||||||
setFilteredDynamicFormConfigList(dynamicFormConfigList);
|
|
||||||
}
|
}
|
||||||
}, [currentAdapter, currentAdapterConfig, dynamicFormConfigList]);
|
};
|
||||||
|
|
||||||
// 复制到剪贴板的辅助函数 - 使用页面上的真实input元素
|
const fallbackCopy = (
|
||||||
const copyToClipboard = () => {
|
text: string,
|
||||||
console.log('[Copy] Attempting to copy from input element');
|
setStatus: React.Dispatch<React.SetStateAction<boolean>>,
|
||||||
|
) => {
|
||||||
const inputElement = webhookInputRef.current;
|
const textarea = document.createElement('textarea');
|
||||||
if (!inputElement) {
|
textarea.value = text;
|
||||||
console.error('[Copy] Input element not found');
|
textarea.style.position = 'fixed';
|
||||||
return;
|
textarea.style.opacity = '0';
|
||||||
}
|
document.body.appendChild(textarea);
|
||||||
|
textarea.select();
|
||||||
try {
|
const successful = document.execCommand('copy');
|
||||||
// 确保input元素可见且未被禁用
|
document.body.removeChild(textarea);
|
||||||
inputElement.disabled = false;
|
if (successful) {
|
||||||
inputElement.readOnly = false;
|
setStatus(true);
|
||||||
|
setTimeout(() => setStatus(false), 2000);
|
||||||
// 聚焦并选中所有文本
|
|
||||||
inputElement.focus();
|
|
||||||
inputElement.select();
|
|
||||||
|
|
||||||
// 尝试使用现代API
|
|
||||||
if (navigator.clipboard && navigator.clipboard.writeText) {
|
|
||||||
console.log(
|
|
||||||
'[Copy] Using Clipboard API with input value:',
|
|
||||||
inputElement.value,
|
|
||||||
);
|
|
||||||
navigator.clipboard
|
|
||||||
.writeText(inputElement.value)
|
|
||||||
.then(() => {
|
|
||||||
console.log('[Copy] Clipboard API success');
|
|
||||||
inputElement.blur(); // 取消选中
|
|
||||||
inputElement.readOnly = true;
|
|
||||||
setCopied(true);
|
|
||||||
setTimeout(() => setCopied(false), 2000);
|
|
||||||
})
|
|
||||||
.catch((err) => {
|
|
||||||
console.error(
|
|
||||||
'[Copy] Clipboard API failed, trying execCommand:',
|
|
||||||
err,
|
|
||||||
);
|
|
||||||
// 降级到execCommand
|
|
||||||
const successful = document.execCommand('copy');
|
|
||||||
console.log('[Copy] execCommand result:', successful);
|
|
||||||
inputElement.blur();
|
|
||||||
inputElement.readOnly = true;
|
|
||||||
if (successful) {
|
|
||||||
setCopied(true);
|
|
||||||
setTimeout(() => setCopied(false), 2000);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// 直接使用execCommand
|
|
||||||
console.log(
|
|
||||||
'[Copy] Using execCommand with input value:',
|
|
||||||
inputElement.value,
|
|
||||||
);
|
|
||||||
const successful = document.execCommand('copy');
|
|
||||||
console.log('[Copy] execCommand result:', successful);
|
|
||||||
inputElement.blur();
|
|
||||||
inputElement.readOnly = true;
|
|
||||||
if (successful) {
|
|
||||||
setCopied(true);
|
|
||||||
setTimeout(() => setCopied(false), 2000);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
console.error('[Copy] Copy failed:', err);
|
|
||||||
inputElement.readOnly = true;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -240,6 +204,7 @@ export default function BotForm({
|
|||||||
} else {
|
} else {
|
||||||
setWebhookUrl('');
|
setWebhookUrl('');
|
||||||
}
|
}
|
||||||
|
setExtraWebhookUrl(val.extra_webhook_full_url || '');
|
||||||
})
|
})
|
||||||
.catch((err) => {
|
.catch((err) => {
|
||||||
toast.error(
|
toast.error(
|
||||||
@@ -249,6 +214,7 @@ export default function BotForm({
|
|||||||
} else {
|
} else {
|
||||||
form.reset();
|
form.reset();
|
||||||
setWebhookUrl('');
|
setWebhookUrl('');
|
||||||
|
setExtraWebhookUrl('');
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -321,14 +287,20 @@ export default function BotForm({
|
|||||||
setAdapterNameToDynamicConfigMap(adapterNameToDynamicConfigMap);
|
setAdapterNameToDynamicConfigMap(adapterNameToDynamicConfigMap);
|
||||||
}
|
}
|
||||||
|
|
||||||
async function getBotConfig(
|
async function getBotConfig(botId: string): Promise<
|
||||||
botId: string,
|
z.infer<typeof formSchema> & {
|
||||||
): Promise<z.infer<typeof formSchema> & { webhook_full_url?: string }> {
|
webhook_full_url?: string;
|
||||||
|
extra_webhook_full_url?: string;
|
||||||
|
}
|
||||||
|
> {
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
httpClient
|
httpClient
|
||||||
.getBot(botId)
|
.getBot(botId)
|
||||||
.then((res) => {
|
.then((res) => {
|
||||||
const bot = res.bot;
|
const bot = res.bot;
|
||||||
|
const runtimeValues = bot.adapter_runtime_values as
|
||||||
|
| Record<string, unknown>
|
||||||
|
| undefined;
|
||||||
resolve({
|
resolve({
|
||||||
adapter: bot.adapter,
|
adapter: bot.adapter,
|
||||||
description: bot.description,
|
description: bot.description,
|
||||||
@@ -336,10 +308,12 @@ export default function BotForm({
|
|||||||
adapter_config: bot.adapter_config,
|
adapter_config: bot.adapter_config,
|
||||||
enable: bot.enable ?? true,
|
enable: bot.enable ?? true,
|
||||||
use_pipeline_uuid: bot.use_pipeline_uuid ?? '',
|
use_pipeline_uuid: bot.use_pipeline_uuid ?? '',
|
||||||
webhook_full_url: bot.adapter_runtime_values
|
webhook_full_url: runtimeValues?.webhook_full_url as
|
||||||
? ((bot.adapter_runtime_values as Record<string, unknown>)
|
| string
|
||||||
.webhook_full_url as string)
|
| undefined,
|
||||||
: undefined,
|
extra_webhook_full_url: runtimeValues?.extra_webhook_full_url as
|
||||||
|
| string
|
||||||
|
| undefined,
|
||||||
});
|
});
|
||||||
})
|
})
|
||||||
.catch((err) => {
|
.catch((err) => {
|
||||||
@@ -530,13 +504,11 @@ export default function BotForm({
|
|||||||
|
|
||||||
{/* Webhook 地址显示(统一 Webhook 模式) */}
|
{/* Webhook 地址显示(统一 Webhook 模式) */}
|
||||||
{webhookUrl &&
|
{webhookUrl &&
|
||||||
(currentAdapter !== 'lark' ||
|
(currentAdapter !== 'lark' || enableWebhook !== false) && (
|
||||||
currentAdapterConfig?.['enable-webhook'] !== false) && (
|
|
||||||
<FormItem>
|
<FormItem>
|
||||||
<FormLabel>{t('bots.webhookUrl')}</FormLabel>
|
<FormLabel>{t('bots.webhookUrl')}</FormLabel>
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
<Input
|
<Input
|
||||||
ref={webhookInputRef}
|
|
||||||
value={webhookUrl}
|
value={webhookUrl}
|
||||||
readOnly
|
readOnly
|
||||||
className="flex-1 bg-gray-50 dark:bg-gray-900"
|
className="flex-1 bg-gray-50 dark:bg-gray-900"
|
||||||
@@ -549,7 +521,7 @@ export default function BotForm({
|
|||||||
type="button"
|
type="button"
|
||||||
variant="outline"
|
variant="outline"
|
||||||
size="sm"
|
size="sm"
|
||||||
onClick={copyToClipboard}
|
onClick={() => copyToClipboard(webhookUrl, setCopied)}
|
||||||
>
|
>
|
||||||
{copied ? (
|
{copied ? (
|
||||||
<Check className="h-4 w-4 text-green-600 mr-2" />
|
<Check className="h-4 w-4 text-green-600 mr-2" />
|
||||||
@@ -559,8 +531,37 @@ export default function BotForm({
|
|||||||
{t('common.copy')}
|
{t('common.copy')}
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
{extraWebhookUrl && (
|
||||||
|
<div className="flex items-center gap-2 mt-2">
|
||||||
|
<Input
|
||||||
|
value={extraWebhookUrl}
|
||||||
|
readOnly
|
||||||
|
className="flex-1 bg-gray-50 dark:bg-gray-900"
|
||||||
|
onClick={(e) => {
|
||||||
|
(e.target as HTMLInputElement).select();
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
onClick={() =>
|
||||||
|
copyToClipboard(extraWebhookUrl, setExtraCopied)
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{extraCopied ? (
|
||||||
|
<Check className="h-4 w-4 text-green-600 mr-2" />
|
||||||
|
) : (
|
||||||
|
<Copy className="h-4 w-4 mr-2" />
|
||||||
|
)}
|
||||||
|
{t('common.copy')}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<p className="text-sm text-gray-500 mt-1">
|
<p className="text-sm text-gray-500 mt-1">
|
||||||
{t('bots.webhookUrlHint')}
|
{extraWebhookUrl
|
||||||
|
? t('bots.webhookUrlHintEither')
|
||||||
|
: t('bots.webhookUrlHint')}
|
||||||
</p>
|
</p>
|
||||||
</FormItem>
|
</FormItem>
|
||||||
)}
|
)}
|
||||||
@@ -667,7 +668,7 @@ export default function BotForm({
|
|||||||
</div>
|
</div>
|
||||||
<DynamicFormComponent
|
<DynamicFormComponent
|
||||||
itemConfigList={filteredDynamicFormConfigList}
|
itemConfigList={filteredDynamicFormConfigList}
|
||||||
initialValues={form.watch('adapter_config')}
|
initialValues={currentAdapterConfig}
|
||||||
onSubmit={(values) => {
|
onSubmit={(values) => {
|
||||||
form.setValue('adapter_config', values);
|
form.setValue('adapter_config', values);
|
||||||
}}
|
}}
|
||||||
|
|||||||
@@ -34,6 +34,35 @@ export default function DynamicFormComponent({
|
|||||||
const previousInitialValues = useRef(initialValues);
|
const previousInitialValues = useRef(initialValues);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
// Normalize a form value according to its field type.
|
||||||
|
// This ensures legacy/malformed data (e.g. a plain string for
|
||||||
|
// model-fallback-selector) is coerced to the expected shape
|
||||||
|
// so that downstream components never crash.
|
||||||
|
const normalizeFieldValue = (
|
||||||
|
item: IDynamicFormItemSchema,
|
||||||
|
value: unknown,
|
||||||
|
): unknown => {
|
||||||
|
if (item.type === 'model-fallback-selector') {
|
||||||
|
if (value != null && typeof value === 'object' && !Array.isArray(value)) {
|
||||||
|
const obj = value as Record<string, unknown>;
|
||||||
|
return {
|
||||||
|
primary: typeof obj.primary === 'string' ? obj.primary : '',
|
||||||
|
fallbacks: Array.isArray(obj.fallbacks)
|
||||||
|
? (obj.fallbacks as unknown[]).filter(
|
||||||
|
(v): v is string => typeof v === 'string',
|
||||||
|
)
|
||||||
|
: [],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
// Legacy string format or any other unexpected type
|
||||||
|
return {
|
||||||
|
primary: typeof value === 'string' ? value : '',
|
||||||
|
fallbacks: [],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
};
|
||||||
|
|
||||||
// 根据 itemConfigList 动态生成 zod schema
|
// 根据 itemConfigList 动态生成 zod schema
|
||||||
const formSchema = z.object(
|
const formSchema = z.object(
|
||||||
itemConfigList.reduce(
|
itemConfigList.reduce(
|
||||||
@@ -116,10 +145,10 @@ export default function DynamicFormComponent({
|
|||||||
resolver: zodResolver(formSchema),
|
resolver: zodResolver(formSchema),
|
||||||
defaultValues: itemConfigList.reduce((acc, item) => {
|
defaultValues: itemConfigList.reduce((acc, item) => {
|
||||||
// 优先使用 initialValues,如果没有则使用默认值
|
// 优先使用 initialValues,如果没有则使用默认值
|
||||||
const value = initialValues?.[item.name] ?? item.default;
|
const rawValue = initialValues?.[item.name] ?? item.default;
|
||||||
return {
|
return {
|
||||||
...acc,
|
...acc,
|
||||||
[item.name]: value,
|
[item.name]: normalizeFieldValue(item, rawValue),
|
||||||
};
|
};
|
||||||
}, {} as FormValues),
|
}, {} as FormValues),
|
||||||
});
|
});
|
||||||
@@ -144,7 +173,8 @@ export default function DynamicFormComponent({
|
|||||||
// 合并默认值和初始值
|
// 合并默认值和初始值
|
||||||
const mergedValues = itemConfigList.reduce(
|
const mergedValues = itemConfigList.reduce(
|
||||||
(acc, item) => {
|
(acc, item) => {
|
||||||
acc[item.name] = initialValues[item.name] ?? item.default;
|
const rawValue = initialValues[item.name] ?? item.default;
|
||||||
|
acc[item.name] = normalizeFieldValue(item, rawValue) as object;
|
||||||
return acc;
|
return acc;
|
||||||
},
|
},
|
||||||
{} as Record<string, object>,
|
{} as Record<string, object>,
|
||||||
@@ -181,6 +211,15 @@ export default function DynamicFormComponent({
|
|||||||
);
|
);
|
||||||
onSubmitRef.current?.(initialFinalValues);
|
onSubmitRef.current?.(initialFinalValues);
|
||||||
|
|
||||||
|
// Update previousInitialValues to the emitted snapshot so that if the
|
||||||
|
// parent writes these values back as new initialValues, the deep
|
||||||
|
// comparison in the initialValues-sync useEffect won't detect a change
|
||||||
|
// and won't trigger an infinite update loop.
|
||||||
|
previousInitialValues.current = initialFinalValues as Record<
|
||||||
|
string,
|
||||||
|
object
|
||||||
|
>;
|
||||||
|
|
||||||
const subscription = form.watch(() => {
|
const subscription = form.watch(() => {
|
||||||
const formValues = form.getValues();
|
const formValues = form.getValues();
|
||||||
const finalValues = itemConfigList.reduce(
|
const finalValues = itemConfigList.reduce(
|
||||||
@@ -191,6 +230,7 @@ export default function DynamicFormComponent({
|
|||||||
{} as Record<string, object>,
|
{} as Record<string, object>,
|
||||||
);
|
);
|
||||||
onSubmitRef.current?.(finalValues);
|
onSubmitRef.current?.(finalValues);
|
||||||
|
previousInitialValues.current = finalValues as Record<string, object>;
|
||||||
});
|
});
|
||||||
return () => subscription.unsubscribe();
|
return () => subscription.unsubscribe();
|
||||||
}, [form, itemConfigList]);
|
}, [form, itemConfigList]);
|
||||||
|
|||||||
@@ -348,10 +348,31 @@ export default function DynamicFormItemComponent({
|
|||||||
{} as Record<string, LLMModel[]>,
|
{} as Record<string, LLMModel[]>,
|
||||||
);
|
);
|
||||||
|
|
||||||
const modelValue = field.value as {
|
const rawModelValue = field.value;
|
||||||
primary: string;
|
const modelValue: { primary: string; fallbacks: string[] } =
|
||||||
fallbacks: string[];
|
rawModelValue != null &&
|
||||||
};
|
typeof rawModelValue === 'object' &&
|
||||||
|
!Array.isArray(rawModelValue)
|
||||||
|
? {
|
||||||
|
primary:
|
||||||
|
typeof (rawModelValue as Record<string, unknown>).primary ===
|
||||||
|
'string'
|
||||||
|
? ((rawModelValue as Record<string, unknown>)
|
||||||
|
.primary as string)
|
||||||
|
: '',
|
||||||
|
fallbacks: Array.isArray(
|
||||||
|
(rawModelValue as Record<string, unknown>).fallbacks,
|
||||||
|
)
|
||||||
|
? (
|
||||||
|
(rawModelValue as Record<string, unknown>)
|
||||||
|
.fallbacks as unknown[]
|
||||||
|
).filter((v): v is string => typeof v === 'string')
|
||||||
|
: [],
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
primary: typeof rawModelValue === 'string' ? rawModelValue : '',
|
||||||
|
fallbacks: [],
|
||||||
|
};
|
||||||
|
|
||||||
const renderModelSelect = (
|
const renderModelSelect = (
|
||||||
value: string,
|
value: string,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useMemo, useState } from 'react';
|
||||||
import Link from 'next/link';
|
import Link from 'next/link';
|
||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { zodResolver } from '@hookform/resolvers/zod';
|
import { zodResolver } from '@hookform/resolvers/zod';
|
||||||
@@ -242,11 +242,17 @@ export default function KBForm({
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Convert creation schema to dynamic form items (same as ExternalKBForm)
|
// Convert creation schema to dynamic form items (same as ExternalKBForm)
|
||||||
const configFormItems = parseCreationSchema(selectedEngine?.creation_schema);
|
// Memoize to avoid regenerating UUIDs on every render, which would cause
|
||||||
|
// DynamicFormComponent's useEffect to re-fire and trigger an infinite loop.
|
||||||
|
const configFormItems = useMemo(
|
||||||
|
() => parseCreationSchema(selectedEngine?.creation_schema),
|
||||||
|
[selectedEngine?.creation_schema],
|
||||||
|
);
|
||||||
|
|
||||||
// Convert retrieval schema to dynamic form items
|
// Convert retrieval schema to dynamic form items
|
||||||
const retrievalFormItems = parseCreationSchema(
|
const retrievalFormItems = useMemo(
|
||||||
selectedEngine?.retrieval_schema,
|
() => parseCreationSchema(selectedEngine?.retrieval_schema),
|
||||||
|
[selectedEngine?.retrieval_schema],
|
||||||
);
|
);
|
||||||
|
|
||||||
// Show loading state
|
// Show loading state
|
||||||
|
|||||||
@@ -284,6 +284,8 @@ const enUS = {
|
|||||||
webhookUrlCopied: 'Webhook URL copied',
|
webhookUrlCopied: 'Webhook URL copied',
|
||||||
webhookUrlHint:
|
webhookUrlHint:
|
||||||
'Click the input to select all, then press Ctrl+C (Mac: Cmd+C) to copy, or click the button',
|
'Click the input to select all, then press Ctrl+C (Mac: Cmd+C) to copy, or click the button',
|
||||||
|
webhookUrlHintEither:
|
||||||
|
'Use either of the two URLs above in your platform configuration',
|
||||||
logLevel: 'Log Level',
|
logLevel: 'Log Level',
|
||||||
allLevels: 'All Levels',
|
allLevels: 'All Levels',
|
||||||
selectLevel: 'Select Level',
|
selectLevel: 'Select Level',
|
||||||
|
|||||||
@@ -289,6 +289,8 @@
|
|||||||
webhookUrlCopied: 'Webhook URL をコピーしました',
|
webhookUrlCopied: 'Webhook URL をコピーしました',
|
||||||
webhookUrlHint:
|
webhookUrlHint:
|
||||||
'入力ボックスをクリックして全選択し、Ctrl+C (Mac: Cmd+C) でコピーするか、右側のボタンをクリックしてください',
|
'入力ボックスをクリックして全選択し、Ctrl+C (Mac: Cmd+C) でコピーするか、右側のボタンをクリックしてください',
|
||||||
|
webhookUrlHintEither:
|
||||||
|
'上記の2つのURLのいずれかをプラットフォーム設定に使用してください',
|
||||||
logLevel: 'ログレベル',
|
logLevel: 'ログレベル',
|
||||||
allLevels: 'すべてのレベル',
|
allLevels: 'すべてのレベル',
|
||||||
selectLevel: 'レベルを選択',
|
selectLevel: 'レベルを選択',
|
||||||
|
|||||||
@@ -273,6 +273,7 @@ const zhHans = {
|
|||||||
webhookUrlCopied: 'Webhook 地址已复制',
|
webhookUrlCopied: 'Webhook 地址已复制',
|
||||||
webhookUrlHint:
|
webhookUrlHint:
|
||||||
'点击输入框自动全选,然后按 Ctrl+C (Mac: Cmd+C) 复制,或点击右侧按钮',
|
'点击输入框自动全选,然后按 Ctrl+C (Mac: Cmd+C) 复制,或点击右侧按钮',
|
||||||
|
webhookUrlHintEither: '以上两个地址任选其一填入平台配置即可',
|
||||||
logLevel: '日志级别',
|
logLevel: '日志级别',
|
||||||
allLevels: '全部级别',
|
allLevels: '全部级别',
|
||||||
selectLevel: '选择级别',
|
selectLevel: '选择级别',
|
||||||
|
|||||||
@@ -272,6 +272,7 @@ const zhHant = {
|
|||||||
webhookUrlCopied: 'Webhook 位址已複製',
|
webhookUrlCopied: 'Webhook 位址已複製',
|
||||||
webhookUrlHint:
|
webhookUrlHint:
|
||||||
'點擊輸入框自動全選,然後按 Ctrl+C (Mac: Cmd+C) 複製,或點擊右側按鈕',
|
'點擊輸入框自動全選,然後按 Ctrl+C (Mac: Cmd+C) 複製,或點擊右側按鈕',
|
||||||
|
webhookUrlHintEither: '以上兩個地址任選其一填入平台配置即可',
|
||||||
logLevel: '日誌級別',
|
logLevel: '日誌級別',
|
||||||
allLevels: '全部級別',
|
allLevels: '全部級別',
|
||||||
selectLevel: '選擇級別',
|
selectLevel: '選擇級別',
|
||||||
|
|||||||
Reference in New Issue
Block a user