diff --git a/libs/official_account_api/api.py b/libs/official_account_api/api.py index 0ed4d8b4..af427063 100644 --- a/libs/official_account_api/api.py +++ b/libs/official_account_api/api.py @@ -1,4 +1,5 @@ # 微信公众号的加解密算法与企业微信一样,所以直接使用企业微信的加解密算法文件 +from collections import deque import time import traceback from ..wecom_api.WXBizMsgCrypt3 import WXBizMsgCrypt @@ -12,6 +13,7 @@ import httpx import asyncio import time import xml.etree.ElementTree as ET +from pkg.platform.sources import officialaccount as oa @@ -169,6 +171,139 @@ class OAClient(): await handler(event) + +class OAClientForLongerResponse(): + + def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str): + self.token = token + self.aes = EncodingAESKey + self.appid = AppID + self.appsecret = Appsecret + self.base_url = 'https://api.weixin.qq.com' + self.access_token = '' + self.app = Quart(__name__) + self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']) + self._message_handlers = { + "example":[], + } + self.access_token_expiry_time = None + + async def handle_callback_request(self): + try: + start_time = time.time() + signature = request.args.get("signature", "") + timestamp = request.args.get("timestamp", "") + nonce = request.args.get("nonce", "") + echostr = request.args.get("echostr", "") + msg_signature = request.args.get("msg_signature", "") + + if msg_signature is None: + raise Exception("msg_signature不在请求体中") + + if request.method == 'GET': + check_str = "".join(sorted([self.token, timestamp, nonce])) + check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest() + return echostr if check_signature == signature else "拒绝请求" + + elif request.method == "POST": + encryt_msg = await request.data + wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid) + ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce) + xml_msg = xml_msg.decode('utf-8') + + if ret != 0: + raise Exception("消息解密失败") + + # 解析 XML + root = ET.fromstring(xml_msg) + from_user = root.find("FromUserName").text + to_user = root.find("ToUserName").text + + + from pkg.platform.sources import officialaccount as oa + + + if oa.msg_queue.get(from_user) and oa.msg_queue[from_user][0]["content"]: + queue_top = oa.msg_queue[from_user].pop(0) + queue_content = queue_top["content"] + + response_xml = xml_template.format( + to_user=from_user, + from_user=to_user, + create_time=int(time.time()), + content=queue_content + ) + return response_xml + + else: + response_xml = xml_template.format( + to_user=from_user, + from_user=to_user, + create_time=int(time.time()), + content="AI正在思考中,请发送任意内容获取回答。" + ) + + message_data = await self.get_message(xml_msg) + if message_data: + event = OAEvent.from_payload(message_data) + if event: + await self._handle_message(event) + + return response_xml + + except Exception as e: + traceback.print_exc() + + + + async def get_message(self, xml_msg: str): + + root = ET.fromstring(xml_msg) + + message_data = { + "ToUserName": root.find("ToUserName").text, + "FromUserName": root.find("FromUserName").text, + "CreateTime": int(root.find("CreateTime").text), + "MsgType": root.find("MsgType").text, + "Content": root.find("Content").text if root.find("Content") is not None else None, + "MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None, + } + + return message_data + + + + async def run_task(self, host: str, port: int, *args, **kwargs): + """ + 启动 Quart 应用。 + """ + await self.app.run_task(host=host, port=port, *args, **kwargs) + + + + def on_message(self, msg_type: str): + """ + 注册消息类型处理器。 + """ + def decorator(func: Callable[[OAEvent], None]): + if msg_type not in self._message_handlers: + self._message_handlers[msg_type] = [] + self._message_handlers[msg_type].append(func) + return func + return decorator + + async def _handle_message(self, event: OAEvent): + """ + 处理消息事件。 + """ + + msg_type = event.type + if msg_type in self._message_handlers: + for handler in self._message_handlers[msg_type]: + await handler(event) + + + diff --git a/pkg/platform/sources/officialaccount.py b/pkg/platform/sources/officialaccount.py index 1e2c5ce6..b308f517 100644 --- a/pkg/platform/sources/officialaccount.py +++ b/pkg/platform/sources/officialaccount.py @@ -7,14 +7,14 @@ import datetime from pkg.core import app from pkg.platform.adapter import MessagePlatformAdapter from pkg.platform.types import events as platform_events, message as platform_message - +from collections import deque from libs.official_account_api.oaevent import OAEvent from pkg.platform.adapter import MessagePlatformAdapter from pkg.platform.types import events as platform_events, message as platform_message from libs.official_account_api.api import OAClient +from libs.official_account_api.api import OAClientForLongerResponse from pkg.core import app from .. import adapter -from ...pipeline.longtext.strategies import forward from ...core import app from ..types import message as platform_message from ..types import events as platform_events @@ -24,6 +24,7 @@ from ...command.errors import ParamNotEnoughError # 生成的ai回答 generated_content = {} +msg_queue = {} class OAMessageConverter(adapter.MessageConverter): @staticmethod @@ -86,17 +87,29 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): "EncodingAESKey", "AppSecret", "AppID", + "Mode", ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: raise ParamNotEnoughError("微信公众号缺少相关配置项,请查看文档或联系管理员") - self.bot = OAClient( - token=config['token'], - EncodingAESKey=config['EncodingAESKey'], - Appsecret=config['AppSecret'], - AppID=config['AppID'], - ) + # 模式1为15s内回复,模式2为超过15s回复 + + if self.config['Mode'] == 1: + self.bot = OAClient( + token=config['token'], + EncodingAESKey=config['EncodingAESKey'], + Appsecret=config['AppSecret'], + AppID=config['AppID'], + ) + if self.config['Mode'] == 2: + self.bot = OAClientForLongerResponse( + token=config['token'], + EncodingAESKey=config['EncodingAESKey'], + Appsecret=config['AppSecret'], + AppID=config['AppID'], + ) + async def reply_message(self, message_source: platform_events.FriendMessage, message: platform_message.MessageChain, quote_origin: bool = False): global generated_content @@ -107,6 +120,20 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): generated_content[message_source.message_chain.message_id] = content + from_user = message_source.sender.id + + + if from_user not in msg_queue: + msg_queue[from_user] = [] + + msg_queue[from_user].append( + { + "msg_id":message_source.message_chain.message_id, + "content":content, + } + ) + + async def send_message( self, target_type: str, target_id: str, message: platform_message.MessageChain ): diff --git a/templates/platform.json b/templates/platform.json index a0fb8ce2..051649e4 100644 --- a/templates/platform.json +++ b/templates/platform.json @@ -77,6 +77,7 @@ "EncodingAESKey":"", "AppID":"", "AppSecret":"", + "Mode":1, "host": "0.0.0.0", "port": 2287 },