From 53ecd0933ead339d36de18116dfd9f8dcf226179 Mon Sep 17 00:00:00 2001 From: Alfons Date: Tue, 28 Oct 2025 18:12:35 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E4=BC=81?= =?UTF-8?q?=E4=B8=9A=E5=BE=AE=E4=BF=A1=E6=99=BA=E8=83=BD=E6=9C=BA=E5=99=A8?= =?UTF-8?q?=E4=BA=BA=E6=B5=81=E5=BC=8F=E5=93=8D=E5=BA=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 重构 WecomBotClient,支持流式会话管理和队列机制 - 新增 StreamSession 和 StreamSessionManager 类管理流式上下文 - 实现 reply_message_chunk 接口支持流式输出 - 优化消息处理流程,支持异步流式响应 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- libs/wecom_ai_bot_api/api.py | 585 +++++++++++++++++++++++-------- pkg/platform/sources/wecombot.py | 44 +++ 2 files changed, 477 insertions(+), 152 deletions(-) diff --git a/libs/wecom_ai_bot_api/api.py b/libs/wecom_ai_bot_api/api.py index 6b5dc573..41d379a6 100644 --- a/libs/wecom_ai_bot_api/api.py +++ b/libs/wecom_ai_bot_api/api.py @@ -1,189 +1,445 @@ +import asyncio +import base64 import json import time +import traceback import uuid import xml.etree.ElementTree as ET +from dataclasses import dataclass, field +from typing import Any, Callable, Optional from urllib.parse import unquote -import hashlib -import traceback import httpx -from libs.wecom_ai_bot_api.WXBizMsgCrypt3 import WXBizMsgCrypt -from quart import Quart, request, Response, jsonify -import langbot_plugin.api.entities.builtin.platform.message as platform_message -import asyncio -from libs.wecom_ai_bot_api import wecombotevent -from typing import Callable -import base64 from Crypto.Cipher import AES +from quart import Quart, request, Response, jsonify + +from libs.wecom_ai_bot_api import wecombotevent +from libs.wecom_ai_bot_api.WXBizMsgCrypt3 import WXBizMsgCrypt from pkg.platform.logger import EventLogger +@dataclass +class StreamChunk: + """描述单次推送给企业微信的流式片段。""" + + # 需要返回给企业微信的文本内容 + content: str + + # 标记是否为最终片段,对应企业微信协议里的 finish 字段 + is_final: bool = False + + # 预留额外元信息,未来支持多模态扩展时可使用 + meta: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class StreamSession: + """维护一次企业微信流式会话的上下文。""" + + # 企业微信要求的 stream_id,用于标识后续刷新请求 + stream_id: str + + # 原始消息的 msgid,便于与流水线消息对应 + msg_id: str + + # 群聊会话标识(单聊时为空) + chat_id: Optional[str] + + # 触发消息的发送者 + user_id: Optional[str] + + # 会话创建时间 + created_at: float = field(default_factory=time.time) + + # 最近一次被访问的时间,cleanup 依据该值判断过期 + last_access: float = field(default_factory=time.time) + + # 将流水线增量结果缓存到队列,刷新请求逐条消费 + queue: asyncio.Queue = field(default_factory=asyncio.Queue) + + # 是否已经完成(收到最终片段) + finished: bool = False + + # 缓存最近一次片段,处理重试或超时兜底 + last_chunk: Optional[StreamChunk] = None + + +class StreamSessionManager: + """管理 stream 会话的生命周期,并负责队列的生产消费。""" + + def __init__(self, logger: EventLogger, ttl: int = 60) -> None: + self.logger = logger + + self.ttl = ttl # 超时时间(秒),超过该时间未被访问的会话会被清理由 cleanup + self._sessions: dict[str, StreamSession] = {} # stream_id -> StreamSession 映射 + self._msg_index: dict[str, str] = {} # msgid -> stream_id 映射,便于流水线根据消息 ID 找到会话 + + def get_stream_id_by_msg(self, msg_id: str) -> Optional[str]: + if not msg_id: + return None + return self._msg_index.get(msg_id) + + def get_session(self, stream_id: str) -> Optional[StreamSession]: + return self._sessions.get(stream_id) + + def create_or_get(self, msg_json: dict[str, Any]) -> tuple[StreamSession, bool]: + """根据企业微信回调创建或获取会话。 + + Args: + msg_json: 企业微信解密后的回调 JSON。 + + Returns: + Tuple[StreamSession, bool]: `StreamSession` 为会话实例,`bool` 指示是否为新建会话。 + + Example: + 在首次回调中调用,得到 `is_new=True` 后再触发流水线。 + """ + msg_id = msg_json.get('msgid', '') + if msg_id and msg_id in self._msg_index: + stream_id = self._msg_index[msg_id] + session = self._sessions.get(stream_id) + if session: + session.last_access = time.time() + return session, False + + stream_id = str(uuid.uuid4()) + session = StreamSession( + stream_id=stream_id, + msg_id=msg_id, + chat_id=msg_json.get('chatid'), + user_id=msg_json.get('from', {}).get('userid'), + ) + + if msg_id: + self._msg_index[msg_id] = stream_id + self._sessions[stream_id] = session + return session, True + + async def publish(self, stream_id: str, chunk: StreamChunk) -> bool: + """向 stream 队列写入新的增量片段。 + + Args: + stream_id: 企业微信分配的流式会话 ID。 + chunk: 待发送的增量片段。 + + Returns: + bool: 当流式队列存在并成功入队时返回 True。 + + Example: + 在收到模型增量后调用 `await manager.publish('sid', StreamChunk('hello'))`。 + """ + session = self._sessions.get(stream_id) + if not session: + return False + + session.last_access = time.time() + session.last_chunk = chunk + + try: + session.queue.put_nowait(chunk) + except asyncio.QueueFull: + # 默认无界队列,此处兜底防御 + await session.queue.put(chunk) + + if chunk.is_final: + session.finished = True + + return True + + async def consume(self, stream_id: str, timeout: float = 0.5) -> Optional[StreamChunk]: + """从队列中取出一个片段,若超时返回 None。 + + Args: + stream_id: 企业微信流式会话 ID。 + timeout: 取片段的最长等待时间(秒)。 + + Returns: + Optional[StreamChunk]: 成功时返回片段,超时或会话不存在时返回 None。 + + Example: + 企业微信刷新到达时调用,若队列有数据则立即返回 `StreamChunk`。 + """ + session = self._sessions.get(stream_id) + if not session: + return None + + session.last_access = time.time() + + try: + chunk = await asyncio.wait_for(session.queue.get(), timeout) + session.last_access = time.time() + if chunk.is_final: + session.finished = True + return chunk + except asyncio.TimeoutError: + if session.finished and session.last_chunk: + return session.last_chunk + return None + + def mark_finished(self, stream_id: str) -> None: + session = self._sessions.get(stream_id) + if session: + session.finished = True + session.last_access = time.time() + + def cleanup(self) -> None: + """定期清理过期会话,防止队列与映射无上限累积。""" + now = time.time() + expired: list[str] = [] + for stream_id, session in self._sessions.items(): + if now - session.last_access > self.ttl: + expired.append(stream_id) + + for stream_id in expired: + session = self._sessions.pop(stream_id, None) + if not session: + continue + msg_id = session.msg_id + if msg_id and self._msg_index.get(msg_id) == stream_id: + self._msg_index.pop(msg_id, None) + class WecomBotClient: - def __init__(self,Token:str,EnCodingAESKey:str,Corpid:str,logger:EventLogger): - self.Token=Token - self.EnCodingAESKey=EnCodingAESKey - self.Corpid=Corpid + def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger): + """企业微信智能机器人客户端。 + + Args: + Token: 企业微信回调验证使用的 token。 + EnCodingAESKey: 企业微信消息加解密密钥。 + Corpid: 企业 ID。 + logger: 日志记录器。 + + Example: + >>> client = WecomBotClient(Token='token', EnCodingAESKey='aeskey', Corpid='corp', logger=logger) + """ + + self.Token = Token + self.EnCodingAESKey = EnCodingAESKey + self.Corpid = Corpid self.ReceiveId = '' self.app = Quart(__name__) self.app.add_url_rule( '/callback/command', 'handle_callback', self.handle_callback_request, - methods=['POST','GET'] + methods=['POST', 'GET'] ) self._message_handlers = { 'example': [], } - self.user_stream_map = {} self.logger = logger - self.generated_content = {} - self.msg_id_map = {} + self.generated_content: dict[str, str] = {} + self.msg_id_map: dict[str, int] = {} + self.stream_sessions = StreamSessionManager(logger=logger) + self.stream_poll_timeout = 0.5 - async def sha1_signature(token: str, timestamp: str, nonce: str, encrypt: str) -> str: - raw = "".join(sorted([token, timestamp, nonce, encrypt])) - return hashlib.sha1(raw.encode("utf-8")).hexdigest() - - async def handle_callback_request(self): + @staticmethod + def _build_stream_payload(stream_id: str, content: str, finish: bool) -> dict[str, Any]: + """按照企业微信协议拼装返回报文。 + + Args: + stream_id: 企业微信会话 ID。 + content: 推送的文本内容。 + finish: 是否为最终片段。 + + Returns: + dict[str, Any]: 可直接加密返回的 payload。 + + Example: + 组装 `{'msgtype': 'stream', 'stream': {'id': 'sid', ...}}` 结构。 + """ + return { + 'msgtype': 'stream', + 'stream': { + 'id': stream_id, + 'finish': finish, + 'content': content, + }, + } + + async def _encrypt_and_reply(self, payload: dict[str, Any], nonce: str) -> tuple[Response, int]: + """对响应进行加密封装并返回给企业微信。 + + Args: + payload: 待加密的响应内容。 + nonce: 企业微信回调参数中的 nonce。 + + Returns: + Tuple[Response, int]: Quart Response 对象及状态码。 + + Example: + 在首包或刷新场景中调用以生成加密响应。 + """ + reply_plain_str = json.dumps(payload, ensure_ascii=False) + reply_timestamp = str(int(time.time())) + ret, encrypt_text = self.wxcpt.EncryptMsg(reply_plain_str, nonce, reply_timestamp) + if ret != 0: + await self.logger.error(f'加密失败: {ret}') + return jsonify({'error': 'encrypt_failed'}), 500 + + root = ET.fromstring(encrypt_text) + encrypt = root.find('Encrypt').text + resp = { + 'encrypt': encrypt, + } + return jsonify(resp), 200 + + async def _dispatch_event(self, event: wecombotevent.WecomBotEvent) -> None: + """异步触发流水线处理,避免阻塞首包响应。 + + Args: + event: 由企业微信消息转换的内部事件对象。 + """ try: - self.wxcpt=WXBizMsgCrypt(self.Token,self.EnCodingAESKey,'') + await self._handle_message(event) + except Exception: + await self.logger.error(traceback.format_exc()) - if request.method == "GET": + async def _handle_initial_message(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]: + """处理企业微信首次推送的消息,返回 stream_id 并开启流水线。 - msg_signature = unquote(request.args.get("msg_signature", "")) - timestamp = unquote(request.args.get("timestamp", "")) - nonce = unquote(request.args.get("nonce", "")) - echostr = unquote(request.args.get("echostr", "")) + Args: + msg_json: 解密后的企业微信消息 JSON。 + nonce: 企业微信回调参数 nonce。 + + Returns: + Tuple[Response, int]: Quart Response 及状态码。 + + Example: + 首次回调时调用,立即返回带 `stream_id` 的响应。 + """ + session, is_new = self.stream_sessions.create_or_get(msg_json) + + message_data = await self.get_message(msg_json) + if message_data: + message_data['stream_id'] = session.stream_id + try: + event = wecombotevent.WecomBotEvent(message_data) + except Exception: + await self.logger.error(traceback.format_exc()) + else: + if is_new: + asyncio.create_task(self._dispatch_event(event)) + + payload = self._build_stream_payload(session.stream_id, '', False) + return await self._encrypt_and_reply(payload, nonce) + + async def _handle_stream_refresh(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]: + """处理企业微信的流式刷新请求,按需返回增量片段。 + + Args: + msg_json: 解密后的企业微信刷新请求。 + nonce: 企业微信回调参数 nonce。 + + Returns: + Tuple[Response, int]: Quart Response 及状态码。 + + Example: + 在刷新请求中调用,按需返回增量片段。 + """ + stream_info = msg_json.get('stream', {}) + stream_id = stream_info.get('id', '') + if not stream_id: + await self.logger.error('刷新请求缺少 stream.id') + return await self._encrypt_and_reply(self._build_stream_payload('', '', True), nonce) + + session = self.stream_sessions.get_session(stream_id) + chunk = await self.stream_sessions.consume(stream_id, timeout=self.stream_poll_timeout) + + if not chunk: + cached_content = None + if session and session.msg_id: + cached_content = self.generated_content.pop(session.msg_id, None) + if cached_content is not None: + chunk = StreamChunk(content=cached_content, is_final=True) + else: + payload = self._build_stream_payload(stream_id, '', False) + return await self._encrypt_and_reply(payload, nonce) + + payload = self._build_stream_payload(stream_id, chunk.content, chunk.is_final) + if chunk.is_final: + self.stream_sessions.mark_finished(stream_id) + return await self._encrypt_and_reply(payload, nonce) + + async def handle_callback_request(self): + """企业微信回调入口。 + + Returns: + Quart Response: 根据请求类型返回验证、首包或刷新结果。 + + Example: + 作为 Quart 路由处理函数直接注册并使用。 + """ + try: + self.wxcpt = WXBizMsgCrypt(self.Token, self.EnCodingAESKey, '') + await self.logger.info(f'{request.method} {request.url} {str(request.args)}') + + if request.method == 'GET': + # GET 用于验证回调 URL,有效期内直接返回微信给的 echostr + msg_signature = unquote(request.args.get('msg_signature', '')) + timestamp = unquote(request.args.get('timestamp', '')) + nonce = unquote(request.args.get('nonce', '')) + echostr = unquote(request.args.get('echostr', '')) if not all([msg_signature, timestamp, nonce, echostr]): - await self.logger.error("请求参数缺失") - return Response("缺少参数", status=400) + await self.logger.error('请求参数缺失') + return Response('缺少参数', status=400) ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr) if ret != 0: - - await self.logger.error("验证URL失败") - return Response("验证失败", status=403) + await self.logger.error('验证URL失败') + return Response('验证失败', status=403) - return Response(decrypted_str, mimetype="text/plain") + return Response(decrypted_str, mimetype='text/plain') - elif request.method == "POST": - msg_signature = unquote(request.args.get("msg_signature", "")) - timestamp = unquote(request.args.get("timestamp", "")) - nonce = unquote(request.args.get("nonce", "")) + if request.method != 'POST': + return Response('', status=405) - try: - timeout = 3 - interval = 0.1 - start_time = time.monotonic() - encrypted_json = await request.get_json() - encrypted_msg = encrypted_json.get("encrypt", "") - if not encrypted_msg: - await self.logger.error("请求体中缺少 'encrypt' 字段") + self.stream_sessions.cleanup() - xml_post_data = f"" - ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce) - if ret != 0: - await self.logger.error("解密失败") + msg_signature = unquote(request.args.get('msg_signature', '')) + timestamp = unquote(request.args.get('timestamp', '')) + nonce = unquote(request.args.get('nonce', '')) + encrypted_json = await request.get_json() + encrypted_msg = (encrypted_json or {}).get('encrypt', '') + if not encrypted_msg: + await self.logger.error("请求体中缺少 'encrypt' 字段") + return Response('Bad Request', status=400) - msg_json = json.loads(decrypted_xml) - - from_user_id = msg_json.get("from", {}).get("userid") - chatid = msg_json.get("chatid", "") - - message_data = await self.get_message(msg_json) - - + xml_post_data = f"" + ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce) + if ret != 0: + await self.logger.error('解密失败') + return Response('解密失败', status=400) - if message_data: - try: - event = wecombotevent.WecomBotEvent(message_data) - if event: - await self._handle_message(event) - except Exception as e: - await self.logger.error(traceback.format_exc()) - print(traceback.format_exc()) + msg_json = json.loads(decrypted_xml) - start_time = time.time() - try: - if msg_json.get('chattype','') == 'single': - if from_user_id in self.user_stream_map: - stream_id = self.user_stream_map[from_user_id] - else: - stream_id =str(uuid.uuid4()) - self.user_stream_map[from_user_id] = stream_id - + if msg_json.get('msgtype') == 'stream': + # 企业微信刷新请求:尝试从队列中取出增量回复 + return await self._handle_stream_refresh(msg_json, nonce) - else: - - if chatid in self.user_stream_map: - stream_id = self.user_stream_map[chatid] - else: - stream_id = str(uuid.uuid4()) - self.user_stream_map[chatid] = stream_id - except Exception as e: - await self.logger.error(traceback.format_exc()) - print(traceback.format_exc()) - while True: - content = self.generated_content.pop(msg_json['msgid'],None) - if content: - reply_plain = { - "msgtype": "stream", - "stream": { - "id": stream_id, - "finish": True, - "content": content - } - } - reply_plain_str = json.dumps(reply_plain, ensure_ascii=False) + # 首次请求:快速返回 stream_id 并异步处理流水线 + return await self._handle_initial_message(msg_json, nonce) - reply_timestamp = str(int(time.time())) - ret, encrypt_text = self.wxcpt.EncryptMsg(reply_plain_str, nonce, reply_timestamp) - if ret != 0: - - await self.logger.error("加密失败"+str(ret)) - - - root = ET.fromstring(encrypt_text) - encrypt = root.find("Encrypt").text - resp = { - "encrypt": encrypt, - } - return jsonify(resp), 200 - - if time.time() - start_time > timeout: - break - - await asyncio.sleep(interval) - - if self.msg_id_map.get(message_data['msgid'], 1) == 3: - await self.logger.error('请求失效:暂不支持智能机器人超过7秒的请求,如有需求,请联系 LangBot 团队。') - return '' - - except Exception as e: - await self.logger.error(traceback.format_exc()) - print(traceback.format_exc()) - - except Exception as e: + except Exception: await self.logger.error(traceback.format_exc()) - print(traceback.format_exc()) + return Response('Internal Server Error', status=500) - - async def get_message(self,msg_json): + async def get_message(self, msg_json): message_data = {} - if msg_json.get('chattype','') == 'single': + if msg_json.get('chattype', '') == 'single': message_data['type'] = 'single' - elif msg_json.get('chattype','') == 'group': + elif msg_json.get('chattype', '') == 'group': message_data['type'] = 'group' if msg_json.get('msgtype') == 'text': - message_data['content'] = msg_json.get('text',{}).get('content') + message_data['content'] = msg_json.get('text', {}).get('content') elif msg_json.get('msgtype') == 'image': - picurl = msg_json.get('image', {}).get('url','') - base64 = await self.download_url_to_base64(picurl,self.EnCodingAESKey) - message_data['picurl'] = base64 + picurl = msg_json.get('image', {}).get('url', '') + base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey) + message_data['picurl'] = base64 elif msg_json.get('msgtype') == 'mixed': items = msg_json.get('mixed', {}).get('msg_item', []) texts = [] @@ -197,8 +453,8 @@ class WecomBotClient: if texts: message_data['content'] = "".join(texts) # 拼接所有 text if picurl: - base64 = await self.download_url_to_base64(picurl,self.EnCodingAESKey) - message_data['picurl'] = base64 # 只保留第一个 image + base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey) + message_data['picurl'] = base64 # 只保留第一个 image message_data['userid'] = msg_json.get('from', {}).get('userid', '') message_data['msgid'] = msg_json.get('msgid', '') @@ -207,7 +463,7 @@ class WecomBotClient: message_data['aibotid'] = msg_json.get('aibotid', '') return message_data - + async def _handle_message(self, event: wecombotevent.WecomBotEvent): """ 处理消息事件。 @@ -223,10 +479,46 @@ class WecomBotClient: for handler in self._message_handlers[msg_type]: await handler(event) except Exception: - print(traceback.format_exc()) + print(traceback.format_exc()) + + async def push_stream_chunk(self, msg_id: str, content: str, is_final: bool = False) -> bool: + """将流水线片段推送到 stream 会话。 + + Args: + msg_id: 原始企业微信消息 ID。 + content: 模型产生的片段内容。 + is_final: 是否为最终片段。 + + Returns: + bool: 当成功写入流式队列时返回 True。 + + Example: + 在流水线 `reply_message_chunk` 中调用,将增量推送至企业微信。 + """ + # 根据 msg_id 找到对应 stream 会话,如果不存在说明当前消息非流式 + stream_id = self.stream_sessions.get_stream_id_by_msg(msg_id) + if not stream_id: + return False + + chunk = StreamChunk(content=content, is_final=is_final) + await self.stream_sessions.publish(stream_id, chunk) + if is_final: + self.stream_sessions.mark_finished(stream_id) + return True async def set_message(self, msg_id: str, content: str): - self.generated_content[msg_id] = content + """兼容旧逻辑:若无法流式返回则缓存最终结果。 + + Args: + msg_id: 企业微信消息 ID。 + content: 最终回复的文本内容。 + + Example: + 在非流式场景下缓存最终结果以备刷新时返回。 + """ + handled = await self.push_stream_chunk(msg_id, content, is_final=True) + if not handled: + self.generated_content[msg_id] = content def on_message(self, msg_type: str): def decorator(func: Callable[[wecombotevent.WecomBotEvent], None]): @@ -237,7 +529,6 @@ class WecomBotClient: return decorator - async def download_url_to_base64(self, download_url, encoding_aes_key): async with httpx.AsyncClient() as client: response = await client.get(download_url) @@ -247,26 +538,22 @@ class WecomBotClient: encrypted_bytes = response.content - aes_key = base64.b64decode(encoding_aes_key + "=") # base64 补齐 iv = aes_key[:16] - cipher = AES.new(aes_key, AES.MODE_CBC, iv) decrypted = cipher.decrypt(encrypted_bytes) - pad_len = decrypted[-1] decrypted = decrypted[:-pad_len] - - if decrypted.startswith(b"\xff\xd8"): # JPEG + if decrypted.startswith(b"\xff\xd8"): # JPEG mime_type = "image/jpeg" elif decrypted.startswith(b"\x89PNG"): # PNG mime_type = "image/png" elif decrypted.startswith((b"GIF87a", b"GIF89a")): # GIF mime_type = "image/gif" - elif decrypted.startswith(b"BM"): # BMP + elif decrypted.startswith(b"BM"): # BMP mime_type = "image/bmp" elif decrypted.startswith(b"II*\x00") or decrypted.startswith(b"MM\x00*"): # TIFF mime_type = "image/tiff" @@ -276,15 +563,9 @@ class WecomBotClient: # 转 base64 base64_str = base64.b64encode(decrypted).decode("utf-8") return f"data:{mime_type};base64,{base64_str}" - async def run_task(self, host: str, port: int, *args, **kwargs): """ 启动 Quart 应用。 """ await self.app.run_task(host=host, port=port, *args, **kwargs) - - - - - diff --git a/pkg/platform/sources/wecombot.py b/pkg/platform/sources/wecombot.py index 9487b637..e5f2d1b5 100644 --- a/pkg/platform/sources/wecombot.py +++ b/pkg/platform/sources/wecombot.py @@ -117,6 +117,50 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): content = await self.message_converter.yiri2target(message) await self.bot.set_message(message_source.source_platform_object.message_id, content) + async def reply_message_chunk( + self, + message_source: platform_events.MessageEvent, + bot_message, + message: platform_message.MessageChain, + quote_origin: bool = False, + is_final: bool = False, + ): + """将流水线增量输出写入企业微信 stream 会话。 + + Args: + message_source: 流水线提供的原始消息事件。 + bot_message: 当前片段对应的模型元信息(未使用)。 + message: 需要回复的消息链。 + quote_origin: 是否引用原消息(企业微信暂不支持)。 + is_final: 标记当前片段是否为最终回复。 + + Returns: + dict: 包含 `stream` 键,标识写入是否成功。 + + Example: + 在流水线 `reply_message_chunk` 调用中自动触发,无需手动调用。 + """ + # 转换为纯文本(智能机器人当前协议仅支持文本流) + content = await self.message_converter.yiri2target(message) + msg_id = message_source.source_platform_object.message_id + + # 将片段推送到 WecomBotClient 中的队列,返回值用于判断是否走降级逻辑 + success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final) + if not success and is_final: + # 未命中流式队列时使用旧有 set_message 兜底 + await self.bot.set_message(msg_id, content) + return {'stream': success} + + async def is_stream_output_supported(self) -> bool: + """智能机器人侧默认开启流式能力。 + + Returns: + bool: 恒定返回 True。 + + Example: + 流水线执行阶段会调用此方法以确认是否启用流式。""" + return True + async def send_message(self, target_type, target_id, message): pass From 69767ebdb4fda644e57229e7d3811acc624825ef Mon Sep 17 00:00:00 2001 From: Alfonsxh Date: Tue, 28 Oct 2025 18:30:55 +0800 Subject: [PATCH 2/2] refactor: split WeCom callback handlers --- libs/wecom_ai_bot_api/api.py | 99 +++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 46 deletions(-) diff --git a/libs/wecom_ai_bot_api/api.py b/libs/wecom_ai_bot_api/api.py index 41d379a6..9568eab4 100644 --- a/libs/wecom_ai_bot_api/api.py +++ b/libs/wecom_ai_bot_api/api.py @@ -295,7 +295,7 @@ class WecomBotClient: except Exception: await self.logger.error(traceback.format_exc()) - async def _handle_initial_message(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]: + async def _handle_post_initial_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]: """处理企业微信首次推送的消息,返回 stream_id 并开启流水线。 Args: @@ -324,7 +324,7 @@ class WecomBotClient: payload = self._build_stream_payload(session.stream_id, '', False) return await self._encrypt_and_reply(payload, nonce) - async def _handle_stream_refresh(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]: + async def _handle_post_followup_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]: """处理企业微信的流式刷新请求,按需返回增量片段。 Args: @@ -375,57 +375,64 @@ class WecomBotClient: await self.logger.info(f'{request.method} {request.url} {str(request.args)}') if request.method == 'GET': - # GET 用于验证回调 URL,有效期内直接返回微信给的 echostr - msg_signature = unquote(request.args.get('msg_signature', '')) - timestamp = unquote(request.args.get('timestamp', '')) - nonce = unquote(request.args.get('nonce', '')) - echostr = unquote(request.args.get('echostr', '')) + return await self._handle_get_callback() - if not all([msg_signature, timestamp, nonce, echostr]): - await self.logger.error('请求参数缺失') - return Response('缺少参数', status=400) + if request.method == 'POST': + return await self._handle_post_callback() - ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr) - if ret != 0: - await self.logger.error('验证URL失败') - return Response('验证失败', status=403) - - return Response(decrypted_str, mimetype='text/plain') - - if request.method != 'POST': - return Response('', status=405) - - self.stream_sessions.cleanup() - - msg_signature = unquote(request.args.get('msg_signature', '')) - timestamp = unquote(request.args.get('timestamp', '')) - nonce = unquote(request.args.get('nonce', '')) - - encrypted_json = await request.get_json() - encrypted_msg = (encrypted_json or {}).get('encrypt', '') - if not encrypted_msg: - await self.logger.error("请求体中缺少 'encrypt' 字段") - return Response('Bad Request', status=400) - - xml_post_data = f"" - ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce) - if ret != 0: - await self.logger.error('解密失败') - return Response('解密失败', status=400) - - msg_json = json.loads(decrypted_xml) - - if msg_json.get('msgtype') == 'stream': - # 企业微信刷新请求:尝试从队列中取出增量回复 - return await self._handle_stream_refresh(msg_json, nonce) - - # 首次请求:快速返回 stream_id 并异步处理流水线 - return await self._handle_initial_message(msg_json, nonce) + return Response('', status=405) except Exception: await self.logger.error(traceback.format_exc()) return Response('Internal Server Error', status=500) + async def _handle_get_callback(self) -> tuple[Response, int] | Response: + """处理企业微信的 GET 验证请求。""" + + msg_signature = unquote(request.args.get('msg_signature', '')) + timestamp = unquote(request.args.get('timestamp', '')) + nonce = unquote(request.args.get('nonce', '')) + echostr = unquote(request.args.get('echostr', '')) + + if not all([msg_signature, timestamp, nonce, echostr]): + await self.logger.error('请求参数缺失') + return Response('缺少参数', status=400) + + ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr) + if ret != 0: + await self.logger.error('验证URL失败') + return Response('验证失败', status=403) + + return Response(decrypted_str, mimetype='text/plain') + + async def _handle_post_callback(self) -> tuple[Response, int] | Response: + """处理企业微信的 POST 回调请求。""" + + self.stream_sessions.cleanup() + + msg_signature = unquote(request.args.get('msg_signature', '')) + timestamp = unquote(request.args.get('timestamp', '')) + nonce = unquote(request.args.get('nonce', '')) + + encrypted_json = await request.get_json() + encrypted_msg = (encrypted_json or {}).get('encrypt', '') + if not encrypted_msg: + await self.logger.error("请求体中缺少 'encrypt' 字段") + return Response('Bad Request', status=400) + + xml_post_data = f"" + ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce) + if ret != 0: + await self.logger.error('解密失败') + return Response('解密失败', status=400) + + msg_json = json.loads(decrypted_xml) + + if msg_json.get('msgtype') == 'stream': + return await self._handle_post_followup_response(msg_json, nonce) + + return await self._handle_post_initial_response(msg_json, nonce) + async def get_message(self, msg_json): message_data = {}