From 05c1fdaa9eb1f0f9d62c53285d75555385839d06 Mon Sep 17 00:00:00 2001 From: wangcham Date: Mon, 10 Feb 2025 06:08:59 -0500 Subject: [PATCH] =?UTF-8?q?feat:=20add=20adapter=20for=20=E5=BE=AE?= =?UTF-8?q?=E4=BF=A1=E5=85=AC=E4=BC=97=E5=8F=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- libs/official_account_api/__init__.py | 0 libs/official_account_api/api.py | 175 ++++++++++++++++++++++++ libs/official_account_api/oaevent.py | 167 ++++++++++++++++++++++ libs/wecom_api/api.py | 3 +- pkg/platform/manager.py | 2 +- pkg/platform/sources/officialaccount.py | 155 +++++++++++++++++++++ pkg/platform/sources/qqofficial.py | 1 - pkg/platform/sources/wecom.py | 1 - 8 files changed, 500 insertions(+), 4 deletions(-) create mode 100644 libs/official_account_api/__init__.py create mode 100644 libs/official_account_api/api.py create mode 100644 libs/official_account_api/oaevent.py create mode 100644 pkg/platform/sources/officialaccount.py diff --git a/libs/official_account_api/__init__.py b/libs/official_account_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/official_account_api/api.py b/libs/official_account_api/api.py new file mode 100644 index 00000000..92ba47e5 --- /dev/null +++ b/libs/official_account_api/api.py @@ -0,0 +1,175 @@ +# 微信公众号的加解密算法与企业微信一样,所以直接使用企业微信的加解密算法文件 +import time +import traceback +from ..wecom_api.WXBizMsgCrypt3 import WXBizMsgCrypt +import xml.etree.ElementTree as ET +from quart import Quart,request +import hashlib +from typing import Callable, Dict, Any +from .oaevent import OAEvent +import httpx + +import asyncio +import time +import xml.etree.ElementTree as ET + + + +xml_template = """ + + + + {create_time} + + + +""" + +class OAClient(): + + def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str): + self.token = token + self.aes = EncodingAESKey + self.appid = AppID + self.appsecret = Appsecret + self.base_url = 'https://api.weixin.qq.com' + self.access_token = '' + self.app = Quart(__name__) + self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']) + self._message_handlers = { + "example":[], + } + self.access_token_expiry_time = None + self.msg_id_map = {} + + async def handle_callback_request(self): + + try: + # 每隔100毫秒查询是否生成ai回答 + start_time = time.time() + signature = request.args.get("signature", "") + timestamp = request.args.get("timestamp", "") + nonce = request.args.get("nonce", "") + echostr = request.args.get("echostr", "") + msg_signature = request.args.get("msg_signature","") + if msg_signature is None: + raise Exception("msg_signature不在请求体中") + + if request.method == 'GET': + # 校验签名 + check_str = "".join(sorted([self.token, timestamp, nonce])) + check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest() + + if check_signature == signature: + return echostr # 验证成功返回echostr + else: + raise Exception("拒绝请求") + elif request.method == "POST": + encryt_msg = await request.data + wxcpt = WXBizMsgCrypt(self.token,self.aes,self.appid) + ret,xml_msg = wxcpt.DecryptMsg(encryt_msg,msg_signature,timestamp,nonce) + xml_msg = xml_msg.decode('utf-8') + + if ret != 0: + raise Exception("消息解密失败") + + message_data = await self.get_message(xml_msg) + if message_data : + event = OAEvent.from_payload(message_data) + if event: + await self._handle_message(event) + + root = ET.fromstring(xml_msg) + from_user = root.find("FromUserName").text # 发送者 + to_user = root.find("ToUserName").text # 机器人 + + from pkg.platform.sources import officialaccount + + timeout = 4.80 + interval = 0.1 + while True: + content = officialaccount.generated_content.pop(message_data["MsgId"], None) + if content: + response_xml = xml_template.format( + to_user=from_user, + from_user=to_user, + create_time=int(time.time()), + content = content + ) + + return response_xml + + if time.time() - start_time >= timeout: + break + + await asyncio.sleep(interval) + + if self.msg_id_map.get(message_data["MsgId"], 1) == 3: + + response_xml = xml_template.format( + to_user=from_user, + from_user=to_user, + create_time=int(time.time()), + content = "请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。" + ) + return response_xml + + except Exception as e: + traceback.print_exc() + + + async def get_message(self, xml_msg: str): + + root = ET.fromstring(xml_msg) + + message_data = { + "ToUserName": root.find("ToUserName").text, + "FromUserName": root.find("FromUserName").text, + "CreateTime": int(root.find("CreateTime").text), + "MsgType": root.find("MsgType").text, + "Content": root.find("Content").text if root.find("Content") is not None else None, + "MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None, + } + + return message_data + + + async def run_task(self, host: str, port: int, *args, **kwargs): + """ + 启动 Quart 应用。 + """ + await self.app.run_task(host=host, port=port, *args, **kwargs) + + + def on_message(self, msg_type: str): + """ + 注册消息类型处理器。 + """ + def decorator(func: Callable[[OAEvent], None]): + if msg_type not in self._message_handlers: + self._message_handlers[msg_type] = [] + self._message_handlers[msg_type].append(func) + return func + return decorator + + async def _handle_message(self, event: OAEvent): + """ + 处理消息事件。 + """ + message_id = event.message_id + if message_id in self.msg_id_map.keys(): + self.msg_id_map[message_id] += 1 + return + + self.msg_id_map[message_id] = 1 + msg_type = event.type + if msg_type in self._message_handlers: + for handler in self._message_handlers[msg_type]: + await handler(event) + + + + + + + diff --git a/libs/official_account_api/oaevent.py b/libs/official_account_api/oaevent.py new file mode 100644 index 00000000..ebbccd7e --- /dev/null +++ b/libs/official_account_api/oaevent.py @@ -0,0 +1,167 @@ +from typing import Dict, Any, Optional + + +class OAEvent(dict): + """ + 封装从微信公众号收到的事件数据对象(字典),提供属性以获取其中的字段。 + + 除 `type` 和 `detail_type` 属性对于任何事件都有效外,其它属性是否存在(若不存在则返回 `None`)依事件类型不同而不同。 + """ + + @staticmethod + def from_payload(payload: Dict[str, Any]) -> Optional["OAEvent"]: + """ + 从微信公众号事件数据构造 `WecomEvent` 对象。 + + Args: + payload (Dict[str, Any]): 解密后的微信事件数据。 + + Returns: + Optional[OAEvent]: 如果事件数据合法,则返回 OAEvent 对象;否则返回 None。 + """ + try: + event = OAEvent(payload) + _ = event.type, event.detail_type # 确保必须字段存在 + return event + except KeyError: + return None + + @property + def type(self) -> str: + """ + 事件类型,例如 "message"、"event"、"text" 等。 + + Returns: + str: 事件类型。 + """ + return self.get("MsgType", "") + + @property + def picurl(self) -> str: + """ + 图片链接 + """ + return self.get("PicUrl","") + + @property + def detail_type(self) -> str: + """ + 事件详细类型,依 `type` 的不同而不同。例如: + - 消息事件: "text", "image", "voice", 等 + - 事件通知: "subscribe", "unsubscribe", "click", 等 + + Returns: + str: 事件详细类型。 + """ + if self.type == "event": + return self.get("Event", "") + return self.type + + @property + def name(self) -> str: + """ + 事件名,对于消息事件是 `type.detail_type`,对于其他事件是 `event_type`。 + + Returns: + str: 事件名。 + """ + return f"{self.type}.{self.detail_type}" + + @property + def user_id(self) -> Optional[str]: + """ + 发送方账号 + """ + return self.get("FromUserName") + + + @property + def receiver_id(self) -> Optional[str]: + """ + 接收者 ID,例如机器人自身的公众号微信 ID。 + + Returns: + Optional[str]: 接收者 ID。 + """ + return self.get("ToUserName") + + @property + def message_id(self) -> Optional[str]: + """ + 消息 ID,仅在消息类型事件中存在。 + + Returns: + Optional[str]: 消息 ID。 + """ + return self.get("MsgId") + + @property + def message(self) -> Optional[str]: + """ + 消息内容,仅在消息类型事件中存在。 + + Returns: + Optional[str]: 消息内容。 + """ + return self.get("Content") + + @property + def media_id(self) -> Optional[str]: + """ + 媒体文件 ID,仅在图片、语音等消息类型中存在。 + + Returns: + Optional[str]: 媒体文件 ID。 + """ + return self.get("MediaId") + + @property + def timestamp(self) -> Optional[int]: + """ + 事件发生的时间戳。 + + Returns: + Optional[int]: 时间戳。 + """ + return self.get("CreateTime") + + @property + def event_key(self) -> Optional[str]: + """ + 事件的 Key 值,例如点击菜单时的 `EventKey`。 + + Returns: + Optional[str]: 事件 Key。 + """ + return self.get("EventKey") + + def __getattr__(self, key: str) -> Optional[Any]: + """ + 允许通过属性访问数据中的任意字段。 + + Args: + key (str): 字段名。 + + Returns: + Optional[Any]: 字段值。 + """ + return self.get(key) + + def __setattr__(self, key: str, value: Any) -> None: + """ + 允许通过属性设置数据中的任意字段。 + + Args: + key (str): 字段名。 + value (Any): 字段值。 + """ + self[key] = value + + def __repr__(self) -> str: + """ + 生成事件对象的字符串表示。 + + Returns: + str: 字符串表示。 + """ + return f"" diff --git a/libs/wecom_api/api.py b/libs/wecom_api/api.py index e3376dd0..69f92e08 100644 --- a/libs/wecom_api/api.py +++ b/libs/wecom_api/api.py @@ -171,6 +171,7 @@ class WecomClient(): elif request.method == "POST": encrypt_msg = await request.data ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce) + # print(xml_msg) if ret != 0: raise Exception(f"消息解密失败,错误码: {ret}") @@ -228,7 +229,7 @@ class WecomClient(): if message_data["MsgType"] == "image": message_data["MediaId"] = root.find("MediaId").text if root.find("MediaId") is not None else None message_data["PicUrl"] = root.find("PicUrl").text if root.find("PicUrl") is not None else None - + return message_data @staticmethod diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 85302ca4..c70417d2 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -39,7 +39,7 @@ class PlatformManager: async def initialize(self): - from .sources import nakuru, aiocqhttp, qqbotpy, qqofficial, wecom, lark, discord, gewechat + from .sources import nakuru, aiocqhttp, qqbotpy, qqofficial, wecom, lark, discord, gewechat, officialaccount async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessageSourceAdapter): diff --git a/pkg/platform/sources/officialaccount.py b/pkg/platform/sources/officialaccount.py new file mode 100644 index 00000000..856388f2 --- /dev/null +++ b/pkg/platform/sources/officialaccount.py @@ -0,0 +1,155 @@ +from __future__ import annotations +import typing +import asyncio +import traceback +import time +import datetime +from pkg.core import app +from pkg.platform.adapter import MessageSourceAdapter +from pkg.platform.types import events as platform_events, message as platform_message + +import aiocqhttp +import aiohttp +from libs.official_account_api.oaevent import OAEvent +from pkg.platform.adapter import MessageSourceAdapter +from pkg.platform.types import events as platform_events, message as platform_message +from libs.official_account_api.api import OAClient +from pkg.core import app +from .. import adapter +from ...pipeline.longtext.strategies import forward +from ...core import app +from ..types import message as platform_message +from ..types import events as platform_events +from ..types import entities as platform_entities +from ...command.errors import ParamNotEnoughError + + +# 生成的ai回答 +generated_content = {} + +class OAMessageConverter(adapter.MessageConverter): + @staticmethod + async def yiri2target(message_chain: platform_message.MessageChain): + for msg in message_chain: + if type(msg) is platform_message.Plain: + return msg.text + + + @staticmethod + async def target2yiri(message:str,message_id =-1): + yiri_msg_list = [] + yiri_msg_list.append( + platform_message.Source(id=message_id, time=datetime.datetime.now()) + ) + + yiri_msg_list.append(platform_message.Plain(text=message)) + chain = platform_message.MessageChain(yiri_msg_list) + + return chain + + +class OAEventConverter(adapter.EventConverter): + @staticmethod + async def target2yiri(event:OAEvent): + if event.type == "text": + yiri_chain = await OAMessageConverter.target2yiri( + event.message, event.message_id + ) + + friend = platform_entities.Friend( + id=event.user_id, + nickname=str(event.user_id), + remark="", + ) + + return platform_events.FriendMessage( + sender=friend, message_chain=yiri_chain, time=event.timestamp, source_platform_object=event + ) + else: + return None + +@adapter.adapter_class("officialaccount") +class OfficialAccountAdapter(adapter.MessageSourceAdapter): + + bot : OAClient + ap : app.Application + bot_account_id: str + message_converter: OAMessageConverter = OAMessageConverter() + event_converter: OAEventConverter = OAEventConverter() + config: dict + + + def __init__(self, config: dict, ap: app.Application): + self.config = config + + self.ap = ap + + required_keys = [ + "token", + "EncodingAESKey", + "AppSecret", + "AppID", + ] + missing_keys = [key for key in required_keys if key not in config] + if missing_keys: + raise ParamNotEnoughError("企业微信缺少相关配置项,请查看文档或联系管理员") + + self.bot = OAClient( + token=config['token'], + EncodingAESKey=config['EncodingAESKey'], + Appsecret=config['AppSecret'], + AppID=config['AppID'], + ) + + async def reply_message(self, message_source: platform_events.FriendMessage, message: platform_message.MessageChain, quote_origin: bool = False): + global generated_content + + content = await OAMessageConverter.yiri2target( + message + ) + + generated_content[message_source.message_chain.message_id] = content + + async def send_message( + self, target_type: str, target_id: str, message: platform_message.MessageChain + ): + pass + + + def register_listener(self, event_type: type, callback: typing.Callable[[platform_events.Event, MessageSourceAdapter], None]): + async def on_message(event: OAEvent): + self.bot_account_id = event.receiver_id + try: + return await callback( + await self.event_converter.target2yiri(event), self + ) + except: + traceback.print_exc() + + if event_type == platform_events.FriendMessage: + self.bot.on_message("text")(on_message) + elif event_type == platform_events.GroupMessage: + pass + + async def run_async(self): + async def shutdown_trigger_placeholder(): + while True: + await asyncio.sleep(1) + + await self.bot.run_task( + host=self.config["host"], + port=self.config["port"], + shutdown_trigger=shutdown_trigger_placeholder, + ) + + async def kill(self) -> bool: + return False + + async def unregister_listener( + self, + event_type: type, + callback: typing.Callable[[platform_events.Event, MessageSourceAdapter], None], + ): + return super().unregister_listener(event_type, callback) + + \ No newline at end of file diff --git a/pkg/platform/sources/qqofficial.py b/pkg/platform/sources/qqofficial.py index f41e84db..924e7ba0 100644 --- a/pkg/platform/sources/qqofficial.py +++ b/pkg/platform/sources/qqofficial.py @@ -47,7 +47,6 @@ class QQOfficialMessageConverter(adapter.MessageConverter): yiri_msg_list.append( platform_message.Image(base64=base64_url) ) - message = '' yiri_msg_list.append(platform_message.Plain(text=message)) chain = platform_message.MessageChain(yiri_msg_list) return chain diff --git a/pkg/platform/sources/wecom.py b/pkg/platform/sources/wecom.py index 38de84e9..64b08abe 100644 --- a/pkg/platform/sources/wecom.py +++ b/pkg/platform/sources/wecom.py @@ -154,7 +154,6 @@ class WecomeAdapter(adapter.MessageSourceAdapter): message_converter: WecomMessageConverter = WecomMessageConverter() event_converter: WecomEventConverter = WecomEventConverter() config: dict - ap: app.Application def __init__(self, config: dict, ap: app.Application): self.config = config