Compare commits

..

1 Commits

Author SHA1 Message Date
RockChinQ
c3152e10c8 feat: add agent loop protection - max iterations and tool result truncation
- Add max-tool-iterations config (default 16) to prevent runaway agent loops
- Add max-tool-result-chars config (default 8000) to truncate oversized tool results
- Both settings are configurable in pipeline UI under Local Agent settings
- Logs warnings when limits are hit for debugging

Closes #2051
2026-03-12 03:14:36 -04:00
30 changed files with 766 additions and 1824 deletions

View File

@@ -1,6 +1,6 @@
[project]
name = "langbot"
version = "4.9.2"
version = "4.9.0"
description = "Production-grade platform for building agentic IM bots"
readme = "README.md"
license-files = ["LICENSE"]
@@ -64,7 +64,7 @@ dependencies = [
"chromadb>=1.0.0,<2.0.0",
"qdrant-client (>=1.15.1,<2.0.0)",
"pyseekdb==1.1.0.post3",
"langbot-plugin==0.3.1",
"langbot-plugin==0.3.0",
"asyncpg>=0.30.0",
"line-bot-sdk>=3.19.0",
"tboxsdk>=0.0.10",

View File

@@ -1,3 +1,3 @@
"""LangBot - Production-grade platform for building agentic IM bots"""
__version__ = '4.9.2'
__version__ = '4.9.0'

View File

@@ -199,253 +199,6 @@ class StreamSessionManager:
self._msg_index.pop(msg_id, None)
async def download_encrypted_file(download_url: str, encoding_aes_key: str, logger: EventLogger) -> Optional[str]:
"""Download an AES-encrypted file from WeChat Work and return as data URI.
Args:
download_url: The encrypted file download URL.
encoding_aes_key: The AES key used for decryption (base64-encoded, without trailing '=').
logger: Logger instance.
Returns:
A data URI string (e.g. 'data:image/jpeg;base64,...') or None on failure.
"""
if not download_url:
return None
async with httpx.AsyncClient() as client:
response = await client.get(download_url)
if response.status_code != 200:
await logger.error(f'failed to get file: {response.text}')
return None
encrypted_bytes = response.content
aes_key = base64.b64decode(encoding_aes_key + '=')
iv = aes_key[:16]
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
decrypted = cipher.decrypt(encrypted_bytes)
pad_len = decrypted[-1]
decrypted = decrypted[:-pad_len]
if decrypted.startswith(b'\xff\xd8'):
mime_type = 'image/jpeg'
elif decrypted.startswith(b'\x89PNG'):
mime_type = 'image/png'
elif decrypted.startswith((b'GIF87a', b'GIF89a')):
mime_type = 'image/gif'
elif decrypted.startswith(b'BM'):
mime_type = 'image/bmp'
elif decrypted.startswith(b'II*\x00') or decrypted.startswith(b'MM\x00*'):
mime_type = 'image/tiff'
else:
mime_type = 'application/octet-stream'
base64_str = base64.b64encode(decrypted).decode('utf-8')
return f'data:{mime_type};base64,{base64_str}'
async def parse_wecom_bot_message(
msg_json: dict[str, Any], encoding_aes_key: str, logger: EventLogger
) -> dict[str, Any]:
"""Parse a decrypted WeChat Work AI Bot message JSON into a unified message dict.
This is the shared message parsing logic used by both webhook and WebSocket modes.
Args:
msg_json: The decrypted message JSON from WeChat Work.
encoding_aes_key: AES key for file decryption.
logger: Logger instance.
Returns:
A dict suitable for constructing a WecomBotEvent.
"""
message_data: dict[str, Any] = {}
msg_type = msg_json.get('msgtype', '')
if msg_type:
message_data['msgtype'] = msg_type
if msg_json.get('chattype', '') == 'single':
message_data['type'] = 'single'
elif msg_json.get('chattype', '') == 'group':
message_data['type'] = 'group'
max_inline_file_size = 5 * 1024 * 1024
async def _safe_download(url: str):
if not url:
return None
return await download_encrypted_file(url, encoding_aes_key, logger)
if msg_type == 'text':
message_data['content'] = msg_json.get('text', {}).get('content')
elif msg_type == 'markdown':
message_data['content'] = msg_json.get('markdown', {}).get('content') or msg_json.get('text', {}).get(
'content', ''
)
elif msg_type == 'image':
picurl = msg_json.get('image', {}).get('url', '')
base64_data = await _safe_download(picurl)
if base64_data:
message_data['picurl'] = base64_data
message_data['images'] = [base64_data]
elif msg_type == 'voice':
voice_info = msg_json.get('voice', {}) or {}
download_url = voice_info.get('url')
message_data['voice'] = {
'url': download_url,
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
'filesize': voice_info.get('filesize') or voice_info.get('size'),
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
}
if voice_info.get('content'):
message_data['content'] = voice_info.get('content')
if (message_data['voice'].get('filesize') or 0) <= max_inline_file_size:
voice_base64 = await _safe_download(download_url)
if voice_base64:
message_data['voice']['base64'] = voice_base64
elif msg_type == 'video':
video_info = msg_json.get('video', {}) or {}
download_url = video_info.get('url')
video_data = {
'url': download_url,
'filesize': video_info.get('filesize') or video_info.get('size'),
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
'filename': video_info.get('filename') or video_info.get('name'),
}
if (video_data.get('filesize') or 0) <= max_inline_file_size:
video_base64 = await _safe_download(download_url)
if video_base64:
video_data['base64'] = video_base64
message_data['video'] = video_data
elif msg_type == 'file':
file_info = msg_json.get('file', {}) or {}
download_url = file_info.get('url') or file_info.get('fileurl')
file_data = {
'filename': file_info.get('filename') or file_info.get('name'),
'filesize': file_info.get('filesize') or file_info.get('size'),
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
'download_url': download_url,
'extra': file_info,
}
if (file_data.get('filesize') or 0) <= max_inline_file_size:
file_base64 = await _safe_download(download_url)
if file_base64:
file_data['base64'] = file_base64
message_data['file'] = file_data
elif msg_type == 'link':
message_data['link'] = msg_json.get('link', {})
if not message_data.get('content'):
title = message_data['link'].get('title', '')
desc = message_data['link'].get('description') or message_data['link'].get('digest', '')
message_data['content'] = '\n'.join(filter(None, [title, desc]))
elif msg_type == 'mixed':
items = msg_json.get('mixed', {}).get('msg_item', [])
texts = []
images = []
files = []
voices = []
videos = []
links = []
for item in items:
item_type = item.get('msgtype')
if item_type == 'text':
texts.append(item.get('text', {}).get('content', ''))
elif item_type == 'image':
img_url = item.get('image', {}).get('url')
base64_data = await _safe_download(img_url)
if base64_data:
images.append(base64_data)
elif item_type == 'file':
file_info = item.get('file', {}) or {}
download_url = file_info.get('url') or file_info.get('fileurl')
file_data = {
'filename': file_info.get('filename') or file_info.get('name'),
'filesize': file_info.get('filesize') or file_info.get('size'),
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
'download_url': download_url,
'extra': file_info,
}
if (file_data.get('filesize') or 0) <= max_inline_file_size:
file_base64 = await _safe_download(download_url)
if file_base64:
file_data['base64'] = file_base64
files.append(file_data)
elif item_type == 'voice':
voice_info = item.get('voice', {}) or {}
download_url = voice_info.get('url')
voice_data = {
'url': download_url,
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
'filesize': voice_info.get('filesize') or voice_info.get('size'),
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
}
if voice_info.get('content'):
texts.append(voice_info.get('content'))
if (voice_data.get('filesize') or 0) <= max_inline_file_size:
voice_base64 = await _safe_download(download_url)
if voice_base64:
voice_data['base64'] = voice_base64
voices.append(voice_data)
elif item_type == 'video':
video_info = item.get('video', {}) or {}
download_url = video_info.get('url')
video_data = {
'url': download_url,
'filesize': video_info.get('filesize') or video_info.get('size'),
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
'filename': video_info.get('filename') or video_info.get('name'),
}
if (video_data.get('filesize') or 0) <= max_inline_file_size:
video_base64 = await _safe_download(download_url)
if video_base64:
video_data['base64'] = video_base64
videos.append(video_data)
elif item_type == 'link':
links.append(item.get('link', {}))
if texts:
message_data['content'] = ' '.join(texts)
if images:
message_data['images'] = images
message_data['picurl'] = images[0]
if files:
message_data['files'] = files
message_data['file'] = files[0]
if voices:
message_data['voices'] = voices
message_data['voice'] = voices[0]
if videos:
message_data['videos'] = videos
message_data['video'] = videos[0]
if links:
message_data['link'] = links[0]
if items:
message_data['attachments'] = items
else:
message_data['raw_msg'] = msg_json
from_info = msg_json.get('from', {})
message_data['userid'] = from_info.get('userid', '')
message_data['username'] = from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
if msg_json.get('chattype', '') == 'group':
message_data['chatid'] = msg_json.get('chatid', '')
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
message_data['msgid'] = msg_json.get('msgid', '')
if msg_json.get('aibotid'):
message_data['aibotid'] = msg_json.get('aibotid', '')
return message_data
class WecomBotClient:
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger, unified_mode: bool = False):
"""企业微信智能机器人客户端。
@@ -702,7 +455,196 @@ class WecomBotClient:
return await self._handle_post_initial_response(msg_json, nonce)
async def get_message(self, msg_json):
return await parse_wecom_bot_message(msg_json, self.EnCodingAESKey, self.logger)
message_data = {}
msg_type = msg_json.get('msgtype', '')
if msg_type:
message_data['msgtype'] = msg_type
if msg_json.get('chattype', '') == 'single':
message_data['type'] = 'single'
elif msg_json.get('chattype', '') == 'group':
message_data['type'] = 'group'
max_inline_file_size = 5 * 1024 * 1024 # avoid decoding very large payloads by default
async def _safe_download(url: str):
if not url:
return None
return await self.download_url_to_base64(url, self.EnCodingAESKey)
if msg_type == 'text':
message_data['content'] = msg_json.get('text', {}).get('content')
elif msg_type == 'markdown':
message_data['content'] = msg_json.get('markdown', {}).get('content') or msg_json.get('text', {}).get(
'content', ''
)
elif msg_type == 'image':
picurl = msg_json.get('image', {}).get('url', '')
base64_data = await _safe_download(picurl)
if base64_data:
message_data['picurl'] = base64_data
message_data['images'] = [base64_data]
elif msg_type == 'voice':
voice_info = msg_json.get('voice', {}) or {}
download_url = voice_info.get('url')
message_data['voice'] = {
'url': download_url,
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
'filesize': voice_info.get('filesize') or voice_info.get('size'),
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
}
# 企业微信智能转写文本(如果已有)直接复用,避免重复转写
if voice_info.get('content'):
message_data['content'] = voice_info.get('content')
if (message_data['voice'].get('filesize') or 0) <= max_inline_file_size:
voice_base64 = await _safe_download(download_url)
if voice_base64:
message_data['voice']['base64'] = voice_base64
elif msg_type == 'video':
video_info = msg_json.get('video', {}) or {}
download_url = video_info.get('url')
video_data = {
'url': download_url,
'filesize': video_info.get('filesize') or video_info.get('size'),
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
'filename': video_info.get('filename') or video_info.get('name'),
}
if (video_data.get('filesize') or 0) <= max_inline_file_size:
video_base64 = await _safe_download(download_url)
if video_base64:
video_data['base64'] = video_base64
message_data['video'] = video_data
elif msg_type == 'file':
file_info = msg_json.get('file', {}) or {}
download_url = file_info.get('url') or file_info.get('fileurl')
file_data = {
'filename': file_info.get('filename') or file_info.get('name'),
'filesize': file_info.get('filesize') or file_info.get('size'),
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
'download_url': download_url,
'extra': file_info,
}
if (file_data.get('filesize') or 0) <= max_inline_file_size:
file_base64 = await _safe_download(download_url)
if file_base64:
file_data['base64'] = file_base64
message_data['file'] = file_data
elif msg_type == 'link':
message_data['link'] = msg_json.get('link', {})
if not message_data.get('content'):
title = message_data['link'].get('title', '')
desc = message_data['link'].get('description') or message_data['link'].get('digest', '')
message_data['content'] = '\n'.join(filter(None, [title, desc]))
elif msg_type == 'mixed':
items = msg_json.get('mixed', {}).get('msg_item', [])
texts = []
images = []
files = []
voices = []
videos = []
links = []
for item in items:
item_type = item.get('msgtype')
if item_type == 'text':
texts.append(item.get('text', {}).get('content', ''))
elif item_type == 'image':
img_url = item.get('image', {}).get('url')
base64_data = await _safe_download(img_url)
if base64_data:
images.append(base64_data)
elif item_type == 'file':
file_info = item.get('file', {}) or {}
download_url = file_info.get('url') or file_info.get('fileurl')
file_data = {
'filename': file_info.get('filename') or file_info.get('name'),
'filesize': file_info.get('filesize') or file_info.get('size'),
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
'download_url': download_url,
'extra': file_info,
}
if (file_data.get('filesize') or 0) <= max_inline_file_size:
file_base64 = await _safe_download(download_url)
if file_base64:
file_data['base64'] = file_base64
files.append(file_data)
elif item_type == 'voice':
voice_info = item.get('voice', {}) or {}
download_url = voice_info.get('url')
voice_data = {
'url': download_url,
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
'filesize': voice_info.get('filesize') or voice_info.get('size'),
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
}
if voice_info.get('content'):
texts.append(voice_info.get('content'))
if (voice_data.get('filesize') or 0) <= max_inline_file_size:
voice_base64 = await _safe_download(download_url)
if voice_base64:
voice_data['base64'] = voice_base64
voices.append(voice_data)
elif item_type == 'video':
video_info = item.get('video', {}) or {}
download_url = video_info.get('url')
video_data = {
'url': download_url,
'filesize': video_info.get('filesize') or video_info.get('size'),
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
'filename': video_info.get('filename') or video_info.get('name'),
}
if (video_data.get('filesize') or 0) <= max_inline_file_size:
video_base64 = await _safe_download(download_url)
if video_base64:
video_data['base64'] = video_base64
videos.append(video_data)
elif item_type == 'link':
links.append(item.get('link', {}))
if texts:
message_data['content'] = ' '.join(texts) # 拼接所有 text
if images:
message_data['images'] = images
message_data['picurl'] = images[0] # 只保留第一个 image
if files:
message_data['files'] = files
message_data['file'] = files[0]
if voices:
message_data['voices'] = voices
message_data['voice'] = voices[0]
if videos:
message_data['videos'] = videos
message_data['video'] = videos[0]
if links:
message_data['link'] = links[0]
if items:
message_data['attachments'] = items
else:
message_data['raw_msg'] = msg_json
# Extract user information
from_info = msg_json.get('from', {})
message_data['userid'] = from_info.get('userid', '')
message_data['username'] = (
from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
)
# Extract chat/group information
if msg_json.get('chattype', '') == 'group':
message_data['chatid'] = msg_json.get('chatid', '')
# Try to get group name if available
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
message_data['msgid'] = msg_json.get('msgid', '')
if msg_json.get('aibotid'):
message_data['aibotid'] = msg_json.get('aibotid', '')
return message_data
async def _handle_message(self, event: wecombotevent.WecomBotEvent):
"""
@@ -770,7 +712,39 @@ class WecomBotClient:
return decorator
async def download_url_to_base64(self, download_url, encoding_aes_key):
return await download_encrypted_file(download_url, encoding_aes_key, self.logger)
async with httpx.AsyncClient() as client:
response = await client.get(download_url)
if response.status_code != 200:
await self.logger.error(f'failed to get file: {response.text}')
return None
encrypted_bytes = response.content
aes_key = base64.b64decode(encoding_aes_key + '=') # base64 补齐
iv = aes_key[:16]
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
decrypted = cipher.decrypt(encrypted_bytes)
pad_len = decrypted[-1]
decrypted = decrypted[:-pad_len]
if decrypted.startswith(b'\xff\xd8'): # JPEG
mime_type = 'image/jpeg'
elif decrypted.startswith(b'\x89PNG'): # PNG
mime_type = 'image/png'
elif decrypted.startswith((b'GIF87a', b'GIF89a')): # GIF
mime_type = 'image/gif'
elif decrypted.startswith(b'BM'): # BMP
mime_type = 'image/bmp'
elif decrypted.startswith(b'II*\x00') or decrypted.startswith(b'MM\x00*'): # TIFF
mime_type = 'image/tiff'
else:
mime_type = 'application/octet-stream'
# 转 base64
base64_str = base64.b64encode(decrypted).decode('utf-8')
return f'data:{mime_type};base64,{base64_str}'
async def run_task(self, host: str, port: int, *args, **kwargs):
"""

View File

@@ -1,596 +0,0 @@
"""WeChat Work AI Bot WebSocket long connection client.
Implements the WebSocket protocol for receiving messages and sending replies
via a persistent connection to wss://openws.work.weixin.qq.com, as an
alternative to the HTTP callback (webhook) mode.
Protocol reference: https://developer.work.weixin.qq.com/document/path/101463
Official Node.js SDK: https://github.com/WecomTeam/aibot-node-sdk
"""
from __future__ import annotations
import asyncio
import json
import secrets
import time
import traceback
from typing import Any, Callable, Optional
import aiohttp
from langbot.libs.wecom_ai_bot_api import wecombotevent
from langbot.libs.wecom_ai_bot_api.api import parse_wecom_bot_message
from langbot.pkg.platform.logger import EventLogger
DEFAULT_WS_URL = 'wss://openws.work.weixin.qq.com'
# WebSocket frame command constants
CMD_SUBSCRIBE = 'aibot_subscribe'
CMD_HEARTBEAT = 'ping'
CMD_MSG_CALLBACK = 'aibot_msg_callback'
CMD_EVENT_CALLBACK = 'aibot_event_callback'
CMD_RESPOND_MSG = 'aibot_respond_msg'
CMD_RESPOND_WELCOME = 'aibot_respond_welcome_msg'
CMD_RESPOND_UPDATE = 'aibot_respond_update_msg'
CMD_SEND_MSG = 'aibot_send_msg'
def _generate_req_id(prefix: str) -> str:
"""Generate a unique request ID in the format: {prefix}_{timestamp}_{random}."""
ts = int(time.time() * 1000)
rand = secrets.token_hex(4)
return f'{prefix}_{ts}_{rand}'
class WecomBotWsClient:
"""WeChat Work AI Bot WebSocket long connection client.
Provides message receiving, streaming reply, proactive message sending,
and event callback handling over a persistent WebSocket connection.
"""
def __init__(
self,
bot_id: str,
secret: str,
logger: EventLogger,
encoding_aes_key: str = '',
ws_url: str = DEFAULT_WS_URL,
heartbeat_interval: float = 30.0,
max_reconnect_attempts: int = -1,
reconnect_base_delay: float = 1.0,
reconnect_max_delay: float = 30.0,
):
self.bot_id = bot_id
self.secret = secret
self.logger = logger
self.encoding_aes_key = encoding_aes_key
self.ws_url = ws_url
self.heartbeat_interval = heartbeat_interval
self.max_reconnect_attempts = max_reconnect_attempts
self.reconnect_base_delay = reconnect_base_delay
self.reconnect_max_delay = reconnect_max_delay
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
self._session: Optional[aiohttp.ClientSession] = None
self._running = False
self._heartbeat_task: Optional[asyncio.Task] = None
self._missed_pong_count = 0
self._max_missed_pong = 2
self._reconnect_attempts = 0
# Message handler registry (same pattern as WecomBotClient)
self._message_handlers: dict[str, list[Callable]] = {}
# Message deduplication
self._msg_id_map: dict[str, int] = {}
# Pending ACK futures: req_id -> Future[dict]
self._pending_acks: dict[str, asyncio.Future] = {}
# Per-req_id serial reply queues
self._reply_queues: dict[str, asyncio.Queue] = {}
self._reply_workers: dict[str, asyncio.Task] = {}
self._reply_ack_timeout = 5.0
# Stream ID tracking for WebSocket mode
self._stream_ids: dict[str, str] = {} # msg_id -> req_id|stream_id
# Dedup: skip sending when content hasn't changed
self._stream_last_content: dict[str, str] = {} # msg_id -> last content sent
# ── Public API ──────────────────────────────────────────────────
async def connect(self):
"""Connect to WebSocket server with automatic reconnection.
This method blocks until disconnect() is called or max reconnect
attempts are exhausted.
"""
self._running = True
self._reconnect_attempts = 0
while self._running:
try:
await self._connect_once()
except Exception:
if not self._running:
break
await self.logger.error(f'WebSocket connection error: {traceback.format_exc()}')
if not self._running:
break
# Reconnect with exponential backoff
if self.max_reconnect_attempts != -1 and self._reconnect_attempts >= self.max_reconnect_attempts:
await self.logger.error(f'Max reconnect attempts reached ({self.max_reconnect_attempts}), giving up')
break
self._reconnect_attempts += 1
delay = min(
self.reconnect_base_delay * (2 ** (self._reconnect_attempts - 1)),
self.reconnect_max_delay,
)
await self.logger.info(f'Reconnecting in {delay:.1f}s (attempt {self._reconnect_attempts})...')
await asyncio.sleep(delay)
async def disconnect(self):
"""Gracefully disconnect from the WebSocket server."""
self._running = False
if self._heartbeat_task and not self._heartbeat_task.done():
self._heartbeat_task.cancel()
for task in self._reply_workers.values():
if not task.done():
task.cancel()
if self._ws and not self._ws.closed:
await self._ws.close()
self._ws = None
if self._session and not self._session.closed:
await self._session.close()
self._session = None
def on_message(self, msg_type: str) -> Callable:
"""Decorator to register a message handler.
Same interface as WecomBotClient.on_message for compatibility.
Args:
msg_type: 'single', 'group', or specific message type.
"""
def decorator(func: Callable[[wecombotevent.WecomBotEvent], Any]):
if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func)
return func
return decorator
async def reply_stream(
self,
req_id: str,
stream_id: str,
content: str,
finish: bool = False,
) -> Optional[dict]:
"""Send a streaming reply frame.
Args:
req_id: The req_id from the original message frame (must be passed through).
stream_id: The stream ID for this streaming session.
content: The content to send (supports Markdown).
finish: Whether this is the final chunk.
Returns:
The ACK frame dict, or None on failure.
"""
body = {
'msgtype': 'stream',
'stream': {
'id': stream_id,
'finish': finish,
'content': content,
},
}
return await self._send_reply(req_id, body)
async def reply_text(self, req_id: str, content: str) -> Optional[dict]:
"""Send a non-streaming text reply.
Args:
req_id: The req_id from the original message frame.
content: The text content to reply.
Returns:
The ACK frame dict, or None on failure.
"""
body = {
'msgtype': 'markdown',
'markdown': {
'content': content,
},
}
return await self._send_reply(req_id, body)
async def send_message(self, chat_id: str, content: str, msgtype: str = 'markdown') -> Optional[dict]:
"""Proactively send a message to a specified chat.
Args:
chat_id: The chat ID (userid for single chat, chatid for group chat).
content: The message content.
msgtype: Message type, 'markdown' by default.
Returns:
The ACK frame dict, or None on failure.
"""
req_id = _generate_req_id(CMD_SEND_MSG)
body: dict[str, Any] = {
'chatid': chat_id,
'msgtype': msgtype,
}
if msgtype == 'markdown':
body['markdown'] = {'content': content}
elif msgtype == 'text':
body['text'] = {'content': content}
return await self._send_reply(req_id, body, cmd=CMD_SEND_MSG)
async def push_stream_chunk(self, msg_id: str, content: str, is_final: bool = False) -> bool:
"""Push a streaming chunk for a given message ID.
Compatible interface with WecomBotClient.push_stream_chunk.
Args:
msg_id: The original message ID.
content: The cumulative content from the pipeline.
is_final: Whether this is the final chunk.
Returns:
True if the stream session exists and chunk was sent.
"""
key = self._stream_ids.get(msg_id)
if not key:
return False
req_id, stream_id = key.split('|', 1)
try:
# Skip sending if content hasn't changed (e.g. during tool call argument streaming)
if not is_final and content == self._stream_last_content.get(msg_id):
return True
await self.reply_stream(req_id, stream_id, content, finish=is_final)
self._stream_last_content[msg_id] = content
if is_final:
self._stream_ids.pop(msg_id, None)
self._stream_last_content.pop(msg_id, None)
return True
except Exception:
await self.logger.error(f'Failed to push stream chunk: {traceback.format_exc()}')
return False
async def set_message(self, msg_id: str, content: str):
"""Fallback: send content as a final stream chunk or direct reply.
Compatible interface with WecomBotClient.set_message.
"""
handled = await self.push_stream_chunk(msg_id, content, is_final=True)
if not handled:
await self.logger.warning(f'No active stream for msg_id={msg_id}, message dropped')
# ── Connection lifecycle ────────────────────────────────────────
async def _connect_once(self):
"""Establish a single WebSocket connection, authenticate, and listen."""
await self.logger.info(f'Connecting to {self.ws_url}...')
self._session = aiohttp.ClientSession()
try:
self._ws = await self._session.ws_connect(self.ws_url)
self._missed_pong_count = 0
self._reconnect_attempts = 0
await self.logger.info('WebSocket connected, sending auth...')
await self._send_auth()
# Wait for auth response
auth_ok = await self._wait_for_auth()
if not auth_ok:
await self.logger.error('Authentication failed')
return
await self.logger.info('Authenticated successfully')
# Start heartbeat
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
try:
await self._listen_loop()
finally:
if self._heartbeat_task and not self._heartbeat_task.done():
self._heartbeat_task.cancel()
self._clear_pending_acks('Connection closed')
finally:
if self._ws and not self._ws.closed:
await self._ws.close()
self._ws = None
if self._session and not self._session.closed:
await self._session.close()
self._session = None
async def _send_auth(self):
"""Send the authentication frame."""
frame = {
'cmd': CMD_SUBSCRIBE,
'headers': {'req_id': _generate_req_id(CMD_SUBSCRIBE)},
'body': {
'bot_id': self.bot_id,
'secret': self.secret,
},
}
await self._send_frame(frame)
async def _wait_for_auth(self) -> bool:
"""Wait for and validate the authentication response."""
try:
msg = await asyncio.wait_for(self._ws.receive(), timeout=10.0)
if msg.type in (aiohttp.WSMsgType.TEXT,):
frame = json.loads(msg.data)
req_id = frame.get('headers', {}).get('req_id', '')
if req_id.startswith(CMD_SUBSCRIBE) and frame.get('errcode') == 0:
return True
await self.logger.error(f'Auth response: errcode={frame.get("errcode")}, errmsg={frame.get("errmsg")}')
return False
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
await self.logger.error(f'WebSocket closed during auth: {msg.type}')
return False
await self.logger.error(f'Unexpected message type during auth: {msg.type}')
return False
except asyncio.TimeoutError:
await self.logger.error('Auth response timeout')
return False
async def _heartbeat_loop(self):
"""Periodically send heartbeat pings."""
try:
while self._running and self._ws and not self._ws.closed:
await asyncio.sleep(self.heartbeat_interval)
if not self._running or not self._ws or self._ws.closed:
break
if self._missed_pong_count >= self._max_missed_pong:
await self.logger.warning(
f'No heartbeat ack for {self._missed_pong_count} consecutive pings, connection considered dead'
)
await self._ws.close()
break
self._missed_pong_count += 1
frame = {
'cmd': CMD_HEARTBEAT,
'headers': {'req_id': _generate_req_id(CMD_HEARTBEAT)},
}
try:
await self._send_frame(frame)
except Exception:
break
except asyncio.CancelledError:
pass
async def _listen_loop(self):
"""Listen for incoming WebSocket frames and dispatch them."""
async for msg in self._ws:
if not self._running:
break
if msg.type == aiohttp.WSMsgType.TEXT:
try:
frame = json.loads(msg.data)
await self._handle_frame(frame)
except json.JSONDecodeError:
await self.logger.error(f'Failed to parse WebSocket message: {str(msg.data)[:200]}')
except Exception:
await self.logger.error(f'Error handling frame: {traceback.format_exc()}')
elif msg.type == aiohttp.WSMsgType.BINARY:
try:
frame = json.loads(msg.data)
await self._handle_frame(frame)
except Exception:
await self.logger.error(f'Error handling binary frame: {traceback.format_exc()}')
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
await self.logger.warning(f'WebSocket connection closed: {msg.type}')
break
# ── Frame handling ──────────────────────────────────────────────
async def _handle_frame(self, frame: dict):
"""Route an incoming frame to the appropriate handler."""
cmd = frame.get('cmd', '')
# Message push
if cmd == CMD_MSG_CALLBACK:
asyncio.create_task(self._handle_message_callback(frame))
return
# Event push
if cmd == CMD_EVENT_CALLBACK:
asyncio.create_task(self._handle_event_callback(frame))
return
# No cmd → response/ACK frame, dispatch by req_id prefix
req_id = frame.get('headers', {}).get('req_id', '')
# Check pending ACKs first
if req_id in self._pending_acks:
future = self._pending_acks.pop(req_id)
if not future.done():
future.set_result(frame)
return
# Heartbeat response
if req_id.startswith(CMD_HEARTBEAT):
if frame.get('errcode') == 0:
self._missed_pong_count = 0
return
# Unknown frame
await self.logger.warning(f'Unknown frame: {json.dumps(frame, ensure_ascii=False)[:200]}')
async def _handle_message_callback(self, frame: dict):
"""Handle an incoming message callback frame."""
try:
body = frame.get('body', {})
req_id = frame.get('headers', {}).get('req_id', '')
# Parse message using shared logic
message_data = await parse_wecom_bot_message(body, self.encoding_aes_key, self.logger)
if not message_data:
return
# Generate stream_id for this message and store the mapping
stream_id = _generate_req_id('stream')
msg_id = message_data.get('msgid', '')
if msg_id:
self._stream_ids[msg_id] = f'{req_id}|{stream_id}'
message_data['stream_id'] = stream_id
message_data['req_id'] = req_id
event = wecombotevent.WecomBotEvent(message_data)
await self._dispatch_event(event)
except Exception:
await self.logger.error(f'Error in message callback: {traceback.format_exc()}')
async def _handle_event_callback(self, frame: dict):
"""Handle an incoming event callback frame (enter_chat, template_card_event, etc.)."""
try:
body = frame.get('body', {})
req_id = frame.get('headers', {}).get('req_id', '')
event_info = body.get('event', {})
event_type = event_info.get('eventtype', '')
message_data = {
'msgtype': 'event',
'type': body.get('chattype', 'single'),
'event': event_info,
'eventtype': event_type,
'msgid': body.get('msgid', ''),
'aibotid': body.get('aibotid', ''),
'req_id': req_id,
}
from_info = body.get('from', {})
message_data['userid'] = from_info.get('userid', '')
message_data['username'] = from_info.get('alias', '') or from_info.get('userid', '')
if body.get('chatid'):
message_data['chatid'] = body.get('chatid', '')
event = wecombotevent.WecomBotEvent(message_data)
# Dispatch to event-specific handlers
if event_type in self._message_handlers:
for handler in self._message_handlers[event_type]:
await handler(event)
# Also dispatch to generic 'event' handlers
if 'event' in self._message_handlers:
for handler in self._message_handlers['event']:
await handler(event)
except Exception:
await self.logger.error(f'Error in event callback: {traceback.format_exc()}')
async def _dispatch_event(self, event: wecombotevent.WecomBotEvent):
"""Dispatch a message event to registered handlers with deduplication."""
try:
message_id = event.message_id
if message_id in self._msg_id_map:
self._msg_id_map[message_id] += 1
return
self._msg_id_map[message_id] = 1
msg_type = event.type
if msg_type in self._message_handlers:
for handler in self._message_handlers[msg_type]:
await handler(event)
except Exception:
await self.logger.error(f'Error dispatching event: {traceback.format_exc()}')
# ── Reply sending with serial queue ─────────────────────────────
async def _send_reply(
self,
req_id: str,
body: dict,
cmd: str = CMD_RESPOND_MSG,
) -> Optional[dict]:
"""Send a reply frame and wait for ACK.
Replies with the same req_id are serialized to maintain ordering.
"""
if not self._ws or self._ws.closed:
return None
frame = {
'cmd': cmd,
'headers': {'req_id': req_id},
'body': body,
}
# Ensure serial delivery per req_id
if req_id not in self._reply_queues:
self._reply_queues[req_id] = asyncio.Queue()
self._reply_workers[req_id] = asyncio.create_task(self._reply_queue_worker(req_id))
future: asyncio.Future = asyncio.get_event_loop().create_future()
await self._reply_queues[req_id].put((frame, future))
return await future
async def _reply_queue_worker(self, req_id: str):
"""Process reply queue items serially for a given req_id."""
queue = self._reply_queues[req_id]
try:
while self._running:
try:
frame, future = await asyncio.wait_for(queue.get(), timeout=60.0)
except asyncio.TimeoutError:
# Queue idle, clean up worker
break
try:
ack = await self._send_and_wait_ack(frame)
if not future.done():
future.set_result(ack)
except Exception as e:
if not future.done():
future.set_exception(e)
except asyncio.CancelledError:
pass
finally:
self._reply_queues.pop(req_id, None)
self._reply_workers.pop(req_id, None)
async def _send_and_wait_ack(self, frame: dict) -> Optional[dict]:
"""Send a frame and wait for the corresponding ACK."""
req_id = frame['headers']['req_id']
ack_future: asyncio.Future = asyncio.get_event_loop().create_future()
self._pending_acks[req_id] = ack_future
try:
await self._send_frame(frame)
result = await asyncio.wait_for(ack_future, timeout=self._reply_ack_timeout)
if result.get('errcode', 0) != 0:
await self.logger.warning(
f'Reply ACK error: errcode={result.get("errcode")}, errmsg={result.get("errmsg")}'
)
return result
except asyncio.TimeoutError:
self._pending_acks.pop(req_id, None)
await self.logger.warning(f'Reply ACK timeout ({self._reply_ack_timeout}s) for req_id={req_id}')
return None
async def _send_frame(self, frame: dict):
"""Send a JSON frame over the WebSocket connection."""
if self._ws and not self._ws.closed:
await self._ws.send_str(json.dumps(frame, ensure_ascii=False))
def _clear_pending_acks(self, reason: str):
"""Reject all pending ACK futures on disconnection."""
for req_id, future in self._pending_acks.items():
if not future.done():
future.set_exception(ConnectionError(reason))
self._pending_acks.clear()

View File

@@ -70,17 +70,12 @@ class BotService:
'lark',
]:
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
extra_webhook_prefix = self.ap.instance_config.data['api'].get('extra_webhook_prefix', '')
webhook_url = f'/bots/{bot_uuid}'
adapter_runtime_values['webhook_url'] = webhook_url
adapter_runtime_values['webhook_full_url'] = f'{webhook_prefix}{webhook_url}'
adapter_runtime_values['extra_webhook_full_url'] = (
f'{extra_webhook_prefix}{webhook_url}' if extra_webhook_prefix else ''
)
else:
adapter_runtime_values['webhook_url'] = None
adapter_runtime_values['webhook_full_url'] = None
adapter_runtime_values['extra_webhook_full_url'] = None
persistence_bot['adapter_runtime_values'] = adapter_runtime_values

View File

@@ -105,16 +105,11 @@ class LLMModelsService:
)
)
pipeline = result.first()
if pipeline is not None:
model_config = pipeline.config.get('ai', {}).get('local-agent', {}).get('model', {})
if not model_config.get('primary', ''):
pipeline_config = pipeline.config
pipeline_config['ai']['local-agent']['model'] = {
'primary': model_data['uuid'],
'fallbacks': [],
}
pipeline_data = {'config': pipeline_config}
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '':
pipeline_config = pipeline.config
pipeline_config['ai']['local-agent']['model'] = model_data['uuid']
pipeline_data = {'config': pipeline_config}
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
return model_data['uuid']

View File

@@ -1,49 +0,0 @@
from .. import migration
import sqlalchemy
import json
@migration.migration_class(24)
class DBMigrateWecomBotWebSocketMode(migration.DBMigration):
"""Add enable-webhook field to existing wecombot adapter configs.
Existing wecombot bots were all using webhook mode, so we set
enable-webhook=true to preserve their behavior after the new
WebSocket long connection mode is introduced as default.
"""
async def upgrade(self):
"""Upgrade"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.text("SELECT uuid, adapter_config FROM bots WHERE adapter = 'wecombot'")
)
bots = result.fetchall()
for bot_row in bots:
bot_uuid = bot_row[0]
adapter_config = json.loads(bot_row[1]) if isinstance(bot_row[1], str) else bot_row[1]
if 'enable-webhook' in adapter_config:
continue
# Determine mode based on existing config: if webhook fields are present, keep webhook mode
has_webhook_config = bool(
adapter_config.get('Token') and adapter_config.get('EncodingAESKey') and adapter_config.get('Corpid')
)
adapter_config['enable-webhook'] = has_webhook_config
if self.ap.persistence_mgr.db.name == 'postgresql':
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('UPDATE bots SET adapter_config = :config::jsonb WHERE uuid = :uuid'),
{'config': json.dumps(adapter_config), 'uuid': bot_uuid},
)
else:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('UPDATE bots SET adapter_config = :config WHERE uuid = :uuid'),
{'config': json.dumps(adapter_config), 'uuid': bot_uuid},
)
async def downgrade(self):
"""Downgrade"""
pass

View File

@@ -575,127 +575,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
_processed_thread_quote_cache: typing.ClassVar[dict[str, float]] = {}
_processed_thread_quote_cache_max_size: typing.ClassVar[int] = 4096
_processed_thread_quote_cache_ttl_seconds: typing.ClassVar[int] = 86400
@classmethod
def _prune_processed_thread_quote_cache(cls, now: typing.Optional[float] = None) -> None:
if now is None:
now = time.time()
expire_before = now - cls._processed_thread_quote_cache_ttl_seconds
while cls._processed_thread_quote_cache:
oldest_key, oldest_ts = next(iter(cls._processed_thread_quote_cache.items()))
if oldest_ts >= expire_before:
break
cls._processed_thread_quote_cache.pop(oldest_key, None)
while len(cls._processed_thread_quote_cache) > cls._processed_thread_quote_cache_max_size:
oldest_key = next(iter(cls._processed_thread_quote_cache))
cls._processed_thread_quote_cache.pop(oldest_key, None)
@classmethod
def _mark_thread_quote_processed(cls, thread_id: str) -> None:
now = time.time()
cls._prune_processed_thread_quote_cache(now)
cls._processed_thread_quote_cache[thread_id] = now
@classmethod
def _extract_quote_message_id(cls, message: EventMessage) -> typing.Optional[str]:
"""
Extract the message ID to quote from the given message.
Rules:
- First thread reply in a topic: return parent_id and mark topic as processed
- Follow-up thread replies in the same topic: return None
- Non-thread message: return parent_id if valid (non-empty, different from message_id)
Thread reply state is kept in a bounded TTL cache to avoid unbounded memory growth.
"""
parent_id = getattr(message, 'parent_id', None)
if not parent_id:
return None
message_id = getattr(message, 'message_id', None)
if parent_id == message_id:
return None
thread_id = getattr(message, 'thread_id', None)
if thread_id:
cls._prune_processed_thread_quote_cache()
if thread_id in cls._processed_thread_quote_cache:
return None
cls._mark_thread_quote_processed(thread_id)
return parent_id
@staticmethod
def _build_event_message_from_message_item(message_item: Message) -> typing.Optional[EventMessage]:
"""
Build EventMessage from SDK typed Message item.
Returns None if body or content is missing.
"""
body = getattr(message_item, 'body', None)
if not body:
return None
content = getattr(body, 'content', None)
if not content:
return None
event_data = {
'message_id': message_item.message_id,
'message_type': message_item.msg_type,
'content': content,
'create_time': message_item.create_time,
'mentions': getattr(message_item, 'mentions', []) or [],
}
# Preserve thread-related fields
if hasattr(message_item, 'parent_id') and message_item.parent_id:
event_data['parent_id'] = message_item.parent_id
if hasattr(message_item, 'root_id') and message_item.root_id:
event_data['root_id'] = message_item.root_id
if hasattr(message_item, 'thread_id') and message_item.thread_id:
event_data['thread_id'] = message_item.thread_id
if hasattr(message_item, 'chat_id') and message_item.chat_id:
event_data['chat_id'] = message_item.chat_id
return EventMessage(event_data)
@staticmethod
async def _fetch_quoted_message(
quote_message_id: str,
api_client: lark_oapi.Client,
) -> typing.Optional[platform_message.MessageChain]:
"""
Fetch the quoted message and convert to MessageChain.
Returns None if:
- API call fails
- Response items is empty
- Message item normalization fails
"""
request = GetMessageRequest.builder().message_id(quote_message_id).build()
response = await api_client.im.v1.message.aget(request)
if not response.success():
return None
items = getattr(response.data, 'items', None)
if not items:
return None
message_item = items[0]
event_message = LarkEventConverter._build_event_message_from_message_item(message_item)
if event_message is None:
return None
quote_chain = await LarkMessageConverter.target2yiri(event_message, api_client)
return quote_chain
@staticmethod
async def yiri2target(
event: platform_events.MessageEvent,
@@ -708,23 +587,6 @@ class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
) -> platform_events.Event:
message_chain = await LarkMessageConverter.target2yiri(event.event.message, api_client)
# Check for quote/reply message
quote_message_id = LarkEventConverter._extract_quote_message_id(event.event.message)
if quote_message_id:
quote_chain = await LarkEventConverter._fetch_quoted_message(quote_message_id, api_client)
if quote_chain:
# Filter out Source component from quoted chain, keep only content
quote_origin = platform_message.MessageChain(
[comp for comp in quote_chain if not isinstance(comp, platform_message.Source)]
)
if quote_origin:
message_chain.append(
platform_message.Quote(
message_id=quote_message_id,
origin=quote_origin,
)
)
if event.event.message.chat_type == 'p2p':
return platform_events.FriendMessage(
sender=platform_entities.Friend(
@@ -908,32 +770,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
self.request_tenant_access_token(tenant_key)
return self.tenant_access_tokens.get(tenant_key)['token'] if self.tenant_access_tokens.get(tenant_key) else None
def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None:
"""
Get topic-scoped launcher_id for thread-aware session isolation.
For group thread messages, returns "{group_id}_{thread_id}"
to ensure conversation context stays stable per topic.
Returns None for non-thread messages or P2P messages.
"""
source_event = getattr(event.source_platform_object, 'event', None)
if not source_event:
return None
message = getattr(source_event, 'message', None)
if not message:
return None
thread_id = getattr(message, 'thread_id', None)
if not thread_id:
return None
if isinstance(event, platform_events.GroupMessage):
return f'{event.group.id}_{thread_id}'
return None
def build_api_client(self, config):
app_id = config['app_id']
app_secret = config['app_secret']

View File

@@ -11,7 +11,6 @@ import langbot_plugin.api.entities.builtin.platform.entities as platform_entitie
from ..logger import EventLogger
from langbot.libs.wecom_ai_bot_api.wecombotevent import WecomBotEvent
from langbot.libs.wecom_ai_bot_api.api import WecomBotClient
from langbot.libs.wecom_ai_bot_api.ws_client import WecomBotWsClient
class WecomBotMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@@ -177,42 +176,27 @@ class WecomBotEventConverter(abstract_platform_adapter.AbstractEventConverter):
class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: typing.Union[WecomBotClient, WecomBotWsClient]
bot: WecomBotClient
bot_account_id: str
message_converter: WecomBotMessageConverter = WecomBotMessageConverter()
event_converter: WecomBotEventConverter = WecomBotEventConverter()
config: dict
bot_uuid: str = None
_ws_mode: bool = False
def __init__(self, config: dict, logger: EventLogger):
enable_webhook = config.get('enable-webhook', False)
required_keys = ['Token', 'EncodingAESKey', 'Corpid', 'BotId']
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise Exception(f'WecomBot 缺少配置项: {missing_keys}')
if not enable_webhook:
bot = WecomBotWsClient(
bot_id=config['BotId'],
secret=config['Secret'],
logger=logger,
encoding_aes_key=config.get('EncodingAESKey', ''),
)
ws_mode = True
else:
# Webhook callback mode
required_keys = ['Token', 'EncodingAESKey', 'Corpid']
missing_keys = [key for key in required_keys if key not in config or not config[key]]
if missing_keys:
raise Exception(f'WecomBot webhook mode missing config: {missing_keys}')
bot = WecomBotClient(
Token=config['Token'],
EnCodingAESKey=config['EncodingAESKey'],
Corpid=config['Corpid'],
logger=logger,
unified_mode=True,
)
ws_mode = False
bot_account_id = config.get('BotId', '')
bot = WecomBotClient(
Token=config['Token'],
EnCodingAESKey=config['EncodingAESKey'],
Corpid=config['Corpid'],
logger=logger,
unified_mode=True,
)
bot_account_id = config['BotId']
super().__init__(
config=config,
@@ -220,7 +204,6 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot=bot,
bot_account_id=bot_account_id,
)
self._ws_mode = ws_mode
async def reply_message(
self,
@@ -229,15 +212,7 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
quote_origin: bool = False,
):
content = await self.message_converter.yiri2target(message)
if self._ws_mode:
event = message_source.source_platform_object
req_id = event.get('req_id', '')
if req_id:
await self.bot.reply_text(req_id, content)
else:
await self.bot.set_message(event.message_id, content)
else:
await self.bot.set_message(message_source.source_platform_object.message_id, content)
await self.bot.set_message(message_source.source_platform_object.message_id, content)
async def reply_message_chunk(
self,
@@ -247,22 +222,31 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
quote_origin: bool = False,
is_final: bool = False,
):
"""将流水线增量输出写入企业微信 stream 会话。
Args:
message_source: 流水线提供的原始消息事件。
bot_message: 当前片段对应的模型元信息(未使用)。
message: 需要回复的消息链。
quote_origin: 是否引用原消息(企业微信暂不支持)。
is_final: 标记当前片段是否为最终回复。
Returns:
dict: 包含 `stream` 键,标识写入是否成功。
Example:
在流水线 `reply_message_chunk` 调用中自动触发,无需手动调用。
"""
# 转换为纯文本(智能机器人当前协议仅支持文本流)
content = await self.message_converter.yiri2target(message)
msg_id = message_source.source_platform_object.message_id
if self._ws_mode:
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
if not success and is_final:
event = message_source.source_platform_object
req_id = event.get('req_id', '')
if req_id:
await self.bot.reply_text(req_id, content)
return {'stream': success}
else:
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
if not success and is_final:
await self.bot.set_message(msg_id, content)
return {'stream': success}
# 将片段推送到 WecomBotClient 中的队列,返回值用于判断是否走降级逻辑
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
if not success and is_final:
# 未命中流式队列时使用旧有 set_message 兜底
await self.bot.set_message(msg_id, content)
return {'stream': success}
async def is_stream_output_supported(self) -> bool:
"""智能机器人侧默认开启流式能力。
@@ -275,11 +259,7 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
return True
async def send_message(self, target_type, target_id, message):
if self._ws_mode:
content = await self.message_converter.yiri2target(message)
await self.bot.send_message(target_id, content)
else:
pass
pass
def register_listener(
self,
@@ -308,25 +288,29 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
self.bot_uuid = bot_uuid
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
if self._ws_mode:
return None
"""处理统一 webhook 请求。
Args:
bot_uuid: Bot 的 UUID
path: 子路径(如果有的话)
request: Quart Request 对象
Returns:
响应数据
"""
return await self.bot.handle_unified_webhook(request)
async def run_async(self):
if self._ws_mode:
await self.bot.connect()
else:
# 统一 webhook 模式下,不启动独立的 Quart 应用
# 保持运行但不启动独立端口
async def keep_alive():
while True:
await asyncio.sleep(1)
async def keep_alive():
while True:
await asyncio.sleep(1)
await keep_alive()
await keep_alive()
async def kill(self) -> bool:
if self._ws_mode:
await self.bot.disconnect()
return True
return False
async def unregister_listener(

View File

@@ -11,64 +11,35 @@ metadata:
icon: wecombot.png
spec:
config:
- name: BotId
label:
en_US: BotId
zh_Hans: 机器人ID (BotId)
type: string
required: true
default: ""
- name: enable-webhook
label:
en_US: Enable Webhook Mode
zh_Hans: 启用Webhook模式
description:
en_US: If enabled, the bot will use webhook mode to receive messages. Otherwise, it will use WS long connection mode
zh_Hans: 如果启用,机器人将使用 Webhook 模式接收消息。否则,将使用 WS 长连接模式
type: boolean
required: true
default: false
- name: Secret
label:
en_US: Secret
zh_Hans: 机器人密钥 (Secret)
description:
en_US: Required for WebSocket long connection mode
zh_Hans: 使用 WS 长连接模式时必填
type: string
required: false
default: ""
- name: Corpid
label:
en_US: Corpid
zh_Hans: 企业ID
description:
en_US: Required for Webhook mode
zh_Hans: 使用 Webhook 模式时必填
type: string
required: false
required: true
default: ""
- name: Token
label:
en_US: Token
zh_Hans: 令牌 (Token)
description:
en_US: Required for Webhook mode
zh_Hans: 使用 Webhook 模式时必填
type: string
required: false
required: true
default: ""
- name: EncodingAESKey
label:
en_US: EncodingAESKey
zh_Hans: 消息加解密密钥 (EncodingAESKey)
description:
en_US: Required for Webhook mode. Optional for WebSocket mode (used for file decryption)
zh_Hans: 使用 Webhook 模式时必填。WebSocket 模式下可选(用于文件解密)
type: string
required: true
default: ""
- name: BotId
label:
en_US: BotId
zh_Hans: 机器人ID
type: string
required: false
default: ""
execution:
python:
path: ./wecombot.py
attr: WecomBotAdapter
attr: WecomBotAdapter

View File

@@ -565,16 +565,6 @@ class RuntimeConnectionHandler(handler.Handler):
except Exception as e:
return _make_rag_error_response(e, 'FileServiceError', storage_path=storage_path)
@self.action(PluginToRuntimeAction.LIST_PARSERS)
async def list_parsers(data: dict[str, Any]) -> handler.ActionResponse:
"""Plugin requests host to list available parser plugins."""
mime_type = data.get('mime_type')
try:
parsers = await self.ap.knowledge_service.list_parsers(mime_type)
return handler.ActionResponse.success(data={'parsers': parsers})
except Exception as e:
return _make_rag_error_response(e, 'ParserDiscoveryError', mime_type=mime_type)
@self.action(PluginToRuntimeAction.INVOKE_PARSER)
async def invoke_parser(data: dict[str, Any]) -> handler.ActionResponse:
"""Plugin requests host to invoke a parser plugin."""
@@ -599,94 +589,6 @@ class RuntimeConnectionHandler(handler.Handler):
except Exception as e:
return _make_rag_error_response(e, 'ParserError')
# ================= Knowledge Base Query APIs =================
@self.action(PluginToRuntimeAction.LIST_PIPELINE_KNOWLEDGE_BASES)
async def list_pipeline_knowledge_bases(data: dict[str, Any]) -> handler.ActionResponse:
"""List knowledge bases configured for the current query's pipeline."""
query_id = data['query_id']
if query_id not in self.ap.query_pool.cached_queries:
return handler.ActionResponse.error(
message=f'Query with query_id {query_id} not found',
)
query = self.ap.query_pool.cached_queries[query_id]
kb_uuids = []
if query.pipeline_config:
local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {})
kb_uuids = local_agent_config.get('knowledge-bases', [])
# Backward compatibility
if not kb_uuids:
old_kb_uuid = local_agent_config.get('knowledge-base', '')
if old_kb_uuid and old_kb_uuid != '__none__':
kb_uuids = [old_kb_uuid]
knowledge_bases = []
for kb_uuid in kb_uuids:
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
if kb:
knowledge_bases.append(
{
'uuid': kb.get_uuid(),
'name': kb.get_name(),
'description': kb.knowledge_base_entity.description or '',
}
)
return handler.ActionResponse.success(data={'knowledge_bases': knowledge_bases})
@self.action(PluginToRuntimeAction.RETRIEVE_KNOWLEDGE_BASE)
async def retrieve_knowledge_base(data: dict[str, Any]) -> handler.ActionResponse:
"""Retrieve documents from a knowledge base within the pipeline's scope."""
query_id = data['query_id']
kb_id = data['kb_id']
query_text = data['query_text']
top_k = data.get('top_k', 5)
filters = data.get('filters', {})
if query_id not in self.ap.query_pool.cached_queries:
return handler.ActionResponse.error(
message=f'Query with query_id {query_id} not found',
)
query = self.ap.query_pool.cached_queries[query_id]
# Validate kb_id is in pipeline's allowed list
allowed_kb_uuids = []
if query.pipeline_config:
local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {})
allowed_kb_uuids = local_agent_config.get('knowledge-bases', [])
if not allowed_kb_uuids:
old_kb_uuid = local_agent_config.get('knowledge-base', '')
if old_kb_uuid and old_kb_uuid != '__none__':
allowed_kb_uuids = [old_kb_uuid]
if kb_id not in allowed_kb_uuids:
return handler.ActionResponse.error(
message=f'Knowledge base {kb_id} is not configured for this pipeline',
)
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_id)
if not kb:
return handler.ActionResponse.error(
message=f'Knowledge base {kb_id} not found',
)
try:
entries = await kb.retrieve(
query_text,
settings={
'top_k': top_k,
'filters': filters,
},
)
results = [entry.model_dump(mode='json') for entry in entries]
return handler.ActionResponse.success(data={'results': results})
except Exception as e:
return _make_rag_error_response(e, 'RetrievalError', kb_id=kb_id)
@self.action(CommonAction.PING)
async def ping(data: dict[str, Any]) -> handler.ActionResponse:
"""Ping"""
@@ -993,7 +895,7 @@ class RuntimeConnectionHandler(handler.Handler):
result = await self.call_action(
LangBotToRuntimeAction.RAG_INGEST_DOCUMENT,
{'plugin_author': plugin_author, 'plugin_name': plugin_name, 'context': context_data},
timeout=1200, # Ingestion can be slow for large documents
timeout=300, # Ingestion can be slow
)
return result

View File

@@ -132,6 +132,12 @@ class LocalAgentRunner(runner.RequestRunner):
"""Run request"""
pending_tool_calls = []
# Agent loop protection config
agent_config = query.pipeline_config['ai']['local-agent']
max_tool_iterations = agent_config.get('max-tool-iterations', 16)
max_tool_result_chars = agent_config.get('max-tool-result-chars', 8000)
iteration_count = 0
# Get knowledge bases list (new field)
kb_uuids = query.pipeline_config['ai']['local-agent'].get('knowledge-bases', [])
@@ -168,7 +174,6 @@ class LocalAgentRunner(runner.RequestRunner):
result = await kb.retrieve(
user_message_text,
settings={
'bot_uuid': query.bot_uuid or '',
'sender_id': str(query.sender_id),
'session_name': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
},
@@ -296,6 +301,14 @@ class LocalAgentRunner(runner.RequestRunner):
# Once a model succeeds, commit to it for the tool call loop
# (no fallback mid-conversation — different models may interpret tool results differently)
while pending_tool_calls:
iteration_count += 1
if iteration_count > max_tool_iterations:
self.ap.logger.warning(
f'localagent: query={query.query_id} agent loop exceeded max iterations ({max_tool_iterations}), '
f'forcing termination'
)
break
for tool_call in pending_tool_calls:
try:
func = tool_call.function
@@ -318,6 +331,14 @@ class LocalAgentRunner(runner.RequestRunner):
else:
tool_content = json.dumps(func_ret, ensure_ascii=False)
# Truncate oversized tool results to prevent context overflow
if isinstance(tool_content, str) and len(tool_content) > max_tool_result_chars:
self.ap.logger.warning(
f'localagent: tool {func.name} returned {len(tool_content)} chars, '
f'truncating to {max_tool_result_chars}'
)
tool_content = tool_content[:max_tool_result_chars] + '\n...[result truncated]'
if is_stream:
msg = provider_message.MessageChunk(
role='tool',

View File

@@ -2,7 +2,7 @@ import langbot
semantic_version = f'v{langbot.__version__}'
required_database_version = 24
required_database_version = 23
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
debug_mode = False

View File

@@ -100,7 +100,7 @@ class VectorDBManager:
) -> list[dict]:
"""Proxy: Search vectors.
Returns a list of dicts with keys: 'id', 'distance', 'metadata'.
Returns a list of dicts with keys: 'id', 'score', 'metadata'.
The underlying VectorDatabase.search returns Chroma-style format:
{ 'ids': [['id1']], 'distances': [[0.1]], 'metadatas': [[{}]] }
"""
@@ -130,7 +130,7 @@ class VectorDBManager:
parsed_results.append(
{
'id': id_val,
'distance': r_dists[i] if r_dists and i < len(r_dists) else 0.0,
'score': r_dists[i] if r_dists and i < len(r_dists) else 0.0,
'metadata': r_metas[i] if r_metas and i < len(r_metas) else {},
}
)

View File

@@ -2,7 +2,6 @@ admins: []
api:
port: 5300
webhook_prefix: 'http://127.0.0.1:5300'
extra_webhook_prefix: ''
command:
enable: true
prefix:

View File

@@ -41,10 +41,7 @@
"runner": "local-agent"
},
"local-agent": {
"model": {
"primary": "",
"fallbacks": []
},
"model": "",
"max-round": 10,
"prompt": [
{

View File

@@ -93,6 +93,26 @@ stages:
type: knowledge-base-multi-selector
required: false
default: []
- name: max-tool-iterations
label:
en_US: Max Tool Iterations
zh_Hans: 最大工具调用轮次
description:
en_US: Maximum number of tool call iterations in a single agent loop to prevent runaway loops
zh_Hans: 单次 Agent 循环中工具调用的最大轮次,防止无限循环
type: integer
required: false
default: 16
- name: max-tool-result-chars
label:
en_US: Max Tool Result Length
zh_Hans: 工具返回最大字符数
description:
en_US: Maximum character length of a single tool call result, longer results will be truncated
zh_Hans: 单次工具调用返回结果的最大字符数,超出部分将被截断
type: integer
required: false
default: 8000
- name: tbox-app-api
label:
en_US: Tbox App API

View File

@@ -91,15 +91,14 @@ class TestWebhookDisplayPrefix:
def test_default_webhook_prefix(self):
"""Test that the default webhook display prefix is correctly set"""
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
# Should have the default value
assert cfg['api']['webhook_prefix'] == 'http://127.0.0.1:5300'
assert cfg['api']['extra_webhook_prefix'] == ''
def test_webhook_prefix_env_override(self):
"""Test overriding webhook_prefix via environment variable"""
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
# Set environment variable
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com:8080'
@@ -113,7 +112,7 @@ class TestWebhookDisplayPrefix:
def test_webhook_prefix_with_custom_domain(self):
"""Test webhook_prefix with custom domain"""
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
# Set to a custom domain
os.environ['API__WEBHOOK_PREFIX'] = 'https://bot.mycompany.com'
@@ -127,7 +126,7 @@ class TestWebhookDisplayPrefix:
def test_webhook_prefix_with_subdirectory(self):
"""Test webhook_prefix with subdirectory path"""
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300'}}
# Set to a URL with subdirectory
os.environ['API__WEBHOOK_PREFIX'] = 'https://example.com/langbot'
@@ -139,37 +138,6 @@ class TestWebhookDisplayPrefix:
# Cleanup
del os.environ['API__WEBHOOK_PREFIX']
def test_extra_webhook_prefix_default_empty(self):
"""Test that extra_webhook_prefix defaults to empty string"""
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
bot_uuid = 'test-bot-uuid'
webhook_prefix = cfg['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
extra_webhook_prefix = cfg['api'].get('extra_webhook_prefix', '')
webhook_url = f'/bots/{bot_uuid}'
assert f'{webhook_prefix}{webhook_url}' == 'http://127.0.0.1:5300/bots/test-bot-uuid'
# extra should be empty when not configured
assert extra_webhook_prefix == ''
def test_extra_webhook_prefix_env_override(self):
"""Test overriding extra_webhook_prefix via environment variable"""
cfg = {'api': {'port': 5300, 'webhook_prefix': 'http://127.0.0.1:5300', 'extra_webhook_prefix': ''}}
os.environ['API__EXTRA_WEBHOOK_PREFIX'] = 'https://extra.example.com'
result = _apply_env_overrides_to_config(cfg)
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'
bot_uuid = 'test-bot-uuid'
extra_prefix = result['api']['extra_webhook_prefix']
webhook_url = f'/bots/{bot_uuid}'
assert f'{extra_prefix}{webhook_url}' == 'https://extra.example.com/bots/test-bot-uuid'
# Cleanup
del os.environ['API__EXTRA_WEBHOOK_PREFIX']
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -194,7 +194,7 @@ def sample_query(sample_message_chain, sample_message_event, mock_adapter):
pipeline_config={
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'},
'local-agent': {'model': 'test-model-uuid', 'prompt': 'test-prompt'},
},
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
'trigger': {'misc': {'combine-quote-message': False}},
@@ -219,7 +219,7 @@ def sample_pipeline_config():
return {
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'},
'local-agent': {'model': 'test-model-uuid', 'prompt': 'test-prompt'},
},
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
'trigger': {'misc': {'combine-quote-message': False}},

570
uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -102,10 +102,5 @@
"typescript": "^5.8.3",
"typescript-eslint": "^8.31.1"
},
"packageManager": "pnpm@8.9.2+sha512.b9d35fe91b2a5854dadc43034a3e7b2e675fa4b56e20e8e09ef078fa553c18f8aed44051e7b36e8b8dd435f97eb0c44c4ff3b44fc7c6fa7d21e1fac17bbe661e",
"pnpm": {
"overrides": {
"minimatch": "3.1.3"
}
}
}
"packageManager": "pnpm@8.9.2+sha512.b9d35fe91b2a5854dadc43034a3e7b2e675fa4b56e20e8e09ef078fa553c18f8aed44051e7b36e8b8dd435f97eb0c44c4ff3b44fc7c6fa7d21e1fac17bbe661e"
}

34
web/pnpm-lock.yaml generated
View File

@@ -4,9 +4,6 @@ settings:
autoInstallPeers: true
excludeLinksFromLockfile: false
overrides:
minimatch: 3.1.3
dependencies:
'@dnd-kit/core':
specifier: ^6.3.1
@@ -348,7 +345,7 @@ packages:
dependencies:
'@eslint/object-schema': 2.1.7
debug: 4.4.3
minimatch: 3.1.3
minimatch: 3.1.2
transitivePeerDependencies:
- supports-color
dev: true
@@ -378,7 +375,7 @@ packages:
ignore: 5.3.2
import-fresh: 3.3.1
js-yaml: 4.1.1
minimatch: 3.1.3
minimatch: 3.1.2
strip-json-comments: 3.1.1
transitivePeerDependencies:
- supports-color
@@ -2263,7 +2260,7 @@ packages:
'@typescript-eslint/types': 8.54.0
'@typescript-eslint/visitor-keys': 8.54.0
debug: 4.4.3
minimatch: 3.1.3
minimatch: 9.0.5
semver: 7.7.3
tinyglobby: 0.2.15
ts-api-utils: 2.4.0(typescript@5.9.3)
@@ -2681,6 +2678,12 @@ packages:
concat-map: 0.0.1
dev: true
/brace-expansion@2.0.2:
resolution: {integrity: sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==}
dependencies:
balanced-match: 1.0.2
dev: true
/braces@3.0.3:
resolution: {integrity: sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==}
engines: {node: '>=8'}
@@ -3342,7 +3345,7 @@ packages:
hasown: 2.0.2
is-core-module: 2.16.1
is-glob: 4.0.3
minimatch: 3.1.3
minimatch: 3.1.2
object.fromentries: 2.0.8
object.groupby: 1.0.3
object.values: 1.2.1
@@ -3373,7 +3376,7 @@ packages:
hasown: 2.0.2
jsx-ast-utils: 3.3.5
language-tags: 1.0.9
minimatch: 3.1.3
minimatch: 3.1.2
object.fromentries: 2.0.8
safe-regex-test: 1.1.0
string.prototype.includes: 2.0.1
@@ -3425,7 +3428,7 @@ packages:
estraverse: 5.3.0
hasown: 2.0.2
jsx-ast-utils: 3.3.5
minimatch: 3.1.3
minimatch: 3.1.2
object.entries: 1.1.9
object.fromentries: 2.0.8
object.values: 1.2.1
@@ -3495,7 +3498,7 @@ packages:
is-glob: 4.0.3
json-stable-stringify-without-jsonify: 1.0.1
lodash.merge: 4.6.2
minimatch: 3.1.3
minimatch: 3.1.2
natural-compare: 1.4.0
optionator: 0.9.4
transitivePeerDependencies:
@@ -5110,12 +5113,19 @@ packages:
engines: {node: '>=18'}
dev: true
/minimatch@3.1.3:
resolution: {integrity: sha512-M2GCs7Vk83NxkUyQV1bkABc4yxgz9kILhHImZiBPAZ9ybuvCb0/H7lEl5XvIg3g+9d4eNotkZA5IWwYl0tibaA==}
/minimatch@3.1.2:
resolution: {integrity: sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==}
dependencies:
brace-expansion: 1.1.12
dev: true
/minimatch@9.0.5:
resolution: {integrity: sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==}
engines: {node: '>=16 || 14 >=14.17'}
dependencies:
brace-expansion: 2.0.2
dev: true
/minimist@1.2.8:
resolution: {integrity: sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==}
dev: true

View File

@@ -1,4 +1,4 @@
import React, { useEffect, useMemo, useState } from 'react';
import React, { useEffect, useState } from 'react';
import {
IChooseAdapterEntity,
IPipelineEntity,
@@ -113,73 +113,109 @@ export default function BotForm({
const [dynamicFormConfigList, setDynamicFormConfigList] = useState<
IDynamicFormItemSchema[]
>([]);
const [filteredDynamicFormConfigList, setFilteredDynamicFormConfigList] =
useState<IDynamicFormItemSchema[]>([]);
const [, setIsLoading] = useState<boolean>(false);
const [webhookUrl, setWebhookUrl] = useState<string>('');
const [extraWebhookUrl, setExtraWebhookUrl] = useState<string>('');
const webhookInputRef = React.useRef<HTMLInputElement>(null);
const [copied, setCopied] = useState<boolean>(false);
const [extraCopied, setExtraCopied] = useState<boolean>(false);
// Watch adapter and adapter_config for filtering
const currentAdapter = form.watch('adapter');
const currentAdapterConfig = form.watch('adapter_config');
// Derive the filtered config list via useMemo instead of useEffect+setState
// to avoid creating new array references that would cause DynamicFormComponent
// to re-subscribe its form.watch, re-emit values, and trigger an infinite loop.
// Only depend on the specific field we care about (enable-webhook) rather than
// the entire currentAdapterConfig object, which changes on every emission.
const enableWebhook = currentAdapterConfig?.['enable-webhook'];
const filteredDynamicFormConfigList = useMemo(() => {
if (currentAdapter === 'lark' && enableWebhook === false) {
// Hide encrypt-key field when webhook is disabled
return dynamicFormConfigList.filter(
(config) => config.name !== 'encrypt-key',
);
}
// For non-Lark adapters or when webhook is enabled/undefined, show all fields
return dynamicFormConfigList;
}, [currentAdapter, enableWebhook, dynamicFormConfigList]);
useEffect(() => {
setBotFormValues();
}, []);
// 复制到剪贴板的辅助函数
const copyToClipboard = (
text: string,
setStatus: React.Dispatch<React.SetStateAction<boolean>>,
) => {
if (navigator.clipboard && navigator.clipboard.writeText) {
navigator.clipboard
.writeText(text)
.then(() => {
setStatus(true);
setTimeout(() => setStatus(false), 2000);
})
.catch(() => {
// 降级创建临时textarea复制
fallbackCopy(text, setStatus);
});
// Filter dynamic form config list based on enable-webhook status for Lark adapter
useEffect(() => {
if (currentAdapter === 'lark') {
const enableWebhook = currentAdapterConfig?.['enable-webhook'];
if (enableWebhook === false) {
// Hide encrypt-key field when webhook is disabled
setFilteredDynamicFormConfigList(
dynamicFormConfigList.filter(
(config) => config.name !== 'encrypt-key',
),
);
} else {
// Show all fields when webhook is enabled or undefined
setFilteredDynamicFormConfigList(dynamicFormConfigList);
}
} else {
fallbackCopy(text, setStatus);
// For non-Lark adapters, show all fields
setFilteredDynamicFormConfigList(dynamicFormConfigList);
}
};
}, [currentAdapter, currentAdapterConfig, dynamicFormConfigList]);
const fallbackCopy = (
text: string,
setStatus: React.Dispatch<React.SetStateAction<boolean>>,
) => {
const textarea = document.createElement('textarea');
textarea.value = text;
textarea.style.position = 'fixed';
textarea.style.opacity = '0';
document.body.appendChild(textarea);
textarea.select();
const successful = document.execCommand('copy');
document.body.removeChild(textarea);
if (successful) {
setStatus(true);
setTimeout(() => setStatus(false), 2000);
// 复制到剪贴板的辅助函数 - 使用页面上的真实input元素
const copyToClipboard = () => {
console.log('[Copy] Attempting to copy from input element');
const inputElement = webhookInputRef.current;
if (!inputElement) {
console.error('[Copy] Input element not found');
return;
}
try {
// 确保input元素可见且未被禁用
inputElement.disabled = false;
inputElement.readOnly = false;
// 聚焦并选中所有文本
inputElement.focus();
inputElement.select();
// 尝试使用现代API
if (navigator.clipboard && navigator.clipboard.writeText) {
console.log(
'[Copy] Using Clipboard API with input value:',
inputElement.value,
);
navigator.clipboard
.writeText(inputElement.value)
.then(() => {
console.log('[Copy] Clipboard API success');
inputElement.blur(); // 取消选中
inputElement.readOnly = true;
setCopied(true);
setTimeout(() => setCopied(false), 2000);
})
.catch((err) => {
console.error(
'[Copy] Clipboard API failed, trying execCommand:',
err,
);
// 降级到execCommand
const successful = document.execCommand('copy');
console.log('[Copy] execCommand result:', successful);
inputElement.blur();
inputElement.readOnly = true;
if (successful) {
setCopied(true);
setTimeout(() => setCopied(false), 2000);
}
});
} else {
// 直接使用execCommand
console.log(
'[Copy] Using execCommand with input value:',
inputElement.value,
);
const successful = document.execCommand('copy');
console.log('[Copy] execCommand result:', successful);
inputElement.blur();
inputElement.readOnly = true;
if (successful) {
setCopied(true);
setTimeout(() => setCopied(false), 2000);
}
}
} catch (err) {
console.error('[Copy] Copy failed:', err);
inputElement.readOnly = true;
}
};
@@ -204,7 +240,6 @@ export default function BotForm({
} else {
setWebhookUrl('');
}
setExtraWebhookUrl(val.extra_webhook_full_url || '');
})
.catch((err) => {
toast.error(
@@ -214,7 +249,6 @@ export default function BotForm({
} else {
form.reset();
setWebhookUrl('');
setExtraWebhookUrl('');
}
});
}
@@ -287,20 +321,14 @@ export default function BotForm({
setAdapterNameToDynamicConfigMap(adapterNameToDynamicConfigMap);
}
async function getBotConfig(botId: string): Promise<
z.infer<typeof formSchema> & {
webhook_full_url?: string;
extra_webhook_full_url?: string;
}
> {
async function getBotConfig(
botId: string,
): Promise<z.infer<typeof formSchema> & { webhook_full_url?: string }> {
return new Promise((resolve, reject) => {
httpClient
.getBot(botId)
.then((res) => {
const bot = res.bot;
const runtimeValues = bot.adapter_runtime_values as
| Record<string, unknown>
| undefined;
resolve({
adapter: bot.adapter,
description: bot.description,
@@ -308,12 +336,10 @@ export default function BotForm({
adapter_config: bot.adapter_config,
enable: bot.enable ?? true,
use_pipeline_uuid: bot.use_pipeline_uuid ?? '',
webhook_full_url: runtimeValues?.webhook_full_url as
| string
| undefined,
extra_webhook_full_url: runtimeValues?.extra_webhook_full_url as
| string
| undefined,
webhook_full_url: bot.adapter_runtime_values
? ((bot.adapter_runtime_values as Record<string, unknown>)
.webhook_full_url as string)
: undefined,
});
})
.catch((err) => {
@@ -504,11 +530,13 @@ export default function BotForm({
{/* Webhook 地址显示(统一 Webhook 模式) */}
{webhookUrl &&
(currentAdapter !== 'lark' || enableWebhook !== false) && (
(currentAdapter !== 'lark' ||
currentAdapterConfig?.['enable-webhook'] !== false) && (
<FormItem>
<FormLabel>{t('bots.webhookUrl')}</FormLabel>
<div className="flex items-center gap-2">
<Input
ref={webhookInputRef}
value={webhookUrl}
readOnly
className="flex-1 bg-gray-50 dark:bg-gray-900"
@@ -521,7 +549,7 @@ export default function BotForm({
type="button"
variant="outline"
size="sm"
onClick={() => copyToClipboard(webhookUrl, setCopied)}
onClick={copyToClipboard}
>
{copied ? (
<Check className="h-4 w-4 text-green-600 mr-2" />
@@ -531,37 +559,8 @@ export default function BotForm({
{t('common.copy')}
</Button>
</div>
{extraWebhookUrl && (
<div className="flex items-center gap-2 mt-2">
<Input
value={extraWebhookUrl}
readOnly
className="flex-1 bg-gray-50 dark:bg-gray-900"
onClick={(e) => {
(e.target as HTMLInputElement).select();
}}
/>
<Button
type="button"
variant="outline"
size="sm"
onClick={() =>
copyToClipboard(extraWebhookUrl, setExtraCopied)
}
>
{extraCopied ? (
<Check className="h-4 w-4 text-green-600 mr-2" />
) : (
<Copy className="h-4 w-4 mr-2" />
)}
{t('common.copy')}
</Button>
</div>
)}
<p className="text-sm text-gray-500 mt-1">
{extraWebhookUrl
? t('bots.webhookUrlHintEither')
: t('bots.webhookUrlHint')}
{t('bots.webhookUrlHint')}
</p>
</FormItem>
)}
@@ -668,7 +667,7 @@ export default function BotForm({
</div>
<DynamicFormComponent
itemConfigList={filteredDynamicFormConfigList}
initialValues={currentAdapterConfig}
initialValues={form.watch('adapter_config')}
onSubmit={(values) => {
form.setValue('adapter_config', values);
}}

View File

@@ -34,35 +34,6 @@ export default function DynamicFormComponent({
const previousInitialValues = useRef(initialValues);
const { t } = useTranslation();
// Normalize a form value according to its field type.
// This ensures legacy/malformed data (e.g. a plain string for
// model-fallback-selector) is coerced to the expected shape
// so that downstream components never crash.
const normalizeFieldValue = (
item: IDynamicFormItemSchema,
value: unknown,
): unknown => {
if (item.type === 'model-fallback-selector') {
if (value != null && typeof value === 'object' && !Array.isArray(value)) {
const obj = value as Record<string, unknown>;
return {
primary: typeof obj.primary === 'string' ? obj.primary : '',
fallbacks: Array.isArray(obj.fallbacks)
? (obj.fallbacks as unknown[]).filter(
(v): v is string => typeof v === 'string',
)
: [],
};
}
// Legacy string format or any other unexpected type
return {
primary: typeof value === 'string' ? value : '',
fallbacks: [],
};
}
return value;
};
// 根据 itemConfigList 动态生成 zod schema
const formSchema = z.object(
itemConfigList.reduce(
@@ -145,10 +116,10 @@ export default function DynamicFormComponent({
resolver: zodResolver(formSchema),
defaultValues: itemConfigList.reduce((acc, item) => {
// 优先使用 initialValues如果没有则使用默认值
const rawValue = initialValues?.[item.name] ?? item.default;
const value = initialValues?.[item.name] ?? item.default;
return {
...acc,
[item.name]: normalizeFieldValue(item, rawValue),
[item.name]: value,
};
}, {} as FormValues),
});
@@ -173,8 +144,7 @@ export default function DynamicFormComponent({
// 合并默认值和初始值
const mergedValues = itemConfigList.reduce(
(acc, item) => {
const rawValue = initialValues[item.name] ?? item.default;
acc[item.name] = normalizeFieldValue(item, rawValue) as object;
acc[item.name] = initialValues[item.name] ?? item.default;
return acc;
},
{} as Record<string, object>,
@@ -211,15 +181,6 @@ export default function DynamicFormComponent({
);
onSubmitRef.current?.(initialFinalValues);
// Update previousInitialValues to the emitted snapshot so that if the
// parent writes these values back as new initialValues, the deep
// comparison in the initialValues-sync useEffect won't detect a change
// and won't trigger an infinite update loop.
previousInitialValues.current = initialFinalValues as Record<
string,
object
>;
const subscription = form.watch(() => {
const formValues = form.getValues();
const finalValues = itemConfigList.reduce(
@@ -230,7 +191,6 @@ export default function DynamicFormComponent({
{} as Record<string, object>,
);
onSubmitRef.current?.(finalValues);
previousInitialValues.current = finalValues as Record<string, object>;
});
return () => subscription.unsubscribe();
}, [form, itemConfigList]);

View File

@@ -348,31 +348,10 @@ export default function DynamicFormItemComponent({
{} as Record<string, LLMModel[]>,
);
const rawModelValue = field.value;
const modelValue: { primary: string; fallbacks: string[] } =
rawModelValue != null &&
typeof rawModelValue === 'object' &&
!Array.isArray(rawModelValue)
? {
primary:
typeof (rawModelValue as Record<string, unknown>).primary ===
'string'
? ((rawModelValue as Record<string, unknown>)
.primary as string)
: '',
fallbacks: Array.isArray(
(rawModelValue as Record<string, unknown>).fallbacks,
)
? (
(rawModelValue as Record<string, unknown>)
.fallbacks as unknown[]
).filter((v): v is string => typeof v === 'string')
: [],
}
: {
primary: typeof rawModelValue === 'string' ? rawModelValue : '',
fallbacks: [],
};
const modelValue = field.value as {
primary: string;
fallbacks: string[];
};
const renderModelSelect = (
value: string,

View File

@@ -1,4 +1,4 @@
import { useEffect, useMemo, useState } from 'react';
import { useEffect, useState } from 'react';
import Link from 'next/link';
import { useForm } from 'react-hook-form';
import { zodResolver } from '@hookform/resolvers/zod';
@@ -242,17 +242,11 @@ export default function KBForm({
};
// Convert creation schema to dynamic form items (same as ExternalKBForm)
// Memoize to avoid regenerating UUIDs on every render, which would cause
// DynamicFormComponent's useEffect to re-fire and trigger an infinite loop.
const configFormItems = useMemo(
() => parseCreationSchema(selectedEngine?.creation_schema),
[selectedEngine?.creation_schema],
);
const configFormItems = parseCreationSchema(selectedEngine?.creation_schema);
// Convert retrieval schema to dynamic form items
const retrievalFormItems = useMemo(
() => parseCreationSchema(selectedEngine?.retrieval_schema),
[selectedEngine?.retrieval_schema],
const retrievalFormItems = parseCreationSchema(
selectedEngine?.retrieval_schema,
);
// Show loading state

View File

@@ -284,8 +284,6 @@ const enUS = {
webhookUrlCopied: 'Webhook URL copied',
webhookUrlHint:
'Click the input to select all, then press Ctrl+C (Mac: Cmd+C) to copy, or click the button',
webhookUrlHintEither:
'Use either of the two URLs above in your platform configuration',
logLevel: 'Log Level',
allLevels: 'All Levels',
selectLevel: 'Select Level',

View File

@@ -289,8 +289,6 @@
webhookUrlCopied: 'Webhook URL をコピーしました',
webhookUrlHint:
'入力ボックスをクリックして全選択し、Ctrl+C (Mac: Cmd+C) でコピーするか、右側のボタンをクリックしてください',
webhookUrlHintEither:
'上記の2つのURLのいずれかをプラットフォーム設定に使用してください',
logLevel: 'ログレベル',
allLevels: 'すべてのレベル',
selectLevel: 'レベルを選択',

View File

@@ -273,7 +273,6 @@ const zhHans = {
webhookUrlCopied: 'Webhook 地址已复制',
webhookUrlHint:
'点击输入框自动全选,然后按 Ctrl+C (Mac: Cmd+C) 复制,或点击右侧按钮',
webhookUrlHintEither: '以上两个地址任选其一填入平台配置即可',
logLevel: '日志级别',
allLevels: '全部级别',
selectLevel: '选择级别',

View File

@@ -272,7 +272,6 @@ const zhHant = {
webhookUrlCopied: 'Webhook 位址已複製',
webhookUrlHint:
'點擊輸入框自動全選,然後按 Ctrl+C (Mac: Cmd+C) 複製,或點擊右側按鈕',
webhookUrlHintEither: '以上兩個地址任選其一填入平台配置即可',
logLevel: '日誌級別',
allLevels: '全部級別',
selectLevel: '選擇級別',