diff --git a/libs/wecom_ai_bot_api/api.py b/libs/wecom_ai_bot_api/api.py
index 6b5dc573..9568eab4 100644
--- a/libs/wecom_ai_bot_api/api.py
+++ b/libs/wecom_ai_bot_api/api.py
@@ -1,189 +1,452 @@
+import asyncio
+import base64
import json
import time
+import traceback
import uuid
import xml.etree.ElementTree as ET
+from dataclasses import dataclass, field
+from typing import Any, Callable, Optional
from urllib.parse import unquote
-import hashlib
-import traceback
import httpx
-from libs.wecom_ai_bot_api.WXBizMsgCrypt3 import WXBizMsgCrypt
-from quart import Quart, request, Response, jsonify
-import langbot_plugin.api.entities.builtin.platform.message as platform_message
-import asyncio
-from libs.wecom_ai_bot_api import wecombotevent
-from typing import Callable
-import base64
from Crypto.Cipher import AES
+from quart import Quart, request, Response, jsonify
+
+from libs.wecom_ai_bot_api import wecombotevent
+from libs.wecom_ai_bot_api.WXBizMsgCrypt3 import WXBizMsgCrypt
from pkg.platform.logger import EventLogger
+@dataclass
+class StreamChunk:
+ """描述单次推送给企业微信的流式片段。"""
+
+ # 需要返回给企业微信的文本内容
+ content: str
+
+ # 标记是否为最终片段,对应企业微信协议里的 finish 字段
+ is_final: bool = False
+
+ # 预留额外元信息,未来支持多模态扩展时可使用
+ meta: dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class StreamSession:
+ """维护一次企业微信流式会话的上下文。"""
+
+ # 企业微信要求的 stream_id,用于标识后续刷新请求
+ stream_id: str
+
+ # 原始消息的 msgid,便于与流水线消息对应
+ msg_id: str
+
+ # 群聊会话标识(单聊时为空)
+ chat_id: Optional[str]
+
+ # 触发消息的发送者
+ user_id: Optional[str]
+
+ # 会话创建时间
+ created_at: float = field(default_factory=time.time)
+
+ # 最近一次被访问的时间,cleanup 依据该值判断过期
+ last_access: float = field(default_factory=time.time)
+
+ # 将流水线增量结果缓存到队列,刷新请求逐条消费
+ queue: asyncio.Queue = field(default_factory=asyncio.Queue)
+
+ # 是否已经完成(收到最终片段)
+ finished: bool = False
+
+ # 缓存最近一次片段,处理重试或超时兜底
+ last_chunk: Optional[StreamChunk] = None
+
+
+class StreamSessionManager:
+ """管理 stream 会话的生命周期,并负责队列的生产消费。"""
+
+ def __init__(self, logger: EventLogger, ttl: int = 60) -> None:
+ self.logger = logger
+
+ self.ttl = ttl # 超时时间(秒),超过该时间未被访问的会话会被清理由 cleanup
+ self._sessions: dict[str, StreamSession] = {} # stream_id -> StreamSession 映射
+ self._msg_index: dict[str, str] = {} # msgid -> stream_id 映射,便于流水线根据消息 ID 找到会话
+
+ def get_stream_id_by_msg(self, msg_id: str) -> Optional[str]:
+ if not msg_id:
+ return None
+ return self._msg_index.get(msg_id)
+
+ def get_session(self, stream_id: str) -> Optional[StreamSession]:
+ return self._sessions.get(stream_id)
+
+ def create_or_get(self, msg_json: dict[str, Any]) -> tuple[StreamSession, bool]:
+ """根据企业微信回调创建或获取会话。
+
+ Args:
+ msg_json: 企业微信解密后的回调 JSON。
+
+ Returns:
+ Tuple[StreamSession, bool]: `StreamSession` 为会话实例,`bool` 指示是否为新建会话。
+
+ Example:
+ 在首次回调中调用,得到 `is_new=True` 后再触发流水线。
+ """
+ msg_id = msg_json.get('msgid', '')
+ if msg_id and msg_id in self._msg_index:
+ stream_id = self._msg_index[msg_id]
+ session = self._sessions.get(stream_id)
+ if session:
+ session.last_access = time.time()
+ return session, False
+
+ stream_id = str(uuid.uuid4())
+ session = StreamSession(
+ stream_id=stream_id,
+ msg_id=msg_id,
+ chat_id=msg_json.get('chatid'),
+ user_id=msg_json.get('from', {}).get('userid'),
+ )
+
+ if msg_id:
+ self._msg_index[msg_id] = stream_id
+ self._sessions[stream_id] = session
+ return session, True
+
+ async def publish(self, stream_id: str, chunk: StreamChunk) -> bool:
+ """向 stream 队列写入新的增量片段。
+
+ Args:
+ stream_id: 企业微信分配的流式会话 ID。
+ chunk: 待发送的增量片段。
+
+ Returns:
+ bool: 当流式队列存在并成功入队时返回 True。
+
+ Example:
+ 在收到模型增量后调用 `await manager.publish('sid', StreamChunk('hello'))`。
+ """
+ session = self._sessions.get(stream_id)
+ if not session:
+ return False
+
+ session.last_access = time.time()
+ session.last_chunk = chunk
+
+ try:
+ session.queue.put_nowait(chunk)
+ except asyncio.QueueFull:
+ # 默认无界队列,此处兜底防御
+ await session.queue.put(chunk)
+
+ if chunk.is_final:
+ session.finished = True
+
+ return True
+
+ async def consume(self, stream_id: str, timeout: float = 0.5) -> Optional[StreamChunk]:
+ """从队列中取出一个片段,若超时返回 None。
+
+ Args:
+ stream_id: 企业微信流式会话 ID。
+ timeout: 取片段的最长等待时间(秒)。
+
+ Returns:
+ Optional[StreamChunk]: 成功时返回片段,超时或会话不存在时返回 None。
+
+ Example:
+ 企业微信刷新到达时调用,若队列有数据则立即返回 `StreamChunk`。
+ """
+ session = self._sessions.get(stream_id)
+ if not session:
+ return None
+
+ session.last_access = time.time()
+
+ try:
+ chunk = await asyncio.wait_for(session.queue.get(), timeout)
+ session.last_access = time.time()
+ if chunk.is_final:
+ session.finished = True
+ return chunk
+ except asyncio.TimeoutError:
+ if session.finished and session.last_chunk:
+ return session.last_chunk
+ return None
+
+ def mark_finished(self, stream_id: str) -> None:
+ session = self._sessions.get(stream_id)
+ if session:
+ session.finished = True
+ session.last_access = time.time()
+
+ def cleanup(self) -> None:
+ """定期清理过期会话,防止队列与映射无上限累积。"""
+ now = time.time()
+ expired: list[str] = []
+ for stream_id, session in self._sessions.items():
+ if now - session.last_access > self.ttl:
+ expired.append(stream_id)
+
+ for stream_id in expired:
+ session = self._sessions.pop(stream_id, None)
+ if not session:
+ continue
+ msg_id = session.msg_id
+ if msg_id and self._msg_index.get(msg_id) == stream_id:
+ self._msg_index.pop(msg_id, None)
+
class WecomBotClient:
- def __init__(self,Token:str,EnCodingAESKey:str,Corpid:str,logger:EventLogger):
- self.Token=Token
- self.EnCodingAESKey=EnCodingAESKey
- self.Corpid=Corpid
+ def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger):
+ """企业微信智能机器人客户端。
+
+ Args:
+ Token: 企业微信回调验证使用的 token。
+ EnCodingAESKey: 企业微信消息加解密密钥。
+ Corpid: 企业 ID。
+ logger: 日志记录器。
+
+ Example:
+ >>> client = WecomBotClient(Token='token', EnCodingAESKey='aeskey', Corpid='corp', logger=logger)
+ """
+
+ self.Token = Token
+ self.EnCodingAESKey = EnCodingAESKey
+ self.Corpid = Corpid
self.ReceiveId = ''
self.app = Quart(__name__)
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
- methods=['POST','GET']
+ methods=['POST', 'GET']
)
self._message_handlers = {
'example': [],
}
- self.user_stream_map = {}
self.logger = logger
- self.generated_content = {}
- self.msg_id_map = {}
+ self.generated_content: dict[str, str] = {}
+ self.msg_id_map: dict[str, int] = {}
+ self.stream_sessions = StreamSessionManager(logger=logger)
+ self.stream_poll_timeout = 0.5
- async def sha1_signature(token: str, timestamp: str, nonce: str, encrypt: str) -> str:
- raw = "".join(sorted([token, timestamp, nonce, encrypt]))
- return hashlib.sha1(raw.encode("utf-8")).hexdigest()
-
- async def handle_callback_request(self):
+ @staticmethod
+ def _build_stream_payload(stream_id: str, content: str, finish: bool) -> dict[str, Any]:
+ """按照企业微信协议拼装返回报文。
+
+ Args:
+ stream_id: 企业微信会话 ID。
+ content: 推送的文本内容。
+ finish: 是否为最终片段。
+
+ Returns:
+ dict[str, Any]: 可直接加密返回的 payload。
+
+ Example:
+ 组装 `{'msgtype': 'stream', 'stream': {'id': 'sid', ...}}` 结构。
+ """
+ return {
+ 'msgtype': 'stream',
+ 'stream': {
+ 'id': stream_id,
+ 'finish': finish,
+ 'content': content,
+ },
+ }
+
+ async def _encrypt_and_reply(self, payload: dict[str, Any], nonce: str) -> tuple[Response, int]:
+ """对响应进行加密封装并返回给企业微信。
+
+ Args:
+ payload: 待加密的响应内容。
+ nonce: 企业微信回调参数中的 nonce。
+
+ Returns:
+ Tuple[Response, int]: Quart Response 对象及状态码。
+
+ Example:
+ 在首包或刷新场景中调用以生成加密响应。
+ """
+ reply_plain_str = json.dumps(payload, ensure_ascii=False)
+ reply_timestamp = str(int(time.time()))
+ ret, encrypt_text = self.wxcpt.EncryptMsg(reply_plain_str, nonce, reply_timestamp)
+ if ret != 0:
+ await self.logger.error(f'加密失败: {ret}')
+ return jsonify({'error': 'encrypt_failed'}), 500
+
+ root = ET.fromstring(encrypt_text)
+ encrypt = root.find('Encrypt').text
+ resp = {
+ 'encrypt': encrypt,
+ }
+ return jsonify(resp), 200
+
+ async def _dispatch_event(self, event: wecombotevent.WecomBotEvent) -> None:
+ """异步触发流水线处理,避免阻塞首包响应。
+
+ Args:
+ event: 由企业微信消息转换的内部事件对象。
+ """
try:
- self.wxcpt=WXBizMsgCrypt(self.Token,self.EnCodingAESKey,'')
-
- if request.method == "GET":
-
- msg_signature = unquote(request.args.get("msg_signature", ""))
- timestamp = unquote(request.args.get("timestamp", ""))
- nonce = unquote(request.args.get("nonce", ""))
- echostr = unquote(request.args.get("echostr", ""))
-
- if not all([msg_signature, timestamp, nonce, echostr]):
- await self.logger.error("请求参数缺失")
- return Response("缺少参数", status=400)
-
- ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
- if ret != 0:
-
- await self.logger.error("验证URL失败")
- return Response("验证失败", status=403)
-
- return Response(decrypted_str, mimetype="text/plain")
-
- elif request.method == "POST":
- msg_signature = unquote(request.args.get("msg_signature", ""))
- timestamp = unquote(request.args.get("timestamp", ""))
- nonce = unquote(request.args.get("nonce", ""))
-
- try:
- timeout = 3
- interval = 0.1
- start_time = time.monotonic()
- encrypted_json = await request.get_json()
- encrypted_msg = encrypted_json.get("encrypt", "")
- if not encrypted_msg:
- await self.logger.error("请求体中缺少 'encrypt' 字段")
-
- xml_post_data = f""
- ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce)
- if ret != 0:
- await self.logger.error("解密失败")
-
-
- msg_json = json.loads(decrypted_xml)
-
- from_user_id = msg_json.get("from", {}).get("userid")
- chatid = msg_json.get("chatid", "")
-
- message_data = await self.get_message(msg_json)
-
-
-
- if message_data:
- try:
- event = wecombotevent.WecomBotEvent(message_data)
- if event:
- await self._handle_message(event)
- except Exception as e:
- await self.logger.error(traceback.format_exc())
- print(traceback.format_exc())
-
- start_time = time.time()
- try:
- if msg_json.get('chattype','') == 'single':
- if from_user_id in self.user_stream_map:
- stream_id = self.user_stream_map[from_user_id]
- else:
- stream_id =str(uuid.uuid4())
- self.user_stream_map[from_user_id] = stream_id
-
-
- else:
-
- if chatid in self.user_stream_map:
- stream_id = self.user_stream_map[chatid]
- else:
- stream_id = str(uuid.uuid4())
- self.user_stream_map[chatid] = stream_id
- except Exception as e:
- await self.logger.error(traceback.format_exc())
- print(traceback.format_exc())
- while True:
- content = self.generated_content.pop(msg_json['msgid'],None)
- if content:
- reply_plain = {
- "msgtype": "stream",
- "stream": {
- "id": stream_id,
- "finish": True,
- "content": content
- }
- }
- reply_plain_str = json.dumps(reply_plain, ensure_ascii=False)
-
- reply_timestamp = str(int(time.time()))
- ret, encrypt_text = self.wxcpt.EncryptMsg(reply_plain_str, nonce, reply_timestamp)
- if ret != 0:
-
- await self.logger.error("加密失败"+str(ret))
-
-
- root = ET.fromstring(encrypt_text)
- encrypt = root.find("Encrypt").text
- resp = {
- "encrypt": encrypt,
- }
- return jsonify(resp), 200
-
- if time.time() - start_time > timeout:
- break
-
- await asyncio.sleep(interval)
-
- if self.msg_id_map.get(message_data['msgid'], 1) == 3:
- await self.logger.error('请求失效:暂不支持智能机器人超过7秒的请求,如有需求,请联系 LangBot 团队。')
- return ''
-
- except Exception as e:
- await self.logger.error(traceback.format_exc())
- print(traceback.format_exc())
-
- except Exception as e:
+ await self._handle_message(event)
+ except Exception:
await self.logger.error(traceback.format_exc())
- print(traceback.format_exc())
-
- async def get_message(self,msg_json):
+ async def _handle_post_initial_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
+ """处理企业微信首次推送的消息,返回 stream_id 并开启流水线。
+
+ Args:
+ msg_json: 解密后的企业微信消息 JSON。
+ nonce: 企业微信回调参数 nonce。
+
+ Returns:
+ Tuple[Response, int]: Quart Response 及状态码。
+
+ Example:
+ 首次回调时调用,立即返回带 `stream_id` 的响应。
+ """
+ session, is_new = self.stream_sessions.create_or_get(msg_json)
+
+ message_data = await self.get_message(msg_json)
+ if message_data:
+ message_data['stream_id'] = session.stream_id
+ try:
+ event = wecombotevent.WecomBotEvent(message_data)
+ except Exception:
+ await self.logger.error(traceback.format_exc())
+ else:
+ if is_new:
+ asyncio.create_task(self._dispatch_event(event))
+
+ payload = self._build_stream_payload(session.stream_id, '', False)
+ return await self._encrypt_and_reply(payload, nonce)
+
+ async def _handle_post_followup_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
+ """处理企业微信的流式刷新请求,按需返回增量片段。
+
+ Args:
+ msg_json: 解密后的企业微信刷新请求。
+ nonce: 企业微信回调参数 nonce。
+
+ Returns:
+ Tuple[Response, int]: Quart Response 及状态码。
+
+ Example:
+ 在刷新请求中调用,按需返回增量片段。
+ """
+ stream_info = msg_json.get('stream', {})
+ stream_id = stream_info.get('id', '')
+ if not stream_id:
+ await self.logger.error('刷新请求缺少 stream.id')
+ return await self._encrypt_and_reply(self._build_stream_payload('', '', True), nonce)
+
+ session = self.stream_sessions.get_session(stream_id)
+ chunk = await self.stream_sessions.consume(stream_id, timeout=self.stream_poll_timeout)
+
+ if not chunk:
+ cached_content = None
+ if session and session.msg_id:
+ cached_content = self.generated_content.pop(session.msg_id, None)
+ if cached_content is not None:
+ chunk = StreamChunk(content=cached_content, is_final=True)
+ else:
+ payload = self._build_stream_payload(stream_id, '', False)
+ return await self._encrypt_and_reply(payload, nonce)
+
+ payload = self._build_stream_payload(stream_id, chunk.content, chunk.is_final)
+ if chunk.is_final:
+ self.stream_sessions.mark_finished(stream_id)
+ return await self._encrypt_and_reply(payload, nonce)
+
+ async def handle_callback_request(self):
+ """企业微信回调入口。
+
+ Returns:
+ Quart Response: 根据请求类型返回验证、首包或刷新结果。
+
+ Example:
+ 作为 Quart 路由处理函数直接注册并使用。
+ """
+ try:
+ self.wxcpt = WXBizMsgCrypt(self.Token, self.EnCodingAESKey, '')
+ await self.logger.info(f'{request.method} {request.url} {str(request.args)}')
+
+ if request.method == 'GET':
+ return await self._handle_get_callback()
+
+ if request.method == 'POST':
+ return await self._handle_post_callback()
+
+ return Response('', status=405)
+
+ except Exception:
+ await self.logger.error(traceback.format_exc())
+ return Response('Internal Server Error', status=500)
+
+ async def _handle_get_callback(self) -> tuple[Response, int] | Response:
+ """处理企业微信的 GET 验证请求。"""
+
+ msg_signature = unquote(request.args.get('msg_signature', ''))
+ timestamp = unquote(request.args.get('timestamp', ''))
+ nonce = unquote(request.args.get('nonce', ''))
+ echostr = unquote(request.args.get('echostr', ''))
+
+ if not all([msg_signature, timestamp, nonce, echostr]):
+ await self.logger.error('请求参数缺失')
+ return Response('缺少参数', status=400)
+
+ ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
+ if ret != 0:
+ await self.logger.error('验证URL失败')
+ return Response('验证失败', status=403)
+
+ return Response(decrypted_str, mimetype='text/plain')
+
+ async def _handle_post_callback(self) -> tuple[Response, int] | Response:
+ """处理企业微信的 POST 回调请求。"""
+
+ self.stream_sessions.cleanup()
+
+ msg_signature = unquote(request.args.get('msg_signature', ''))
+ timestamp = unquote(request.args.get('timestamp', ''))
+ nonce = unquote(request.args.get('nonce', ''))
+
+ encrypted_json = await request.get_json()
+ encrypted_msg = (encrypted_json or {}).get('encrypt', '')
+ if not encrypted_msg:
+ await self.logger.error("请求体中缺少 'encrypt' 字段")
+ return Response('Bad Request', status=400)
+
+ xml_post_data = f""
+ ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce)
+ if ret != 0:
+ await self.logger.error('解密失败')
+ return Response('解密失败', status=400)
+
+ msg_json = json.loads(decrypted_xml)
+
+ if msg_json.get('msgtype') == 'stream':
+ return await self._handle_post_followup_response(msg_json, nonce)
+
+ return await self._handle_post_initial_response(msg_json, nonce)
+
+ async def get_message(self, msg_json):
message_data = {}
- if msg_json.get('chattype','') == 'single':
+ if msg_json.get('chattype', '') == 'single':
message_data['type'] = 'single'
- elif msg_json.get('chattype','') == 'group':
+ elif msg_json.get('chattype', '') == 'group':
message_data['type'] = 'group'
if msg_json.get('msgtype') == 'text':
- message_data['content'] = msg_json.get('text',{}).get('content')
+ message_data['content'] = msg_json.get('text', {}).get('content')
elif msg_json.get('msgtype') == 'image':
- picurl = msg_json.get('image', {}).get('url','')
- base64 = await self.download_url_to_base64(picurl,self.EnCodingAESKey)
- message_data['picurl'] = base64
+ picurl = msg_json.get('image', {}).get('url', '')
+ base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey)
+ message_data['picurl'] = base64
elif msg_json.get('msgtype') == 'mixed':
items = msg_json.get('mixed', {}).get('msg_item', [])
texts = []
@@ -197,8 +460,8 @@ class WecomBotClient:
if texts:
message_data['content'] = "".join(texts) # 拼接所有 text
if picurl:
- base64 = await self.download_url_to_base64(picurl,self.EnCodingAESKey)
- message_data['picurl'] = base64 # 只保留第一个 image
+ base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey)
+ message_data['picurl'] = base64 # 只保留第一个 image
message_data['userid'] = msg_json.get('from', {}).get('userid', '')
message_data['msgid'] = msg_json.get('msgid', '')
@@ -207,7 +470,7 @@ class WecomBotClient:
message_data['aibotid'] = msg_json.get('aibotid', '')
return message_data
-
+
async def _handle_message(self, event: wecombotevent.WecomBotEvent):
"""
处理消息事件。
@@ -223,10 +486,46 @@ class WecomBotClient:
for handler in self._message_handlers[msg_type]:
await handler(event)
except Exception:
- print(traceback.format_exc())
+ print(traceback.format_exc())
+
+ async def push_stream_chunk(self, msg_id: str, content: str, is_final: bool = False) -> bool:
+ """将流水线片段推送到 stream 会话。
+
+ Args:
+ msg_id: 原始企业微信消息 ID。
+ content: 模型产生的片段内容。
+ is_final: 是否为最终片段。
+
+ Returns:
+ bool: 当成功写入流式队列时返回 True。
+
+ Example:
+ 在流水线 `reply_message_chunk` 中调用,将增量推送至企业微信。
+ """
+ # 根据 msg_id 找到对应 stream 会话,如果不存在说明当前消息非流式
+ stream_id = self.stream_sessions.get_stream_id_by_msg(msg_id)
+ if not stream_id:
+ return False
+
+ chunk = StreamChunk(content=content, is_final=is_final)
+ await self.stream_sessions.publish(stream_id, chunk)
+ if is_final:
+ self.stream_sessions.mark_finished(stream_id)
+ return True
async def set_message(self, msg_id: str, content: str):
- self.generated_content[msg_id] = content
+ """兼容旧逻辑:若无法流式返回则缓存最终结果。
+
+ Args:
+ msg_id: 企业微信消息 ID。
+ content: 最终回复的文本内容。
+
+ Example:
+ 在非流式场景下缓存最终结果以备刷新时返回。
+ """
+ handled = await self.push_stream_chunk(msg_id, content, is_final=True)
+ if not handled:
+ self.generated_content[msg_id] = content
def on_message(self, msg_type: str):
def decorator(func: Callable[[wecombotevent.WecomBotEvent], None]):
@@ -237,7 +536,6 @@ class WecomBotClient:
return decorator
-
async def download_url_to_base64(self, download_url, encoding_aes_key):
async with httpx.AsyncClient() as client:
response = await client.get(download_url)
@@ -247,26 +545,22 @@ class WecomBotClient:
encrypted_bytes = response.content
-
aes_key = base64.b64decode(encoding_aes_key + "=") # base64 补齐
iv = aes_key[:16]
-
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
decrypted = cipher.decrypt(encrypted_bytes)
-
pad_len = decrypted[-1]
decrypted = decrypted[:-pad_len]
-
- if decrypted.startswith(b"\xff\xd8"): # JPEG
+ if decrypted.startswith(b"\xff\xd8"): # JPEG
mime_type = "image/jpeg"
elif decrypted.startswith(b"\x89PNG"): # PNG
mime_type = "image/png"
elif decrypted.startswith((b"GIF87a", b"GIF89a")): # GIF
mime_type = "image/gif"
- elif decrypted.startswith(b"BM"): # BMP
+ elif decrypted.startswith(b"BM"): # BMP
mime_type = "image/bmp"
elif decrypted.startswith(b"II*\x00") or decrypted.startswith(b"MM\x00*"): # TIFF
mime_type = "image/tiff"
@@ -276,15 +570,9 @@ class WecomBotClient:
# 转 base64
base64_str = base64.b64encode(decrypted).decode("utf-8")
return f"data:{mime_type};base64,{base64_str}"
-
async def run_task(self, host: str, port: int, *args, **kwargs):
"""
启动 Quart 应用。
"""
await self.app.run_task(host=host, port=port, *args, **kwargs)
-
-
-
-
-
diff --git a/pkg/platform/sources/wecombot.py b/pkg/platform/sources/wecombot.py
index 9487b637..e5f2d1b5 100644
--- a/pkg/platform/sources/wecombot.py
+++ b/pkg/platform/sources/wecombot.py
@@ -117,6 +117,50 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
content = await self.message_converter.yiri2target(message)
await self.bot.set_message(message_source.source_platform_object.message_id, content)
+ async def reply_message_chunk(
+ self,
+ message_source: platform_events.MessageEvent,
+ bot_message,
+ message: platform_message.MessageChain,
+ quote_origin: bool = False,
+ is_final: bool = False,
+ ):
+ """将流水线增量输出写入企业微信 stream 会话。
+
+ Args:
+ message_source: 流水线提供的原始消息事件。
+ bot_message: 当前片段对应的模型元信息(未使用)。
+ message: 需要回复的消息链。
+ quote_origin: 是否引用原消息(企业微信暂不支持)。
+ is_final: 标记当前片段是否为最终回复。
+
+ Returns:
+ dict: 包含 `stream` 键,标识写入是否成功。
+
+ Example:
+ 在流水线 `reply_message_chunk` 调用中自动触发,无需手动调用。
+ """
+ # 转换为纯文本(智能机器人当前协议仅支持文本流)
+ content = await self.message_converter.yiri2target(message)
+ msg_id = message_source.source_platform_object.message_id
+
+ # 将片段推送到 WecomBotClient 中的队列,返回值用于判断是否走降级逻辑
+ success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
+ if not success and is_final:
+ # 未命中流式队列时使用旧有 set_message 兜底
+ await self.bot.set_message(msg_id, content)
+ return {'stream': success}
+
+ async def is_stream_output_supported(self) -> bool:
+ """智能机器人侧默认开启流式能力。
+
+ Returns:
+ bool: 恒定返回 True。
+
+ Example:
+ 流水线执行阶段会调用此方法以确认是否启用流式。"""
+ return True
+
async def send_message(self, target_type, target_id, message):
pass