diff --git a/libs/wecom_api/api.py b/libs/wecom_api/api.py index 298d32bb..0d02ffcd 100644 --- a/libs/wecom_api/api.py +++ b/libs/wecom_api/api.py @@ -1,11 +1,14 @@ from quart import request from .WXBizMsgCrypt3 import WXBizMsgCrypt - +import base64 +import binascii import httpx from quart import Quart import xml.etree.ElementTree as ET from typing import Callable, Dict, Any from .wecomevent import WecomEvent +from pkg.platform.types import events as platform_events, message as platform_message +import aiofiles class WecomClient(): @@ -42,7 +45,6 @@ class WecomClient(): else: raise Exception(f"未获取access token: {data}") - async def get_users(self): if not self.check_access_token_for_contacts(): self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts) @@ -88,6 +90,30 @@ class WecomClient(): data = response.json() if data['errcode'] != 0: raise Exception("Failed to send message: "+str(data)) + + async def send_image(self,user_id:str,agent_id:int,media_id:str): + if not await self.check_access_token(): + self.access_token = await self.get_access_token(self.secret) + url = self.base_url+'/media/upload?access_token='+self.access_token + async with httpx.AsyncClient() as client: + params = { + "touser" : user_id, + "toparty" : "", + "totag":"", + "agentid" : agent_id, + "msgtype" : "image", + "image" : { + "media_id" : media_id, + }, + "safe":0, + "enable_id_trans": 0, + "enable_duplicate_check": 0, + "duplicate_check_interval": 1800 + } + response = await client.post(url,json=params) + data = response.json() + if data['errcode'] != 0: + raise Exception("Failed to send image: "+str(data)) async def send_private_msg(self,user_id:str, agent_id:int,content:str): if not await self.check_access_token(): @@ -188,13 +214,92 @@ class WecomClient(): "MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None, "AgentID": int(root.find("AgentID").text) if root.find("AgentID") is not None else None, } - return message_data - - - - - - - + if message_data["MsgType"] == "image": + message_data["MediaId"] = root.find("MediaId").text if root.find("MediaId") is not None else None + message_data["PicUrl"] = root.find("PicUrl").text if root.find("PicUrl") is not None else None - \ No newline at end of file + return message_data + + @staticmethod + async def get_image_type(image_bytes: bytes) -> str: + """ + 通过图片的magic numbers判断图片类型 + """ + magic_numbers = { + b'\xFF\xD8\xFF': 'jpg', + b'\x89\x50\x4E\x47': 'png', + b'\x47\x49\x46': 'gif', + b'\x42\x4D': 'bmp', + b'\x00\x00\x01\x00': 'ico' + } + + for magic, ext in magic_numbers.items(): + if image_bytes.startswith(magic): + return ext + return 'jpg' # 默认返回jpg + + + async def upload_to_work(self, image: platform_message.Image): + """ + 获取 media_id + """ + if not await self.check_access_token(): + self.access_token = await self.get_access_token(self.secret) + + url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file' + file_bytes = None + file_name = "uploaded_file.txt" + + # 获取文件的二进制数据 + if image.path: + async with aiofiles.open(image.path, 'rb') as f: + file_bytes = await f.read() + file_name = image.path.split('/')[-1] + elif image.url: + file_bytes = await self.download_image_to_bytes(image.url) + file_name = image.url.split('/')[-1] + elif image.base64: + try: + base64_data = image.base64 + if ',' in base64_data: + base64_data = base64_data.split(',', 1)[1] + padding = 4 - (len(base64_data) % 4) if len(base64_data) % 4 else 0 + padded_base64 = base64_data + '=' * padding + file_bytes = base64.b64decode(padded_base64) + except binascii.Error as e: + raise ValueError(f"Invalid base64 string: {str(e)}") + else: + raise ValueError("image对象出错") + + # 设置 multipart/form-data 格式的文件 + boundary = "-------------------------acebdf13572468" + headers = { + 'Content-Type': f'multipart/form-data; boundary={boundary}' + } + body = ( + f"--{boundary}\r\n" + f"Content-Disposition: form-data; name=\"media\"; filename=\"{file_name}\"; filelength={len(file_bytes)}\r\n" + f"Content-Type: application/octet-stream\r\n\r\n" + ).encode('utf-8') + file_bytes + f"\r\n--{boundary}--\r\n".encode('utf-8') + + # 上传文件 + async with httpx.AsyncClient() as client: + response = await client.post(url, headers=headers, content=body) + data = response.json() + if data.get('errcode', 0) != 0: + raise Exception("failed to upload file") + + return data.get('media_id') + + + async def download_image_to_bytes(self,url:str) -> bytes: + async with httpx.AsyncClient() as client: + response = await client.get(url) + response.raise_for_status() + return response.content + + #进行media_id的获取 + async def get_media_id(self, image: platform_message.Image): + + media_id = await self.upload_to_work(image=image) + return media_id diff --git a/libs/wecom_api/wecomevent.py b/libs/wecom_api/wecomevent.py index d5e02808..3606cdf5 100644 --- a/libs/wecom_api/wecomevent.py +++ b/libs/wecom_api/wecomevent.py @@ -35,6 +35,13 @@ class WecomEvent(dict): str: 事件类型。 """ return self.get("MsgType", "") + + @property + def picurl(self) -> str: + """ + 图片链接 + """ + return self.get("PicUrl") @property def detail_type(self) -> str: diff --git a/pkg/platform/sources/wecom.py b/pkg/platform/sources/wecom.py index e47cfd6f..dda22816 100644 --- a/pkg/platform/sources/wecom.py +++ b/pkg/platform/sources/wecom.py @@ -12,7 +12,6 @@ from pkg.platform.adapter import MessageSourceAdapter from pkg.platform.types import events as platform_events, message as platform_message from libs.wecom_api.wecomevent import WecomEvent from pkg.core import app - from .. import adapter from ...pipeline.longtext.strategies import forward from ...core import app @@ -20,24 +19,54 @@ from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities from ...command.errors import ParamNotEnoughError - +from ...utils import image class WecomMessageConverter(adapter.MessageConverter): - @staticmethod - async def yiri2target(message_chain:platform_message.MessageChain): - content='' - for msg in message_chain: - if type(msg) is platform_message.Plain: - content+=msg.text - - return content - @staticmethod - async def target2yiri(message:str,message_id:int = -1): + async def yiri2target( + message_chain: platform_message.MessageChain, bot: WecomClient + ): + content_list = [] + + [ + { + "type": "text", + "content": "text", + }, + { + "type": "image", + "media_id": "media_id", + } + ] + + for msg in message_chain: + if type(msg) is platform_message.Plain: + content_list.append({ + "type": "text", + "content": msg.text, + }) + elif type(msg) is platform_message.Image: + content_list.append({ + "type": "image", + "media_id": await bot.get_media_id(msg), + }) + elif type(msg) is platform_message.Forward: + for node in msg.node_list: + content_list.extend((await WecomMessageConverter.yiri2target(node.message_chain, bot))) + else: + content_list.append({ + "type": "text", + "content": str(msg), + }) + + return content_list + + @staticmethod + async def target2yiri(message: str, message_id: int = -1): yiri_msg_list = [] yiri_msg_list.append( - platform_message.Source(id = message_id,time = datetime.datetime.now()) + platform_message.Source(id=message_id, time=datetime.datetime.now()) ) yiri_msg_list.append(platform_message.Plain(text=message)) @@ -45,31 +74,46 @@ class WecomMessageConverter(adapter.MessageConverter): return chain + @staticmethod + async def target2yiri_image(picurl: str, message_id: int = -1): + yiri_msg_list = [] + yiri_msg_list.append( + platform_message.Source(id=message_id, time=datetime.datetime.now()) + ) + image_base64, image_format = await image.get_wecom_image_base64(pic_url=picurl) + yiri_msg_list.append(platform_message.Image(base64=f"data:image/{image_format};base64,{image_base64}")) + chain = platform_message.MessageChain(yiri_msg_list) + return chain class WecomEventConverter: - @staticmethod - async def yiri2target(event:platform_events.Event,bot_account_id:int) -> WecomEvent: - content = await WecomMessageConverter.yiri2target(event.message_chain) - if type(event) is platform_events.GroupMessage: - pass - - if type(event) is platform_events.FriendMessage: - payload = { - "MsgType": "text", - "Content": content, - "FromUserName": event.sender.id, - "ToUserName": bot_account_id, - "CreateTime": int(datetime.datetime.now().timestamp()), - "AgentID": event.sender.nickname - } + @staticmethod + async def yiri2target( + event: platform_events.Event, bot_account_id: int, bot: WecomClient + ) -> WecomEvent: + # only for extracting user information + + if type(event) is platform_events.GroupMessage: + pass + + if type(event) is platform_events.FriendMessage: + + payload = { + "MsgType": "text", + "Content": '', + "FromUserName": event.sender.id, + "ToUserName": bot_account_id, + "CreateTime": int(datetime.datetime.now().timestamp()), + "AgentID": event.sender.nickname, + } wecom_event = WecomEvent.from_payload(payload=payload) if not wecom_event: raise ValueError("无法从 message_data 构造 WecomEvent 对象") + return wecom_event - + @staticmethod async def target2yiri(event: WecomEvent): """ @@ -82,112 +126,133 @@ class WecomEventConverter: platform_events.FriendMessage: 转换后的 FriendMessage 对象。 """ # 转换消息链 - yiri_chain = await WecomMessageConverter.target2yiri( - event.message, event.message_id - ) - - # 判断消息类型并进行转换 - # if event.message_type == "private": 默认消息都是从好友发出 - - friend = platform_entities.Friend( - id=event.user_id, - nickname=str(event.agent_id), - remark="", - ) - - return platform_events.FriendMessage( - sender=friend, - message_chain=yiri_chain, - time=event.timestamp + if event.type == "text": + yiri_chain = await WecomMessageConverter.target2yiri( + event.message, event.message_id + ) + + friend = platform_entities.Friend( + id=event.user_id, + nickname=str(event.agent_id), + remark="", + ) + + return platform_events.FriendMessage( + sender=friend, message_chain=yiri_chain, time=event.timestamp + ) + elif event.type == "image": + friend = platform_entities.Friend( + id=event.user_id, + nickname=str(event.agent_id), + remark="", + ) + + yiri_chain = await WecomMessageConverter.target2yiri_image( + picurl=event.picurl, message_id=event.message_id + ) + + return platform_events.FriendMessage( + sender=friend, message_chain=yiri_chain, time=event.timestamp ) - @adapter.adapter_class("wecom") class WecomeAdapter(adapter.MessageSourceAdapter): - bot:WecomClient - ap:app.Application - bot_account_id:str - message_converter:WecomMessageConverter = WecomMessageConverter() - event_converter:WecomEventConverter = WecomEventConverter() - config:dict - ap:app.Application + bot: WecomClient + ap: app.Application + bot_account_id: str + message_converter: WecomMessageConverter = WecomMessageConverter() + event_converter: WecomEventConverter = WecomEventConverter() + config: dict + ap: app.Application - def __init__(self, config: dict, ap:app.Application): + def __init__(self, config: dict, ap: app.Application): self.config = config - #这里需要对config里的内容换成企业微信的config。是config:corpid,token...... + self.ap = ap - - required_keys = ["corpid","secret","token","EncodingAESKey","contacts_secret"] + + required_keys = [ + "corpid", + "secret", + "token", + "EncodingAESKey", + "contacts_secret", + ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: raise ParamNotEnoughError("企业微信缺少相关配置项,请查看文档或联系管理员") self.bot = WecomClient( - corpid=config['corpid'], - secret=config['secret'], - token=config['token'], - EncodingAESKey=config['EncodingAESKey'], - contacts_secret=config['contacts_secret'] + corpid=config["corpid"], + secret=config["secret"], + token=config["token"], + EncodingAESKey=config["EncodingAESKey"], + contacts_secret=config["contacts_secret"], ) - - async def reply_message(self,message_source:platform_events.MessageEvent,message:platform_message.MessageChain, - quote_origin:bool=False, + + async def reply_message( + self, + message_source: platform_events.MessageEvent, + message: platform_message.MessageChain, + quote_origin: bool = False, ): - Wecom_event = await WecomEventConverter.yiri2target(message_source,self.bot_account_id) - Wecom_msg = await WecomMessageConverter.yiri2target(message) - # message_converter传回一个消息str - user_id = Wecom_event.user_id - agent_id = Wecom_event.agent_id - return await self.bot.send_private_msg(user_id=user_id,agent_id=agent_id,content=Wecom_msg) + Wecom_event = await WecomEventConverter.yiri2target( + message_source, self.bot_account_id, self.bot + ) + content_list = await WecomMessageConverter.yiri2target(message, self.bot) - async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): + for content in content_list: + if content["type"] == "text": + await self.bot.send_private_msg(Wecom_event.user_id, Wecom_event.agent_id, content["content"]) + elif content["type"] == "image": + await self.bot.send_image(Wecom_event.user_id, Wecom_event.agent_id, content["media_id"]) + + async def send_message( + self, target_type: str, target_id: str, message: platform_message.MessageChain + ): pass - - def register_listener( - self, - event_type:typing.Type[platform_events.Event], - callback:typing.Callable[[platform_events.Event,adapter.MessageSourceAdapter],None], - + self, + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[ + [platform_events.Event, adapter.MessageSourceAdapter], None + ], ): - async def on_message(event:WecomEvent): + async def on_message(event: WecomEvent): self.bot_account_id = event.receiver_id try: - return await callback(await self.event_converter.target2yiri(event),self) + return await callback( + await self.event_converter.target2yiri(event), self + ) except: traceback.print_exc() + if event_type == platform_events.FriendMessage: self.bot.on_message("text")(on_message) + self.bot.on_message("image")(on_message) elif event_type == platform_events.GroupMessage: pass - + async def run_async(self): async def shutdown_trigger_placeholder(): while True: await asyncio.sleep(1) - await self.bot.run_task(host=self.config['host'],port=self.config['port'],shutdown_trigger=shutdown_trigger_placeholder) + await self.bot.run_task( + host=self.config["host"], + port=self.config["port"], + shutdown_trigger=shutdown_trigger_placeholder, + ) async def kill(self) -> bool: return False - - async def unregister_listener(self, event_type: type, callback: typing.Callable[[platform_events.Event, MessageSourceAdapter], None]): + + async def unregister_listener( + self, + event_type: type, + callback: typing.Callable[[platform_events.Event, MessageSourceAdapter], None], + ): return super().unregister_listener(event_type, callback) - - - - - - - - - - - - - - diff --git a/pkg/utils/image.py b/pkg/utils/image.py index 06885175..6f769b26 100644 --- a/pkg/utils/image.py +++ b/pkg/utils/image.py @@ -7,6 +7,31 @@ import ssl import aiohttp import PIL.Image +async def get_wecom_image_base64(pic_url: str) -> tuple[str, str]: + """ + 下载企业微信图片并转换为 base64 + :param pic_url: 企业微信图片URL + :return: (base64_str, image_format) + """ + async with aiohttp.ClientSession() as session: + async with session.get(pic_url) as response: + if response.status != 200: + raise Exception(f"Failed to download image: {response.status}") + + # 读取图片数据 + image_data = await response.read() + + # 获取图片格式 + content_type = response.headers.get('Content-Type', '') + image_format = content_type.split('/')[-1] # 例如 'image/jpeg' -> 'jpeg' + + # 转换为 base64 + import base64 + image_base64 = base64.b64encode(image_data).decode('utf-8') + + return image_base64, image_format + + def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]: """获取QQ图片的下载链接"""