mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-18 03:34:20 +00:00
f8010a20eb
* feat(monitoring): link feedback to LangBot message ID and add feedback export - Add pipeline→adapter notification hook so monitoring message ID is passed back to WecomBotAdapter after creation - Store stream_id→monitoring_message_id mapping with 10-min TTL cleanup - Replace feedback record stream_id with LangBot monitoring message ID so feedback can be linked to actual message records - Rename streamId label to "Related Query ID" in all 7 i18n locales - Remove non-functional message ID jump button from FeedbackList - Add feedback export option to ExportDropdown (backend already implemented) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat(monitoring): add combined refresh handler for monitoring and feedback data * fix(wecombot): improve stream ID mapping and error logging in WecomBotAdapter * feat(lark): add monitoring message ID mapping for feedback correlation * feat(lark): rename monitoring message ID mappings for clarity and consistency feat(feedback): add button to view conversation for feedback items * feat(bot-session-monitor): add feedback handling for bot messages with visual indicators * feat(bot-session-monitor): enhance feedback display with hover content for like/dislike indicators * fix(dingtalk): use voice recognition text instead of raw audio binary When DingTalk sends a voice message to the bot, the callback JSON contains a 'recognition' field with the speech-to-text result (powered by Qwen). Previously, LangBot only extracted the 'downloadCode' to download the raw audio binary and passed it as 'file_base64' to LLM APIs, which caused 400 errors since most models don't support this content type. This patch: - Extracts the 'recognition' field from DingTalk audio message content - Uses it as plain text input to the LLM instead of raw audio - Falls back to audio binary only when no recognition text is available - Fixes duplicate text issue for audio messages with recognition Fixes voice messages returning 'Request failed' on all LLM models. * fix: add filereader for dingtalk,lark (#2122) * fix: add filereader for dingtalk * feat: add lark * feat: update uv.lock * chore: update version to 4.9.6 in pyproject.toml, __init__.py, and uv.lock * fix: update langbot-plugin version to 0.3.8 * fix: update langbot-plugin version to 0.3.8 * fix(wecombot): extend StreamSession TTL for feedback sessions to prevent context data loss StreamSessionManager.cleanup() removes sessions after 60s TTL, but feedback events (like → cancel → dislike) can arrive later. When the session expires before the dislike event, all context fields (session_id, user_id, message_id, stream_id) are lost because get_session_by_feedback_id() returns None. Fix: Sessions with registered feedback_ids now use a 10-minute TTL, aligned with the adapter's _stream_to_monitoring_msg TTL in wecombot.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: 6mvp6 <13727783693@163.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: fdc310 <2213070223@qq.com> Co-authored-by: haiyangbg <zhouhaiyangaa@gmail.com> Co-authored-by: Guanchao Wang <wangcham233@gmail.com> Co-authored-by: Rock Chin <1010553892@qq.com>
1160 lines
46 KiB
Python
1160 lines
46 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
import xml.etree.ElementTree as ET
|
|
from dataclasses import dataclass, field
|
|
import re
|
|
from typing import Any, Callable, Optional, Tuple
|
|
from urllib.parse import unquote
|
|
|
|
import httpx
|
|
from Crypto.Cipher import AES
|
|
from quart import Quart, request, Response, jsonify
|
|
|
|
from langbot.libs.wecom_ai_bot_api import wecombotevent
|
|
from langbot.libs.wecom_ai_bot_api.WXBizMsgCrypt3 import WXBizMsgCrypt
|
|
from langbot.pkg.platform.logger import EventLogger
|
|
|
|
|
|
@dataclass
|
|
class StreamChunk:
|
|
"""描述单次推送给企业微信的流式片段。"""
|
|
|
|
# 需要返回给企业微信的文本内容
|
|
content: str
|
|
|
|
# 标记是否为最终片段,对应企业微信协议里的 finish 字段
|
|
is_final: bool = False
|
|
|
|
# 预留额外元信息,未来支持多模态扩展时可使用
|
|
meta: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class StreamSession:
|
|
"""维护一次企业微信流式会话的上下文。"""
|
|
|
|
# 企业微信要求的 stream_id,用于标识后续刷新请求
|
|
stream_id: str
|
|
|
|
# 原始消息的 msgid,便于与流水线消息对应
|
|
msg_id: str
|
|
|
|
# 群聊会话标识(单聊时为空)
|
|
chat_id: Optional[str]
|
|
|
|
# 触发消息的发送者
|
|
user_id: Optional[str]
|
|
|
|
# 会话创建时间
|
|
created_at: float = field(default_factory=time.time)
|
|
|
|
# 最近一次被访问的时间,cleanup 依据该值判断过期
|
|
last_access: float = field(default_factory=time.time)
|
|
|
|
# 将流水线增量结果缓存到队列,刷新请求逐条消费
|
|
queue: asyncio.Queue = field(default_factory=asyncio.Queue)
|
|
|
|
# 是否已经完成(收到最终片段)
|
|
finished: bool = False
|
|
|
|
# 缓存最近一次片段,处理重试或超时兜底
|
|
last_chunk: Optional[StreamChunk] = None
|
|
|
|
# 反馈 ID,用于接收用户点赞/点踩反馈
|
|
feedback_id: Optional[str] = None
|
|
|
|
|
|
class StreamSessionManager:
|
|
"""管理 stream 会话的生命周期,并负责队列的生产消费。"""
|
|
|
|
# Sessions with registered feedback_ids use a longer TTL to survive the
|
|
# full like → cancel → dislike feedback flow. Must align with the adapter's
|
|
# _stream_to_monitoring_msg TTL (wecombot.py).
|
|
_FEEDBACK_SESSION_TTL = 600 # 10 minutes
|
|
|
|
def __init__(self, logger: EventLogger, ttl: int = 60) -> None:
|
|
self.logger = logger
|
|
|
|
self.ttl = ttl # 超时时间(秒),超过该时间未被访问的会话会被清理由 cleanup
|
|
self._sessions: dict[str, StreamSession] = {} # stream_id -> StreamSession 映射
|
|
self._msg_index: dict[str, str] = {} # msgid -> stream_id 映射,便于流水线根据消息 ID 找到会话
|
|
self._feedback_index: dict[str, str] = {} # feedback_id -> stream_id 映射
|
|
|
|
def get_stream_id_by_msg(self, msg_id: str) -> Optional[str]:
|
|
if not msg_id:
|
|
return None
|
|
return self._msg_index.get(msg_id)
|
|
|
|
def get_session(self, stream_id: str) -> Optional[StreamSession]:
|
|
return self._sessions.get(stream_id)
|
|
|
|
def get_session_by_feedback_id(self, feedback_id: str) -> Optional[StreamSession]:
|
|
"""根据 feedback_id 查找会话。
|
|
|
|
Args:
|
|
feedback_id: 企业微信反馈事件中的反馈 ID。
|
|
|
|
Returns:
|
|
Optional[StreamSession]: 找到的会话实例,未找到返回 None。
|
|
"""
|
|
if not feedback_id:
|
|
return None
|
|
stream_id = self._feedback_index.get(feedback_id)
|
|
if stream_id:
|
|
return self._sessions.get(stream_id)
|
|
return None
|
|
|
|
def register_feedback_id(self, stream_id: str, feedback_id: str) -> None:
|
|
"""注册 feedback_id 与 stream_id 的映射。
|
|
|
|
Args:
|
|
stream_id: 企业微信流式会话 ID。
|
|
feedback_id: 反馈 ID。
|
|
"""
|
|
if feedback_id and stream_id:
|
|
self._feedback_index[feedback_id] = stream_id
|
|
|
|
def create_or_get(self, msg_json: dict[str, Any]) -> tuple[StreamSession, bool]:
|
|
"""根据企业微信回调创建或获取会话。
|
|
|
|
Args:
|
|
msg_json: 企业微信解密后的回调 JSON。
|
|
|
|
Returns:
|
|
Tuple[StreamSession, bool]: `StreamSession` 为会话实例,`bool` 指示是否为新建会话。
|
|
|
|
Example:
|
|
在首次回调中调用,得到 `is_new=True` 后再触发流水线。
|
|
"""
|
|
msg_id = msg_json.get('msgid', '')
|
|
if msg_id and msg_id in self._msg_index:
|
|
stream_id = self._msg_index[msg_id]
|
|
session = self._sessions.get(stream_id)
|
|
if session:
|
|
session.last_access = time.time()
|
|
return session, False
|
|
|
|
stream_id = str(uuid.uuid4())
|
|
session = StreamSession(
|
|
stream_id=stream_id,
|
|
msg_id=msg_id,
|
|
chat_id=msg_json.get('chatid'),
|
|
user_id=msg_json.get('from', {}).get('userid'),
|
|
)
|
|
|
|
if msg_id:
|
|
self._msg_index[msg_id] = stream_id
|
|
self._sessions[stream_id] = session
|
|
return session, True
|
|
|
|
async def publish(self, stream_id: str, chunk: StreamChunk) -> bool:
|
|
"""向 stream 队列写入新的增量片段。
|
|
|
|
Args:
|
|
stream_id: 企业微信分配的流式会话 ID。
|
|
chunk: 待发送的增量片段。
|
|
|
|
Returns:
|
|
bool: 当流式队列存在并成功入队时返回 True。
|
|
|
|
Example:
|
|
在收到模型增量后调用 `await manager.publish('sid', StreamChunk('hello'))`。
|
|
"""
|
|
session = self._sessions.get(stream_id)
|
|
if not session:
|
|
return False
|
|
|
|
session.last_access = time.time()
|
|
session.last_chunk = chunk
|
|
|
|
try:
|
|
session.queue.put_nowait(chunk)
|
|
except asyncio.QueueFull:
|
|
# 默认无界队列,此处兜底防御
|
|
await session.queue.put(chunk)
|
|
|
|
if chunk.is_final:
|
|
session.finished = True
|
|
|
|
return True
|
|
|
|
async def consume(self, stream_id: str, timeout: float = 0.5) -> Optional[StreamChunk]:
|
|
"""从队列中取出一个片段,若超时返回 None。
|
|
|
|
Args:
|
|
stream_id: 企业微信流式会话 ID。
|
|
timeout: 取片段的最长等待时间(秒)。
|
|
|
|
Returns:
|
|
Optional[StreamChunk]: 成功时返回片段,超时或会话不存在时返回 None。
|
|
|
|
Example:
|
|
企业微信刷新到达时调用,若队列有数据则立即返回 `StreamChunk`。
|
|
"""
|
|
session = self._sessions.get(stream_id)
|
|
if not session:
|
|
return None
|
|
|
|
session.last_access = time.time()
|
|
|
|
try:
|
|
chunk = await asyncio.wait_for(session.queue.get(), timeout)
|
|
session.last_access = time.time()
|
|
if chunk.is_final:
|
|
session.finished = True
|
|
return chunk
|
|
except asyncio.TimeoutError:
|
|
if session.finished and session.last_chunk:
|
|
return session.last_chunk
|
|
return None
|
|
|
|
def mark_finished(self, stream_id: str) -> None:
|
|
session = self._sessions.get(stream_id)
|
|
if session:
|
|
session.finished = True
|
|
session.last_access = time.time()
|
|
|
|
def cleanup(self) -> None:
|
|
"""定期清理过期会话,防止队列与映射无上限累积。
|
|
|
|
已注册 feedback_id 的会话使用更长的 TTL,确保用户在点赞/取消/点踩流程中
|
|
不会因为 session 被提前清除而丢失上下文信息。
|
|
"""
|
|
now = time.time()
|
|
expired: list[str] = []
|
|
for stream_id, session in self._sessions.items():
|
|
# Sessions with registered feedback_ids use a longer TTL
|
|
effective_ttl = self._FEEDBACK_SESSION_TTL if session.feedback_id else self.ttl
|
|
if now - session.last_access > effective_ttl:
|
|
expired.append(stream_id)
|
|
|
|
for stream_id in expired:
|
|
session = self._sessions.pop(stream_id, None)
|
|
if not session:
|
|
continue
|
|
msg_id = session.msg_id
|
|
if msg_id and self._msg_index.get(msg_id) == stream_id:
|
|
self._msg_index.pop(msg_id, None)
|
|
# Clean up feedback index for expired sessions
|
|
if session.feedback_id:
|
|
self._feedback_index.pop(session.feedback_id, None)
|
|
|
|
|
|
def _decrypt_file(encrypted_data: bytes, aes_key_str: str) -> bytes:
|
|
"""Decrypt AES-256-CBC encrypted file data.
|
|
|
|
Aligned with the official WeCom AI Bot Python SDK (crypto_utils.py).
|
|
|
|
Args:
|
|
encrypted_data: The raw encrypted bytes.
|
|
aes_key_str: Base64-encoded AES key (may lack padding).
|
|
|
|
Returns:
|
|
Decrypted bytes with PKCS#7 padding removed.
|
|
"""
|
|
if not encrypted_data:
|
|
raise ValueError('encrypted_data is empty')
|
|
if not aes_key_str:
|
|
raise ValueError('aes_key is empty')
|
|
|
|
# Python's base64.b64decode requires proper padding (length % 4 == 0).
|
|
# Node.js Buffer.from tolerates missing '=', so we must pad manually.
|
|
remainder = len(aes_key_str) % 4
|
|
if remainder != 0:
|
|
aes_key_str = aes_key_str + '=' * (4 - remainder)
|
|
key = base64.b64decode(aes_key_str)
|
|
|
|
iv = key[:16]
|
|
|
|
cipher = AES.new(key, AES.MODE_CBC, iv)
|
|
|
|
# Ensure encrypted data is aligned to AES block size (16 bytes).
|
|
# Node.js setAutoPadding(false) silently handles unaligned data,
|
|
# but PyCryptodome will raise an error.
|
|
block_size = 16
|
|
data_remainder = len(encrypted_data) % block_size
|
|
if data_remainder != 0:
|
|
encrypted_data = encrypted_data + b'\x00' * (block_size - data_remainder)
|
|
|
|
decrypted = cipher.decrypt(encrypted_data)
|
|
|
|
# Remove PKCS#7 padding with validation
|
|
if len(decrypted) == 0:
|
|
raise ValueError('Decrypted data is empty')
|
|
|
|
pad_len = decrypted[-1]
|
|
if pad_len < 1 or pad_len > 32 or pad_len > len(decrypted):
|
|
raise ValueError(f'Invalid PKCS#7 padding value: {pad_len}')
|
|
|
|
# Verify all padding bytes are consistent
|
|
for i in range(len(decrypted) - pad_len, len(decrypted)):
|
|
if decrypted[i] != pad_len:
|
|
raise ValueError('Invalid PKCS#7 padding: padding bytes mismatch')
|
|
|
|
return decrypted[: len(decrypted) - pad_len]
|
|
|
|
|
|
def _extract_filename(content_disposition: str) -> Optional[str]:
|
|
"""Extract filename from a Content-Disposition header value."""
|
|
if not content_disposition:
|
|
return None
|
|
# RFC 5987: filename*=UTF-8''xxx
|
|
utf8_match = re.search(r"filename\*=UTF-8''([^;\s]+)", content_disposition, re.IGNORECASE)
|
|
if utf8_match:
|
|
return unquote(utf8_match.group(1))
|
|
# Standard: filename="xxx" or filename=xxx
|
|
match = re.search(r'filename="?([^";\s]+)"?', content_disposition, re.IGNORECASE)
|
|
if match:
|
|
return unquote(match.group(1))
|
|
return None
|
|
|
|
|
|
def _bytes_to_data_uri(data: bytes) -> str:
|
|
"""Convert raw bytes to a data URI with auto-detected MIME type."""
|
|
if data.startswith(b'\xff\xd8'):
|
|
mime_type = 'image/jpeg'
|
|
elif data.startswith(b'\x89PNG'):
|
|
mime_type = 'image/png'
|
|
elif data.startswith((b'GIF87a', b'GIF89a')):
|
|
mime_type = 'image/gif'
|
|
elif data.startswith(b'BM'):
|
|
mime_type = 'image/bmp'
|
|
elif data.startswith(b'II*\x00') or data.startswith(b'MM\x00*'):
|
|
mime_type = 'image/tiff'
|
|
elif data[:4] == b'%PDF':
|
|
mime_type = 'application/pdf'
|
|
elif data[:4] == b'PK\x03\x04':
|
|
mime_type = 'application/zip'
|
|
else:
|
|
mime_type = 'application/octet-stream'
|
|
|
|
base64_str = base64.b64encode(data).decode('utf-8')
|
|
return f'data:{mime_type};base64,{base64_str}'
|
|
|
|
|
|
async def download_encrypted_file(
|
|
download_url: str, aes_key: str, logger: EventLogger
|
|
) -> Tuple[Optional[bytes], Optional[str]]:
|
|
"""Download an AES-encrypted file from WeChat Work and decrypt it.
|
|
|
|
Args:
|
|
download_url: The encrypted file download URL.
|
|
aes_key: The AES key for decryption (base64-encoded, per-message aeskey
|
|
or platform EncodingAESKey).
|
|
logger: Logger instance.
|
|
|
|
Returns:
|
|
A tuple of (decrypted_bytes, filename) or (None, None) on failure.
|
|
"""
|
|
if not download_url:
|
|
return None, None
|
|
if not aes_key:
|
|
await logger.error('download_encrypted_file: aes_key is empty, cannot decrypt')
|
|
return None, None
|
|
|
|
filename: Optional[str] = None
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.get(download_url)
|
|
if response.status_code != 200:
|
|
await logger.error(f'Failed to download file (HTTP {response.status_code}): {response.text[:200]}')
|
|
return None, None
|
|
encrypted_bytes = response.content
|
|
filename = _extract_filename(response.headers.get('content-disposition', ''))
|
|
except Exception:
|
|
await logger.error(f'Failed to download file: {traceback.format_exc()}')
|
|
return None, None
|
|
|
|
try:
|
|
decrypted = _decrypt_file(encrypted_bytes, aes_key)
|
|
return decrypted, filename
|
|
except Exception:
|
|
await logger.error(f'Failed to decrypt file: {traceback.format_exc()}')
|
|
return None, None
|
|
|
|
|
|
async def parse_wecom_bot_message(
|
|
msg_json: dict[str, Any], encoding_aes_key: str, logger: EventLogger
|
|
) -> dict[str, Any]:
|
|
"""Parse a decrypted WeChat Work AI Bot message JSON into a unified message dict.
|
|
|
|
This is the shared message parsing logic used by both webhook and WebSocket modes.
|
|
|
|
Args:
|
|
msg_json: The decrypted message JSON from WeChat Work.
|
|
encoding_aes_key: AES key for file decryption.
|
|
logger: Logger instance.
|
|
|
|
Returns:
|
|
A dict suitable for constructing a WecomBotEvent.
|
|
"""
|
|
message_data: dict[str, Any] = {}
|
|
|
|
msg_type = msg_json.get('msgtype', '')
|
|
if msg_type:
|
|
message_data['msgtype'] = msg_type
|
|
|
|
if msg_json.get('chattype', '') == 'single':
|
|
message_data['type'] = 'single'
|
|
elif msg_json.get('chattype', '') == 'group':
|
|
message_data['type'] = 'group'
|
|
|
|
max_inline_file_size = 5 * 1024 * 1024
|
|
|
|
async def _safe_download(url: str, per_msg_aeskey: str = '') -> Tuple[Optional[bytes], Optional[str]]:
|
|
"""Download and decrypt a file, preferring per-message aeskey over platform key."""
|
|
if not url:
|
|
return None, None
|
|
key = per_msg_aeskey or encoding_aes_key
|
|
if not key:
|
|
await logger.warning('No AES key available for file decryption, skipping download')
|
|
return None, None
|
|
return await download_encrypted_file(url, key, logger)
|
|
|
|
async def _safe_download_as_data_uri(url: str, per_msg_aeskey: str = '') -> Optional[str]:
|
|
"""Download, decrypt, and convert to data URI for backward compatibility."""
|
|
data, _filename = await _safe_download(url, per_msg_aeskey)
|
|
if data:
|
|
return _bytes_to_data_uri(data)
|
|
return None
|
|
|
|
if msg_type == 'text':
|
|
message_data['content'] = msg_json.get('text', {}).get('content')
|
|
elif msg_type == 'markdown':
|
|
message_data['content'] = msg_json.get('markdown', {}).get('content') or msg_json.get('text', {}).get(
|
|
'content', ''
|
|
)
|
|
elif msg_type == 'image':
|
|
image_info = msg_json.get('image', {})
|
|
picurl = image_info.get('url', '')
|
|
per_msg_aeskey = image_info.get('aeskey', '')
|
|
base64_data = await _safe_download_as_data_uri(picurl, per_msg_aeskey)
|
|
if base64_data:
|
|
message_data['picurl'] = base64_data
|
|
message_data['images'] = [base64_data]
|
|
elif msg_type == 'voice':
|
|
voice_info = msg_json.get('voice', {}) or {}
|
|
download_url = voice_info.get('url')
|
|
per_msg_aeskey = voice_info.get('aeskey', '')
|
|
message_data['voice'] = {
|
|
'url': download_url,
|
|
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
|
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
|
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
|
}
|
|
if voice_info.get('content'):
|
|
message_data['content'] = voice_info.get('content')
|
|
# if (message_data['voice'].get('filesize') or 0) <= max_inline_file_size:
|
|
# voice_base64 = await _safe_download_as_data_uri(download_url, per_msg_aeskey)
|
|
# if voice_base64:
|
|
# message_data['voice']['base64'] = voice_base64
|
|
elif msg_type == 'video':
|
|
video_info = msg_json.get('video', {}) or {}
|
|
download_url = video_info.get('url')
|
|
per_msg_aeskey = video_info.get('aeskey', '')
|
|
video_data = {
|
|
'url': download_url,
|
|
'filesize': video_info.get('filesize') or video_info.get('size'),
|
|
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
|
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
|
'filename': video_info.get('filename') or video_info.get('name'),
|
|
}
|
|
# if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
|
# video_base64 = await _safe_download_as_data_uri(download_url, per_msg_aeskey)
|
|
# if video_base64:
|
|
# video_data['base64'] = video_base64
|
|
# 应为需要解密,但是目前暂时不能下载到内部进行解密,所以先将下载链接拼接aeskey返回给用户,由插件去处理该链接的下载和解密逻辑
|
|
video_data['download_url'] = download_url + f'?aeskey={per_msg_aeskey}'
|
|
message_data['video'] = video_data
|
|
elif msg_type == 'file':
|
|
file_info = msg_json.get('file', {}) or {}
|
|
download_url = file_info.get('url') or file_info.get('fileurl')
|
|
per_msg_aeskey = file_info.get('aeskey', '')
|
|
file_data = {
|
|
'filename': file_info.get('filename') or file_info.get('name'),
|
|
'filesize': file_info.get('filesize') or file_info.get('size'),
|
|
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
|
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
|
'download_url': download_url,
|
|
'extra': file_info,
|
|
}
|
|
# if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
|
# file_bytes, dl_filename = await _safe_download(download_url, per_msg_aeskey)
|
|
# if file_bytes:
|
|
# file_data['base64'] = _bytes_to_data_uri(file_bytes)
|
|
# if dl_filename and not file_data.get('filename'):
|
|
# file_data['filename'] = dl_filename
|
|
|
|
# 应为需要解密,但是目前暂时不能下载到内部进行解密,所以先将下载链接拼接aeskey返回给用户,由插件去处理该链接的下载和解密逻辑
|
|
file_data['download_url'] = download_url + f'?aeskey={per_msg_aeskey}'
|
|
message_data['file'] = file_data
|
|
elif msg_type == 'link':
|
|
message_data['link'] = msg_json.get('link', {})
|
|
if not message_data.get('content'):
|
|
title = message_data['link'].get('title', '')
|
|
desc = message_data['link'].get('description') or message_data['link'].get('digest', '')
|
|
message_data['content'] = '\n'.join(filter(None, [title, desc]))
|
|
elif msg_type == 'mixed':
|
|
items = msg_json.get('mixed', {}).get('msg_item', [])
|
|
texts = []
|
|
images = []
|
|
files = []
|
|
voices = []
|
|
videos = []
|
|
links = []
|
|
for item in items:
|
|
item_type = item.get('msgtype')
|
|
if item_type == 'text':
|
|
texts.append(item.get('text', {}).get('content', ''))
|
|
elif item_type == 'image':
|
|
img_info = item.get('image', {})
|
|
img_url = img_info.get('url')
|
|
img_aeskey = img_info.get('aeskey', '')
|
|
base64_data = await _safe_download_as_data_uri(img_url, img_aeskey)
|
|
if base64_data:
|
|
images.append(base64_data)
|
|
elif item_type == 'file':
|
|
file_info = item.get('file', {}) or {}
|
|
download_url = file_info.get('url') or file_info.get('fileurl')
|
|
item_aeskey = file_info.get('aeskey', '')
|
|
file_data = {
|
|
'filename': file_info.get('filename') or file_info.get('name'),
|
|
'filesize': file_info.get('filesize') or file_info.get('size'),
|
|
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
|
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
|
'download_url': download_url,
|
|
'extra': file_info,
|
|
}
|
|
if (file_data.get('filesize') or 0) <= max_inline_file_size:
|
|
file_bytes, dl_filename = await _safe_download(download_url, item_aeskey)
|
|
if file_bytes:
|
|
file_data['base64'] = _bytes_to_data_uri(file_bytes)
|
|
if dl_filename and not file_data.get('filename'):
|
|
file_data['filename'] = dl_filename
|
|
files.append(file_data)
|
|
elif item_type == 'voice':
|
|
voice_info = item.get('voice', {}) or {}
|
|
download_url = voice_info.get('url')
|
|
item_aeskey = voice_info.get('aeskey', '')
|
|
voice_data = {
|
|
'url': download_url,
|
|
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
|
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
|
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
|
}
|
|
if voice_info.get('content'):
|
|
texts.append(voice_info.get('content'))
|
|
if (voice_data.get('filesize') or 0) <= max_inline_file_size:
|
|
voice_base64 = await _safe_download_as_data_uri(download_url, item_aeskey)
|
|
if voice_base64:
|
|
voice_data['base64'] = voice_base64
|
|
voices.append(voice_data)
|
|
elif item_type == 'video':
|
|
video_info = item.get('video', {}) or {}
|
|
download_url = video_info.get('url')
|
|
item_aeskey = video_info.get('aeskey', '')
|
|
video_data = {
|
|
'url': download_url,
|
|
'filesize': video_info.get('filesize') or video_info.get('size'),
|
|
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
|
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
|
'filename': video_info.get('filename') or video_info.get('name'),
|
|
}
|
|
if (video_data.get('filesize') or 0) <= max_inline_file_size:
|
|
video_base64 = await _safe_download_as_data_uri(download_url, item_aeskey)
|
|
if video_base64:
|
|
video_data['base64'] = video_base64
|
|
videos.append(video_data)
|
|
elif item_type == 'link':
|
|
links.append(item.get('link', {}))
|
|
|
|
if texts:
|
|
message_data['content'] = ' '.join(texts)
|
|
if images:
|
|
message_data['images'] = images
|
|
message_data['picurl'] = images[0]
|
|
if files:
|
|
message_data['files'] = files
|
|
message_data['file'] = files[0]
|
|
if voices:
|
|
message_data['voices'] = voices
|
|
message_data['voice'] = voices[0]
|
|
if videos:
|
|
message_data['videos'] = videos
|
|
message_data['video'] = videos[0]
|
|
if links:
|
|
message_data['link'] = links[0]
|
|
if items:
|
|
message_data['attachments'] = items
|
|
else:
|
|
message_data['raw_msg'] = msg_json
|
|
|
|
from_info = msg_json.get('from', {})
|
|
message_data['userid'] = from_info.get('userid', '')
|
|
message_data['username'] = from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
|
|
|
|
if msg_json.get('chattype', '') == 'group':
|
|
message_data['chatid'] = msg_json.get('chatid', '')
|
|
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
|
|
|
|
message_data['msgid'] = msg_json.get('msgid', '')
|
|
|
|
if msg_json.get('aibotid'):
|
|
message_data['aibotid'] = msg_json.get('aibotid', '')
|
|
|
|
# Handle quote (referenced message) - important for group chat file references
|
|
quote_info = msg_json.get('quote')
|
|
if quote_info:
|
|
quote_data: dict[str, Any] = {}
|
|
quote_type = quote_info.get('msgtype', '')
|
|
quote_data['msgtype'] = quote_type
|
|
|
|
if quote_type == 'text':
|
|
quote_data['content'] = quote_info.get('text', {}).get('content', '')
|
|
elif quote_type == 'image':
|
|
img_info = quote_info.get('image', {})
|
|
img_url = img_info.get('url', '')
|
|
img_aeskey = img_info.get('aeskey', '')
|
|
base64_data = await _safe_download_as_data_uri(img_url, img_aeskey)
|
|
if base64_data:
|
|
quote_data['picurl'] = base64_data
|
|
quote_data['images'] = [base64_data]
|
|
elif quote_type == 'file':
|
|
file_info = quote_info.get('file', {}) or {}
|
|
download_url = file_info.get('url') or file_info.get('fileurl')
|
|
item_aeskey = file_info.get('aeskey', '')
|
|
file_data = {
|
|
'filename': file_info.get('filename') or file_info.get('name'),
|
|
'filesize': file_info.get('filesize') or file_info.get('size'),
|
|
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
|
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
|
'download_url': download_url,
|
|
'extra': file_info,
|
|
}
|
|
# Same as private chat: append aeskey to download_url for plugin processing
|
|
if download_url and item_aeskey:
|
|
file_data['download_url'] = download_url + f'?aeskey={item_aeskey}'
|
|
quote_data['file'] = file_data
|
|
elif quote_type == 'voice':
|
|
voice_info = quote_info.get('voice', {}) or {}
|
|
download_url = voice_info.get('url')
|
|
item_aeskey = voice_info.get('aeskey', '')
|
|
voice_data = {
|
|
'url': download_url,
|
|
'md5sum': voice_info.get('md5sum') or voice_info.get('md5'),
|
|
'filesize': voice_info.get('filesize') or voice_info.get('size'),
|
|
'sdkfileid': voice_info.get('sdkfileid') or voice_info.get('fileid'),
|
|
}
|
|
if voice_info.get('content'):
|
|
quote_data['content'] = voice_info.get('content')
|
|
# Same as private chat: append aeskey to url for plugin processing
|
|
if download_url and item_aeskey:
|
|
voice_data['url'] = download_url + f'?aeskey={item_aeskey}'
|
|
quote_data['voice'] = voice_data
|
|
elif quote_type == 'video':
|
|
video_info = quote_info.get('video', {}) or {}
|
|
download_url = video_info.get('url')
|
|
item_aeskey = video_info.get('aeskey', '')
|
|
video_data = {
|
|
'url': download_url,
|
|
'filesize': video_info.get('filesize') or video_info.get('size'),
|
|
'sdkfileid': video_info.get('sdkfileid') or video_info.get('fileid'),
|
|
'md5sum': video_info.get('md5sum') or video_info.get('md5'),
|
|
'filename': video_info.get('filename') or video_info.get('name'),
|
|
}
|
|
# Same as private chat: append aeskey to download_url for plugin processing
|
|
if download_url and item_aeskey:
|
|
video_data['download_url'] = download_url + f'?aeskey={item_aeskey}'
|
|
quote_data['video'] = video_data
|
|
elif quote_type == 'link':
|
|
quote_data['link'] = quote_info.get('link', {})
|
|
link = quote_data['link']
|
|
title = link.get('title', '')
|
|
desc = link.get('description') or link.get('digest', '')
|
|
quote_data['content'] = '\n'.join(filter(None, [title, desc]))
|
|
elif quote_type == 'mixed':
|
|
# Handle mixed type in quote (text + images + files etc.)
|
|
items = quote_info.get('mixed', {}).get('msg_item', [])
|
|
texts = []
|
|
images = []
|
|
files = []
|
|
for item in items:
|
|
item_type = item.get('msgtype')
|
|
if item_type == 'text':
|
|
texts.append(item.get('text', {}).get('content', ''))
|
|
elif item_type == 'image':
|
|
img_info = item.get('image', {})
|
|
img_url = img_info.get('url')
|
|
img_aeskey = img_info.get('aeskey', '')
|
|
base64_data = await _safe_download_as_data_uri(img_url, img_aeskey)
|
|
if base64_data:
|
|
images.append(base64_data)
|
|
elif item_type == 'file':
|
|
file_info = item.get('file', {}) or {}
|
|
download_url = file_info.get('url') or file_info.get('fileurl')
|
|
item_aeskey = file_info.get('aeskey', '')
|
|
file_data = {
|
|
'filename': file_info.get('filename') or file_info.get('name'),
|
|
'filesize': file_info.get('filesize') or file_info.get('size'),
|
|
'md5sum': file_info.get('md5sum') or file_info.get('md5'),
|
|
'sdkfileid': file_info.get('sdkfileid') or file_info.get('fileid'),
|
|
'download_url': download_url,
|
|
'extra': file_info,
|
|
}
|
|
# Same as private chat: append aeskey to download_url for plugin processing
|
|
if download_url and item_aeskey:
|
|
file_data['download_url'] = download_url + f'?aeskey={item_aeskey}'
|
|
files.append(file_data)
|
|
if texts:
|
|
quote_data['content'] = ' '.join(texts)
|
|
if images:
|
|
quote_data['images'] = images
|
|
quote_data['picurl'] = images[0]
|
|
if files:
|
|
quote_data['files'] = files
|
|
quote_data['file'] = files[0]
|
|
|
|
message_data['quote'] = quote_data
|
|
|
|
return message_data
|
|
|
|
|
|
class WecomBotClient:
|
|
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger, unified_mode: bool = False):
|
|
"""企业微信智能机器人客户端。
|
|
|
|
Args:
|
|
Token: 企业微信回调验证使用的 token。
|
|
EnCodingAESKey: 企业微信消息加解密密钥。
|
|
Corpid: 企业 ID。
|
|
logger: 日志记录器。
|
|
unified_mode: 是否使用统一 webhook 模式(默认 False)。
|
|
|
|
Example:
|
|
>>> client = WecomBotClient(Token='token', EnCodingAESKey='aeskey', Corpid='corp', logger=logger)
|
|
"""
|
|
|
|
self.Token = Token
|
|
self.EnCodingAESKey = EnCodingAESKey
|
|
self.Corpid = Corpid
|
|
self.ReceiveId = ''
|
|
self.unified_mode = unified_mode
|
|
self.app = Quart(__name__)
|
|
|
|
# 只有在非统一模式下才注册独立路由
|
|
if not self.unified_mode:
|
|
self.app.add_url_rule(
|
|
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['POST', 'GET']
|
|
)
|
|
|
|
self._message_handlers = {
|
|
'example': [],
|
|
}
|
|
self.logger = logger
|
|
self.generated_content: dict[str, str] = {}
|
|
self.msg_id_map: dict[str, int] = {}
|
|
self.stream_sessions = StreamSessionManager(logger=logger)
|
|
self.stream_poll_timeout = 0.5
|
|
|
|
self._feedback_callback: Optional[Callable] = None
|
|
|
|
def set_feedback_callback(self, callback: Callable) -> None:
|
|
"""设置反馈回调函数。
|
|
|
|
Args:
|
|
callback: 反馈回调函数,签名: async def callback(feedback_id, feedback_type, feedback_content, inaccurate_reasons, session)
|
|
"""
|
|
self._feedback_callback = callback
|
|
|
|
@staticmethod
|
|
def _build_stream_payload(
|
|
stream_id: str, content: str, finish: bool, feedback_id: Optional[str] = None
|
|
) -> dict[str, Any]:
|
|
"""按照企业微信协议拼装返回报文。
|
|
|
|
Args:
|
|
stream_id: 企业微信会话 ID。
|
|
content: 推送的文本内容。
|
|
finish: 是否为最终片段。
|
|
feedback_id: 反馈 ID,用于接收用户点赞/点踩反馈。
|
|
|
|
Returns:
|
|
dict[str, Any]: 可直接加密返回的 payload。
|
|
|
|
Example:
|
|
组装 `{'msgtype': 'stream', 'stream': {'id': 'sid', ...}}` 结构。
|
|
"""
|
|
stream_payload = {
|
|
'id': stream_id,
|
|
'finish': finish,
|
|
'content': content,
|
|
}
|
|
if feedback_id:
|
|
stream_payload['feedback'] = {'id': feedback_id}
|
|
return {
|
|
'msgtype': 'stream',
|
|
'stream': stream_payload,
|
|
}
|
|
|
|
async def _encrypt_and_reply(self, payload: dict[str, Any], nonce: str) -> tuple[Response, int]:
|
|
"""对响应进行加密封装并返回给企业微信。
|
|
|
|
Args:
|
|
payload: 待加密的响应内容。
|
|
nonce: 企业微信回调参数中的 nonce。
|
|
|
|
Returns:
|
|
Tuple[Response, int]: Quart Response 对象及状态码。
|
|
|
|
Example:
|
|
在首包或刷新场景中调用以生成加密响应。
|
|
"""
|
|
reply_plain_str = json.dumps(payload, ensure_ascii=False)
|
|
reply_timestamp = str(int(time.time()))
|
|
ret, encrypt_text = self.wxcpt.EncryptMsg(reply_plain_str, nonce, reply_timestamp)
|
|
if ret != 0:
|
|
await self.logger.error(f'加密失败: {ret}')
|
|
return jsonify({'error': 'encrypt_failed'}), 500
|
|
|
|
root = ET.fromstring(encrypt_text)
|
|
encrypt = root.find('Encrypt').text
|
|
resp = {
|
|
'encrypt': encrypt,
|
|
}
|
|
return jsonify(resp), 200
|
|
|
|
async def _dispatch_event(self, event: wecombotevent.WecomBotEvent) -> None:
|
|
"""异步触发流水线处理,避免阻塞首包响应。
|
|
|
|
Args:
|
|
event: 由企业微信消息转换的内部事件对象。
|
|
"""
|
|
try:
|
|
await self._handle_message(event)
|
|
except Exception:
|
|
await self.logger.error(traceback.format_exc())
|
|
|
|
async def _handle_post_initial_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
|
|
"""处理企业微信首次推送的消息,返回 stream_id 并开启流水线。
|
|
|
|
Args:
|
|
msg_json: 解密后的企业微信消息 JSON。
|
|
nonce: 企业微信回调参数 nonce。
|
|
|
|
Returns:
|
|
Tuple[Response, int]: Quart Response 及状态码。
|
|
|
|
Example:
|
|
首次回调时调用,立即返回带 `stream_id` 的响应。
|
|
"""
|
|
session, is_new = self.stream_sessions.create_or_get(msg_json)
|
|
|
|
feedback_id = str(uuid.uuid4())
|
|
session.feedback_id = feedback_id
|
|
self.stream_sessions.register_feedback_id(session.stream_id, feedback_id)
|
|
|
|
message_data = await self.get_message(msg_json)
|
|
if message_data:
|
|
message_data['stream_id'] = session.stream_id
|
|
message_data['feedback_id'] = feedback_id
|
|
try:
|
|
event = wecombotevent.WecomBotEvent(message_data)
|
|
except Exception:
|
|
await self.logger.error(traceback.format_exc())
|
|
else:
|
|
if is_new:
|
|
asyncio.create_task(self._dispatch_event(event))
|
|
|
|
payload = self._build_stream_payload(session.stream_id, '', False, feedback_id)
|
|
return await self._encrypt_and_reply(payload, nonce)
|
|
|
|
async def _handle_post_followup_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
|
|
"""处理企业微信的流式刷新请求,按需返回增量片段。
|
|
|
|
Args:
|
|
msg_json: 解密后的企业微信刷新请求。
|
|
nonce: 企业微信回调参数 nonce。
|
|
|
|
Returns:
|
|
Tuple[Response, int]: Quart Response 及状态码。
|
|
|
|
Example:
|
|
在刷新请求中调用,按需返回增量片段。
|
|
"""
|
|
stream_info = msg_json.get('stream', {})
|
|
stream_id = stream_info.get('id', '')
|
|
if not stream_id:
|
|
await self.logger.error('刷新请求缺少 stream.id')
|
|
return await self._encrypt_and_reply(self._build_stream_payload('', '', True), nonce)
|
|
|
|
session = self.stream_sessions.get_session(stream_id)
|
|
chunk = await self.stream_sessions.consume(stream_id, timeout=self.stream_poll_timeout)
|
|
|
|
if not chunk:
|
|
cached_content = None
|
|
if session and session.msg_id:
|
|
cached_content = self.generated_content.pop(session.msg_id, None)
|
|
if cached_content is not None:
|
|
chunk = StreamChunk(content=cached_content, is_final=True)
|
|
else:
|
|
payload = self._build_stream_payload(stream_id, '', False)
|
|
return await self._encrypt_and_reply(payload, nonce)
|
|
|
|
payload = self._build_stream_payload(stream_id, chunk.content, chunk.is_final)
|
|
if chunk.is_final:
|
|
self.stream_sessions.mark_finished(stream_id)
|
|
return await self._encrypt_and_reply(payload, nonce)
|
|
|
|
async def handle_callback_request(self):
|
|
"""企业微信回调入口(独立端口模式,使用全局 request)。
|
|
|
|
Returns:
|
|
Quart Response: 根据请求类型返回验证、首包或刷新结果。
|
|
|
|
Example:
|
|
作为 Quart 路由处理函数直接注册并使用。
|
|
"""
|
|
return await self._handle_callback_internal(request)
|
|
|
|
async def handle_unified_webhook(self, req):
|
|
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
|
|
|
Args:
|
|
req: Quart Request 对象
|
|
|
|
Returns:
|
|
响应数据
|
|
"""
|
|
return await self._handle_callback_internal(req)
|
|
|
|
async def _handle_callback_internal(self, req):
|
|
"""处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
|
|
|
|
Args:
|
|
req: Quart Request 对象
|
|
"""
|
|
try:
|
|
self.wxcpt = WXBizMsgCrypt(self.Token, self.EnCodingAESKey, '')
|
|
|
|
if req.method == 'GET':
|
|
return await self._handle_get_callback(req)
|
|
|
|
if req.method == 'POST':
|
|
return await self._handle_post_callback(req)
|
|
|
|
return Response('', status=405)
|
|
|
|
except Exception:
|
|
await self.logger.error(traceback.format_exc())
|
|
return Response('Internal Server Error', status=500)
|
|
|
|
async def _handle_get_callback(self, req) -> tuple[Response, int] | Response:
|
|
"""处理企业微信的 GET 验证请求。"""
|
|
|
|
msg_signature = unquote(req.args.get('msg_signature', ''))
|
|
timestamp = unquote(req.args.get('timestamp', ''))
|
|
nonce = unquote(req.args.get('nonce', ''))
|
|
echostr = unquote(req.args.get('echostr', ''))
|
|
|
|
if not all([msg_signature, timestamp, nonce, echostr]):
|
|
await self.logger.error('请求参数缺失')
|
|
return Response('缺少参数', status=400)
|
|
|
|
ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
|
if ret != 0:
|
|
await self.logger.error('验证URL失败')
|
|
return Response('验证失败', status=403)
|
|
|
|
return Response(decrypted_str, mimetype='text/plain')
|
|
|
|
async def _handle_post_callback(self, req) -> tuple[Response, int] | Response:
|
|
"""处理企业微信的 POST 回调请求。"""
|
|
|
|
self.stream_sessions.cleanup()
|
|
|
|
msg_signature = unquote(req.args.get('msg_signature', ''))
|
|
timestamp = unquote(req.args.get('timestamp', ''))
|
|
nonce = unquote(req.args.get('nonce', ''))
|
|
|
|
encrypted_json = await req.get_json()
|
|
encrypted_msg = (encrypted_json or {}).get('encrypt', '')
|
|
if not encrypted_msg:
|
|
await self.logger.error("请求体中缺少 'encrypt' 字段")
|
|
return Response('Bad Request', status=400)
|
|
|
|
xml_post_data = f'<xml><Encrypt><![CDATA[{encrypted_msg}]]></Encrypt></xml>'
|
|
ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce)
|
|
if ret != 0:
|
|
await self.logger.error('解密失败')
|
|
return Response('解密失败', status=400)
|
|
|
|
msg_json = json.loads(decrypted_xml)
|
|
|
|
event = msg_json.get('event', {})
|
|
event_type = event.get('eventtype', '')
|
|
|
|
if event_type == 'feedback_event':
|
|
return await self._handle_feedback_event(msg_json, nonce)
|
|
|
|
if msg_json.get('msgtype') == 'stream':
|
|
return await self._handle_post_followup_response(msg_json, nonce)
|
|
|
|
return await self._handle_post_initial_response(msg_json, nonce)
|
|
|
|
async def _handle_feedback_event(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
|
|
"""处理企业微信用户反馈事件(点赞/点踩)。
|
|
|
|
Args:
|
|
msg_json: 解密后的企业微信反馈事件 JSON。
|
|
nonce: 企业微信回调参数 nonce。
|
|
|
|
Returns:
|
|
Tuple[Response, int]: Quart Response 及状态码。
|
|
|
|
Note:
|
|
企业微信协议要求:反馈事件目前仅支持回复空包。
|
|
"""
|
|
try:
|
|
feedback_event = msg_json.get('event', {}).get('feedback_event', {})
|
|
feedback_id = feedback_event.get('id', '')
|
|
feedback_type = feedback_event.get('type', 0)
|
|
feedback_content = feedback_event.get('content', '')
|
|
inaccurate_reasons = feedback_event.get('inaccurate_reason_list', [])
|
|
|
|
await self.logger.info(
|
|
f'收到用户反馈事件: feedback_id={feedback_id}, type={feedback_type}, '
|
|
f'content={feedback_content}, reasons={inaccurate_reasons}'
|
|
)
|
|
|
|
session = self.stream_sessions.get_session_by_feedback_id(feedback_id)
|
|
|
|
if session:
|
|
await self.logger.info(
|
|
f'反馈关联到会话: stream_id={session.stream_id}, msg_id={session.msg_id}, user_id={session.user_id}'
|
|
)
|
|
else:
|
|
await self.logger.warning(f'未找到 feedback_id={feedback_id} 对应的会话,仍将记录反馈')
|
|
|
|
# Dispatch feedback event regardless of session availability
|
|
for handler in self._message_handlers.get('feedback', []):
|
|
try:
|
|
await handler(
|
|
feedback_id=feedback_id,
|
|
feedback_type=feedback_type,
|
|
feedback_content=feedback_content,
|
|
inaccurate_reasons=inaccurate_reasons,
|
|
session=session,
|
|
)
|
|
except Exception:
|
|
await self.logger.error(traceback.format_exc())
|
|
|
|
if self._feedback_callback:
|
|
try:
|
|
await self._feedback_callback(
|
|
feedback_id=feedback_id,
|
|
feedback_type=feedback_type,
|
|
feedback_content=feedback_content,
|
|
inaccurate_reasons=inaccurate_reasons,
|
|
session=session,
|
|
)
|
|
except Exception:
|
|
await self.logger.error(traceback.format_exc())
|
|
|
|
except Exception:
|
|
await self.logger.error(traceback.format_exc())
|
|
|
|
return await self._encrypt_and_reply({}, nonce)
|
|
|
|
async def get_message(self, msg_json):
|
|
return await parse_wecom_bot_message(msg_json, self.EnCodingAESKey, self.logger)
|
|
|
|
async def _handle_message(self, event: wecombotevent.WecomBotEvent):
|
|
"""
|
|
处理消息事件。
|
|
"""
|
|
try:
|
|
message_id = event.message_id
|
|
if message_id in self.msg_id_map.keys():
|
|
self.msg_id_map[message_id] += 1
|
|
return
|
|
self.msg_id_map[message_id] = 1
|
|
msg_type = event.type
|
|
if msg_type in self._message_handlers:
|
|
for handler in self._message_handlers[msg_type]:
|
|
await handler(event)
|
|
except Exception:
|
|
print(traceback.format_exc())
|
|
|
|
async def push_stream_chunk(self, msg_id: str, content: str, is_final: bool = False) -> bool:
|
|
"""将流水线片段推送到 stream 会话。
|
|
|
|
Args:
|
|
msg_id: 原始企业微信消息 ID。
|
|
content: 模型产生的片段内容。
|
|
is_final: 是否为最终片段。
|
|
|
|
Returns:
|
|
bool: 当成功写入流式队列时返回 True。
|
|
|
|
Example:
|
|
在流水线 `reply_message_chunk` 中调用,将增量推送至企业微信。
|
|
"""
|
|
# 根据 msg_id 找到对应 stream 会话,如果不存在说明当前消息非流式
|
|
stream_id = self.stream_sessions.get_stream_id_by_msg(msg_id)
|
|
if not stream_id:
|
|
return False
|
|
|
|
chunk = StreamChunk(content=content, is_final=is_final)
|
|
await self.stream_sessions.publish(stream_id, chunk)
|
|
if is_final:
|
|
self.stream_sessions.mark_finished(stream_id)
|
|
return True
|
|
|
|
async def set_message(self, msg_id: str, content: str):
|
|
"""兼容旧逻辑:若无法流式返回则缓存最终结果。
|
|
|
|
Args:
|
|
msg_id: 企业微信消息 ID。
|
|
content: 最终回复的文本内容。
|
|
|
|
Example:
|
|
在非流式场景下缓存最终结果以备刷新时返回。
|
|
"""
|
|
handled = await self.push_stream_chunk(msg_id, content, is_final=True)
|
|
if not handled:
|
|
self.generated_content[msg_id] = content
|
|
|
|
def on_message(self, msg_type: str):
|
|
def decorator(func: Callable[[wecombotevent.WecomBotEvent], None]):
|
|
if msg_type not in self._message_handlers:
|
|
self._message_handlers[msg_type] = []
|
|
self._message_handlers[msg_type].append(func)
|
|
return func
|
|
|
|
return decorator
|
|
|
|
def on_feedback(self):
|
|
def decorator(func: Callable):
|
|
if 'feedback' not in self._message_handlers:
|
|
self._message_handlers['feedback'] = []
|
|
self._message_handlers['feedback'].append(func)
|
|
return func
|
|
|
|
return decorator
|
|
|
|
async def download_url_to_base64(self, download_url, encoding_aes_key):
|
|
data, _filename = await download_encrypted_file(download_url, encoding_aes_key, self.logger)
|
|
if data:
|
|
return _bytes_to_data_uri(data)
|
|
return None
|
|
|
|
async def run_task(self, host: str, port: int, *args, **kwargs):
|
|
"""
|
|
启动 Quart 应用。
|
|
"""
|
|
await self.app.run_task(host=host, port=port, *args, **kwargs)
|