diff --git a/libs/wecom_ai_bot_api/api.py b/libs/wecom_ai_bot_api/api.py
index 41d379a6..9568eab4 100644
--- a/libs/wecom_ai_bot_api/api.py
+++ b/libs/wecom_ai_bot_api/api.py
@@ -295,7 +295,7 @@ class WecomBotClient:
except Exception:
await self.logger.error(traceback.format_exc())
- async def _handle_initial_message(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
+ async def _handle_post_initial_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
"""处理企业微信首次推送的消息,返回 stream_id 并开启流水线。
Args:
@@ -324,7 +324,7 @@ class WecomBotClient:
payload = self._build_stream_payload(session.stream_id, '', False)
return await self._encrypt_and_reply(payload, nonce)
- async def _handle_stream_refresh(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
+ async def _handle_post_followup_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
"""处理企业微信的流式刷新请求,按需返回增量片段。
Args:
@@ -375,57 +375,64 @@ class WecomBotClient:
await self.logger.info(f'{request.method} {request.url} {str(request.args)}')
if request.method == 'GET':
- # GET 用于验证回调 URL,有效期内直接返回微信给的 echostr
- msg_signature = unquote(request.args.get('msg_signature', ''))
- timestamp = unquote(request.args.get('timestamp', ''))
- nonce = unquote(request.args.get('nonce', ''))
- echostr = unquote(request.args.get('echostr', ''))
+ return await self._handle_get_callback()
- if not all([msg_signature, timestamp, nonce, echostr]):
- await self.logger.error('请求参数缺失')
- return Response('缺少参数', status=400)
+ if request.method == 'POST':
+ return await self._handle_post_callback()
- ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
- if ret != 0:
- await self.logger.error('验证URL失败')
- return Response('验证失败', status=403)
-
- return Response(decrypted_str, mimetype='text/plain')
-
- if request.method != 'POST':
- return Response('', status=405)
-
- self.stream_sessions.cleanup()
-
- msg_signature = unquote(request.args.get('msg_signature', ''))
- timestamp = unquote(request.args.get('timestamp', ''))
- nonce = unquote(request.args.get('nonce', ''))
-
- encrypted_json = await request.get_json()
- encrypted_msg = (encrypted_json or {}).get('encrypt', '')
- if not encrypted_msg:
- await self.logger.error("请求体中缺少 'encrypt' 字段")
- return Response('Bad Request', status=400)
-
- xml_post_data = f""
- ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce)
- if ret != 0:
- await self.logger.error('解密失败')
- return Response('解密失败', status=400)
-
- msg_json = json.loads(decrypted_xml)
-
- if msg_json.get('msgtype') == 'stream':
- # 企业微信刷新请求:尝试从队列中取出增量回复
- return await self._handle_stream_refresh(msg_json, nonce)
-
- # 首次请求:快速返回 stream_id 并异步处理流水线
- return await self._handle_initial_message(msg_json, nonce)
+ return Response('', status=405)
except Exception:
await self.logger.error(traceback.format_exc())
return Response('Internal Server Error', status=500)
+ async def _handle_get_callback(self) -> tuple[Response, int] | Response:
+ """处理企业微信的 GET 验证请求。"""
+
+ msg_signature = unquote(request.args.get('msg_signature', ''))
+ timestamp = unquote(request.args.get('timestamp', ''))
+ nonce = unquote(request.args.get('nonce', ''))
+ echostr = unquote(request.args.get('echostr', ''))
+
+ if not all([msg_signature, timestamp, nonce, echostr]):
+ await self.logger.error('请求参数缺失')
+ return Response('缺少参数', status=400)
+
+ ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
+ if ret != 0:
+ await self.logger.error('验证URL失败')
+ return Response('验证失败', status=403)
+
+ return Response(decrypted_str, mimetype='text/plain')
+
+ async def _handle_post_callback(self) -> tuple[Response, int] | Response:
+ """处理企业微信的 POST 回调请求。"""
+
+ self.stream_sessions.cleanup()
+
+ msg_signature = unquote(request.args.get('msg_signature', ''))
+ timestamp = unquote(request.args.get('timestamp', ''))
+ nonce = unquote(request.args.get('nonce', ''))
+
+ encrypted_json = await request.get_json()
+ encrypted_msg = (encrypted_json or {}).get('encrypt', '')
+ if not encrypted_msg:
+ await self.logger.error("请求体中缺少 'encrypt' 字段")
+ return Response('Bad Request', status=400)
+
+ xml_post_data = f""
+ ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce)
+ if ret != 0:
+ await self.logger.error('解密失败')
+ return Response('解密失败', status=400)
+
+ msg_json = json.loads(decrypted_xml)
+
+ if msg_json.get('msgtype') == 'stream':
+ return await self._handle_post_followup_response(msg_json, nonce)
+
+ return await self._handle_post_initial_response(msg_json, nonce)
+
async def get_message(self, msg_json):
message_data = {}