mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-05 05:16:03 +00:00
style: introduce ruff as linter and formatter (#1356)
* style: remove necessary imports * style: fix F841 * style: fix F401 * style: fix F811 * style: fix E402 * style: fix E721 * style: fix E722 * style: fix E722 * style: fix F541 * style: ruff format * style: all passed * style: add ruff in deps * style: more ignores in ruff.toml * style: add pre-commit
This commit is contained in:
committed by
GitHub
parent
09e70d70e9
commit
209f16af76
@@ -17,7 +17,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
||||
|
||||
bot_account_id: int
|
||||
"""机器人账号ID,需要在初始化时设置"""
|
||||
|
||||
|
||||
config: dict
|
||||
|
||||
ap: app.Application
|
||||
@@ -32,14 +32,11 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: platform_message.MessageChain
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
"""主动发送消息
|
||||
|
||||
|
||||
Args:
|
||||
target_type (str): 目标类型,`person`或`group`
|
||||
target_id (str): 目标ID
|
||||
@@ -51,7 +48,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
"""回复消息
|
||||
|
||||
@@ -69,23 +66,27 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_message.Event],
|
||||
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None]
|
||||
callback: typing.Callable[
|
||||
[platform_message.Event, MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
"""注册事件监听器
|
||||
|
||||
|
||||
Args:
|
||||
event_type (typing.Type[platform.types.Event]): 事件类型
|
||||
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_message.Event],
|
||||
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None]
|
||||
callback: typing.Callable[
|
||||
[platform_message.Event, MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
"""注销事件监听器
|
||||
|
||||
|
||||
Args:
|
||||
event_type (typing.Type[platform.types.Event]): 事件类型
|
||||
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
|
||||
@@ -98,7 +99,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
||||
|
||||
async def kill(self) -> bool:
|
||||
"""关闭适配器
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否成功关闭,热重载时若此函数返回False则不会重载MessageSource底层
|
||||
"""
|
||||
@@ -107,6 +108,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
||||
|
||||
class MessageConverter:
|
||||
"""消息链转换器基类"""
|
||||
|
||||
@staticmethod
|
||||
def yiri2target(message_chain: platform_message.MessageChain):
|
||||
"""将源平台消息链转换为目标平台消息链
|
||||
|
||||
@@ -1,23 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
import sqlalchemy
|
||||
|
||||
from .sources import qqofficial
|
||||
|
||||
# FriendMessage, Image, MessageChain, Plain
|
||||
from . import adapter as msadapter
|
||||
|
||||
from ..core import app, entities as core_entities, taskmgr
|
||||
from ..plugin import events
|
||||
from .types import message as platform_message
|
||||
from .types import events as platform_events
|
||||
from .types import entities as platform_entities
|
||||
|
||||
from ..discover import engine
|
||||
|
||||
@@ -25,6 +18,7 @@ from ..entity.persistence import bot as persistence_bot
|
||||
|
||||
# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题
|
||||
from . import types as mirai
|
||||
|
||||
sys.modules['mirai'] = mirai
|
||||
|
||||
|
||||
@@ -43,7 +37,12 @@ class RuntimeBot:
|
||||
|
||||
task_context: taskmgr.TaskContext
|
||||
|
||||
def __init__(self, ap: app.Application, bot_entity: persistence_bot.Bot, adapter: msadapter.MessagePlatformAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
ap: app.Application,
|
||||
bot_entity: persistence_bot.Bot,
|
||||
adapter: msadapter.MessagePlatformAdapter,
|
||||
):
|
||||
self.ap = ap
|
||||
self.bot_entity = bot_entity
|
||||
self.enable = bot_entity.enable
|
||||
@@ -51,9 +50,10 @@ class RuntimeBot:
|
||||
self.task_context = taskmgr.TaskContext()
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter):
|
||||
|
||||
async def on_friend_message(
|
||||
event: platform_events.FriendMessage,
|
||||
adapter: msadapter.MessagePlatformAdapter,
|
||||
):
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
@@ -64,8 +64,10 @@ class RuntimeBot:
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter):
|
||||
|
||||
async def on_group_message(
|
||||
event: platform_events.GroupMessage,
|
||||
adapter: msadapter.MessagePlatformAdapter,
|
||||
):
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
||||
@@ -76,17 +78,10 @@ class RuntimeBot:
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
self.adapter.register_listener(
|
||||
platform_events.FriendMessage,
|
||||
on_friend_message
|
||||
)
|
||||
self.adapter.register_listener(
|
||||
platform_events.GroupMessage,
|
||||
on_group_message
|
||||
)
|
||||
self.adapter.register_listener(platform_events.FriendMessage, on_friend_message)
|
||||
self.adapter.register_listener(platform_events.GroupMessage, on_group_message)
|
||||
|
||||
async def run(self):
|
||||
|
||||
async def exception_wrapper():
|
||||
try:
|
||||
self.task_context.set_current_action('Running...')
|
||||
@@ -98,16 +93,19 @@ class RuntimeBot:
|
||||
return
|
||||
self.task_context.set_current_action('Exited with error.')
|
||||
self.task_context.log(f'平台适配器运行出错: {e}')
|
||||
self.task_context.log(f"Traceback: {traceback.format_exc()}")
|
||||
self.task_context.log(f'Traceback: {traceback.format_exc()}')
|
||||
self.ap.logger.error(f'平台适配器运行出错: {e}')
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
self.ap.logger.debug(f'Traceback: {traceback.format_exc()}')
|
||||
|
||||
self.task_wrapper = self.ap.task_mgr.create_task(
|
||||
exception_wrapper(),
|
||||
kind="platform-adapter",
|
||||
name=f"platform-adapter-{self.adapter.__class__.__name__}",
|
||||
kind='platform-adapter',
|
||||
name=f'platform-adapter-{self.adapter.__class__.__name__}',
|
||||
context=self.task_context,
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM]
|
||||
scopes=[
|
||||
core_entities.LifecycleControlScope.APPLICATION,
|
||||
core_entities.LifecycleControlScope.PLATFORM,
|
||||
],
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
@@ -118,7 +116,6 @@ class RuntimeBot:
|
||||
|
||||
# 控制QQ消息输入输出的类
|
||||
class PlatformManager:
|
||||
|
||||
# ====== 4.0 ======
|
||||
ap: app.Application = None
|
||||
|
||||
@@ -129,18 +126,20 @@ class PlatformManager:
|
||||
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]]
|
||||
|
||||
def __init__(self, ap: app.Application = None):
|
||||
|
||||
self.ap = ap
|
||||
self.bots = []
|
||||
self.adapter_components = []
|
||||
self.adapter_dict = {}
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter')
|
||||
async def initialize(self):
|
||||
self.adapter_components = self.ap.discover.get_components_by_kind(
|
||||
'MessagePlatformAdapter'
|
||||
)
|
||||
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {}
|
||||
for component in self.adapter_components:
|
||||
adapter_dict[component.metadata.name] = component.get_python_component_class()
|
||||
adapter_dict[component.metadata.name] = (
|
||||
component.get_python_component_class()
|
||||
)
|
||||
self.adapter_dict = adapter_dict
|
||||
|
||||
await self.load_bots_from_db()
|
||||
@@ -158,12 +157,15 @@ class PlatformManager:
|
||||
)
|
||||
|
||||
bots = result.all()
|
||||
|
||||
|
||||
for bot in bots:
|
||||
# load all bots here, enable or disable will be handled in runtime
|
||||
await self.load_bot(bot)
|
||||
|
||||
async def load_bot(self, bot_entity: persistence_bot.Bot | sqlalchemy.Row[persistence_bot.Bot] | dict) -> RuntimeBot:
|
||||
async def load_bot(
|
||||
self,
|
||||
bot_entity: persistence_bot.Bot | sqlalchemy.Row[persistence_bot.Bot] | dict,
|
||||
) -> RuntimeBot:
|
||||
"""加载机器人"""
|
||||
if isinstance(bot_entity, sqlalchemy.Row):
|
||||
bot_entity = persistence_bot.Bot(**bot_entity._mapping)
|
||||
@@ -171,14 +173,11 @@ class PlatformManager:
|
||||
bot_entity = persistence_bot.Bot(**bot_entity)
|
||||
|
||||
adapter_inst = self.adapter_dict[bot_entity.adapter](
|
||||
bot_entity.adapter_config,
|
||||
self.ap
|
||||
bot_entity.adapter_config, self.ap
|
||||
)
|
||||
|
||||
runtime_bot = RuntimeBot(
|
||||
ap=self.ap,
|
||||
bot_entity=bot_entity,
|
||||
adapter=adapter_inst
|
||||
ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst
|
||||
)
|
||||
|
||||
await runtime_bot.initialize()
|
||||
@@ -186,7 +185,7 @@ class PlatformManager:
|
||||
self.bots.append(runtime_bot)
|
||||
|
||||
return runtime_bot
|
||||
|
||||
|
||||
async def get_bot_by_uuid(self, bot_uuid: str) -> RuntimeBot | None:
|
||||
for bot in self.bots:
|
||||
if bot.bot_entity.uuid == bot_uuid:
|
||||
@@ -202,24 +201,28 @@ class PlatformManager:
|
||||
return
|
||||
|
||||
def get_available_adapters_info(self) -> list[dict]:
|
||||
return [
|
||||
component.to_plain_dict()
|
||||
for component in self.adapter_components
|
||||
]
|
||||
return [component.to_plain_dict() for component in self.adapter_components]
|
||||
|
||||
def get_available_adapter_info_by_name(self, name: str) -> dict | None:
|
||||
for component in self.adapter_components:
|
||||
if component.metadata.name == name:
|
||||
return component.to_plain_dict()
|
||||
return None
|
||||
|
||||
def get_available_adapter_manifest_by_name(self, name: str) -> engine.Component | None:
|
||||
|
||||
def get_available_adapter_manifest_by_name(
|
||||
self, name: str
|
||||
) -> engine.Component | None:
|
||||
for component in self.adapter_components:
|
||||
if component.metadata.name == name:
|
||||
return component
|
||||
return None
|
||||
|
||||
async def write_back_config(self, adapter_name: str, adapter_inst: msadapter.MessagePlatformAdapter, config: dict):
|
||||
async def write_back_config(
|
||||
self,
|
||||
adapter_name: str,
|
||||
adapter_inst: msadapter.MessagePlatformAdapter,
|
||||
config: dict,
|
||||
):
|
||||
# index = -2
|
||||
|
||||
# for i, adapter in enumerate(self.adapters):
|
||||
@@ -251,7 +254,7 @@ class PlatformManager:
|
||||
# TODO implement this
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
async def run(self):
|
||||
# This method will only be called when the application launching
|
||||
for bot in self.bots:
|
||||
if bot.enable:
|
||||
|
||||
@@ -2,24 +2,23 @@ from __future__ import annotations
|
||||
import typing
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
import datetime
|
||||
|
||||
import aiocqhttp
|
||||
import aiohttp
|
||||
|
||||
from .. import adapter
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ...utils import image
|
||||
|
||||
class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
|
||||
class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]:
|
||||
async def yiri2target(
|
||||
message_chain: platform_message.MessageChain,
|
||||
) -> typing.Tuple[list, int, datetime.datetime]:
|
||||
msg_list = aiocqhttp.Message()
|
||||
|
||||
msg_id = 0
|
||||
@@ -35,7 +34,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
arg = ''
|
||||
if msg.base64:
|
||||
arg = msg.base64
|
||||
msg_list.append(aiocqhttp.MessageSegment.image(f"base64://{arg}"))
|
||||
msg_list.append(aiocqhttp.MessageSegment.image(f'base64://{arg}'))
|
||||
elif msg.url:
|
||||
arg = msg.url
|
||||
msg_list.append(aiocqhttp.MessageSegment.image(arg))
|
||||
@@ -45,12 +44,12 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
elif type(msg) is platform_message.At:
|
||||
msg_list.append(aiocqhttp.MessageSegment.at(msg.target))
|
||||
elif type(msg) is platform_message.AtAll:
|
||||
msg_list.append(aiocqhttp.MessageSegment.at("all"))
|
||||
msg_list.append(aiocqhttp.MessageSegment.at('all'))
|
||||
elif type(msg) is platform_message.Voice:
|
||||
arg = ''
|
||||
if msg.base64:
|
||||
arg = msg.base64
|
||||
msg_list.append(aiocqhttp.MessageSegment.record(f"base64://{arg}"))
|
||||
msg_list.append(aiocqhttp.MessageSegment.record(f'base64://{arg}'))
|
||||
elif msg.url:
|
||||
arg = msg.url
|
||||
msg_list.append(aiocqhttp.MessageSegment.record(arg))
|
||||
@@ -58,10 +57,15 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
arg = msg.path
|
||||
msg_list.append(aiocqhttp.MessageSegment.record(msg.path))
|
||||
elif type(msg) is platform_message.Forward:
|
||||
|
||||
for node in msg.node_list:
|
||||
msg_list.extend((await AiocqhttpMessageConverter.yiri2target(node.message_chain))[0])
|
||||
|
||||
msg_list.extend(
|
||||
(
|
||||
await AiocqhttpMessageConverter.yiri2target(
|
||||
node.message_chain
|
||||
)
|
||||
)[0]
|
||||
)
|
||||
|
||||
else:
|
||||
msg_list.append(aiocqhttp.MessageSegment.text(str(msg)))
|
||||
|
||||
@@ -78,20 +82,26 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
)
|
||||
|
||||
for msg in message:
|
||||
if msg.type == "at":
|
||||
if msg.data["qq"] == "all":
|
||||
if msg.type == 'at':
|
||||
if msg.data['qq'] == 'all':
|
||||
yiri_msg_list.append(platform_message.AtAll())
|
||||
else:
|
||||
yiri_msg_list.append(
|
||||
platform_message.At(
|
||||
target=msg.data["qq"],
|
||||
target=msg.data['qq'],
|
||||
)
|
||||
)
|
||||
elif msg.type == "text":
|
||||
yiri_msg_list.append(platform_message.Plain(text=msg.data["text"]))
|
||||
elif msg.type == "image":
|
||||
image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url'])
|
||||
yiri_msg_list.append(platform_message.Image(base64=f"data:image/{image_format};base64,{image_base64}"))
|
||||
elif msg.type == 'text':
|
||||
yiri_msg_list.append(platform_message.Plain(text=msg.data['text']))
|
||||
elif msg.type == 'image':
|
||||
image_base64, image_format = await image.qq_image_url_to_base64(
|
||||
msg.data['url']
|
||||
)
|
||||
yiri_msg_list.append(
|
||||
platform_message.Image(
|
||||
base64=f'data:image/{image_format};base64,{image_base64}'
|
||||
)
|
||||
)
|
||||
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
|
||||
@@ -99,7 +109,6 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
|
||||
|
||||
class AiocqhttpEventConverter(adapter.EventConverter):
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int):
|
||||
return event.source_platform_object
|
||||
@@ -110,49 +119,50 @@ class AiocqhttpEventConverter(adapter.EventConverter):
|
||||
event.message, event.message_id
|
||||
)
|
||||
|
||||
if event.message_type == "group":
|
||||
permission = "MEMBER"
|
||||
if event.message_type == 'group':
|
||||
permission = 'MEMBER'
|
||||
|
||||
if "role" in event.sender:
|
||||
if event.sender["role"] == "admin":
|
||||
permission = "ADMINISTRATOR"
|
||||
elif event.sender["role"] == "owner":
|
||||
permission = "OWNER"
|
||||
if 'role' in event.sender:
|
||||
if event.sender['role'] == 'admin':
|
||||
permission = 'ADMINISTRATOR'
|
||||
elif event.sender['role'] == 'owner':
|
||||
permission = 'OWNER'
|
||||
converted_event = platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=event.sender["user_id"], # message_seq 放哪?
|
||||
member_name=event.sender["nickname"],
|
||||
id=event.sender['user_id'], # message_seq 放哪?
|
||||
member_name=event.sender['nickname'],
|
||||
permission=permission,
|
||||
group=platform_entities.Group(
|
||||
id=event.group_id,
|
||||
name=event.sender["nickname"],
|
||||
name=event.sender['nickname'],
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title=event.sender["title"] if "title" in event.sender else "",
|
||||
special_title=event.sender['title']
|
||||
if 'title' in event.sender
|
||||
else '',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=yiri_chain,
|
||||
time=event.time,
|
||||
source_platform_object=event
|
||||
source_platform_object=event,
|
||||
)
|
||||
return converted_event
|
||||
elif event.message_type == "private":
|
||||
elif event.message_type == 'private':
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=event.sender["user_id"],
|
||||
nickname=event.sender["nickname"],
|
||||
remark="",
|
||||
id=event.sender['user_id'],
|
||||
nickname=event.sender['nickname'],
|
||||
remark='',
|
||||
),
|
||||
message_chain=yiri_chain,
|
||||
time=event.time,
|
||||
source_platform_object=event
|
||||
source_platform_object=event,
|
||||
)
|
||||
|
||||
|
||||
class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
bot: aiocqhttp.CQHttp
|
||||
|
||||
bot_account_id: int
|
||||
@@ -170,14 +180,14 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
async def shutdown_trigger_placeholder():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
self.config['shutdown_trigger'] = shutdown_trigger_placeholder
|
||||
|
||||
self.ap = ap
|
||||
|
||||
if "access-token" in config:
|
||||
self.bot = aiocqhttp.CQHttp(access_token=config["access-token"])
|
||||
del self.config["access-token"]
|
||||
if 'access-token' in config:
|
||||
self.bot = aiocqhttp.CQHttp(access_token=config['access-token'])
|
||||
del self.config['access-token']
|
||||
else:
|
||||
self.bot = aiocqhttp.CQHttp()
|
||||
|
||||
@@ -186,9 +196,9 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
):
|
||||
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
|
||||
|
||||
if target_type == "group":
|
||||
if target_type == 'group':
|
||||
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
|
||||
elif target_type == "person":
|
||||
elif target_type == 'person':
|
||||
await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg)
|
||||
|
||||
async def reply_message(
|
||||
@@ -196,16 +206,17 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
aiocq_event = await AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
|
||||
):
|
||||
aiocq_event = await AiocqhttpEventConverter.yiri2target(
|
||||
message_source, self.bot_account_id
|
||||
)
|
||||
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
|
||||
if quote_origin:
|
||||
aiocq_msg = aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg
|
||||
aiocq_msg = (
|
||||
aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg
|
||||
)
|
||||
|
||||
return await self.bot.send(
|
||||
aiocq_event,
|
||||
aiocq_msg
|
||||
)
|
||||
return await self.bot.send(aiocq_event, aiocq_msg)
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
return False
|
||||
@@ -213,24 +224,30 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
async def on_message(event: aiocqhttp.Event):
|
||||
self.bot_account_id = event.self_id
|
||||
try:
|
||||
return await callback(await self.event_converter.target2yiri(event), self)
|
||||
except:
|
||||
return await callback(
|
||||
await self.event_converter.target2yiri(event), self
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
if event_type == platform_events.GroupMessage:
|
||||
self.bot.on_message("group")(on_message)
|
||||
self.bot.on_message('group')(on_message)
|
||||
elif event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message("private")(on_message)
|
||||
self.bot.on_message('private')(on_message)
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
|
||||
@@ -1,37 +1,30 @@
|
||||
|
||||
import traceback
|
||||
import typing
|
||||
from libs.dingtalk_api.dingtalkevent import DingTalkEvent
|
||||
from pkg.platform.types import message as platform_message
|
||||
from pkg.platform.adapter import MessagePlatformAdapter
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
from pkg.core import app
|
||||
from .. import adapter
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ...command.errors import ParamNotEnoughError
|
||||
from libs.dingtalk_api.api import DingTalkClient
|
||||
import datetime
|
||||
|
||||
|
||||
class DingTalkMessageConverter(adapter.MessageConverter):
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
message_chain:platform_message.MessageChain
|
||||
):
|
||||
async def yiri2target(message_chain: platform_message.MessageChain):
|
||||
for msg in message_chain:
|
||||
if type(msg) is platform_message.Plain:
|
||||
return msg.text
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(event:DingTalkEvent, bot_name:str):
|
||||
async def target2yiri(event: DingTalkEvent, bot_name: str):
|
||||
yiri_msg_list = []
|
||||
yiri_msg_list.append(
|
||||
platform_message.Source(id = event.incoming_message.message_id,time=datetime.datetime.now())
|
||||
platform_message.Source(
|
||||
id=event.incoming_message.message_id, time=datetime.datetime.now()
|
||||
)
|
||||
)
|
||||
|
||||
for atUser in event.incoming_message.at_users:
|
||||
@@ -39,7 +32,7 @@ class DingTalkMessageConverter(adapter.MessageConverter):
|
||||
yiri_msg_list.append(platform_message.At(target=bot_name))
|
||||
|
||||
if event.content:
|
||||
text_content = event.content.replace("@"+bot_name, '')
|
||||
text_content = event.content.replace('@' + bot_name, '')
|
||||
yiri_msg_list.append(platform_message.Plain(text=text_content))
|
||||
if event.picture:
|
||||
yiri_msg_list.append(platform_message.Image(base64=event.picture))
|
||||
@@ -47,60 +40,51 @@ class DingTalkMessageConverter(adapter.MessageConverter):
|
||||
yiri_msg_list.append(platform_message.Voice(base64=event.audio))
|
||||
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
class DingTalkEventConverter(adapter.EventConverter):
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
event:platform_events.MessageEvent
|
||||
):
|
||||
async def yiri2target(event: platform_events.MessageEvent):
|
||||
return event.source_platform_object
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(
|
||||
event:DingTalkEvent,
|
||||
bot_name:str
|
||||
):
|
||||
|
||||
async def target2yiri(event: DingTalkEvent, bot_name: str):
|
||||
message_chain = await DingTalkMessageConverter.target2yiri(event, bot_name)
|
||||
|
||||
|
||||
if event.conversation == 'FriendMessage':
|
||||
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=event.incoming_message.sender_id,
|
||||
nickname = event.incoming_message.sender_nick,
|
||||
remark=""
|
||||
nickname=event.incoming_message.sender_nick,
|
||||
remark='',
|
||||
),
|
||||
message_chain = message_chain,
|
||||
time = event.incoming_message.create_at,
|
||||
message_chain=message_chain,
|
||||
time=event.incoming_message.create_at,
|
||||
source_platform_object=event,
|
||||
)
|
||||
elif event.conversation == 'GroupMessage':
|
||||
sender = platform_entities.GroupMember(
|
||||
id = event.incoming_message.sender_id,
|
||||
id=event.incoming_message.sender_id,
|
||||
member_name=event.incoming_message.sender_nick,
|
||||
permission= 'MEMBER',
|
||||
group = platform_entities.Group(
|
||||
id = event.incoming_message.conversation_id,
|
||||
name = event.incoming_message.conversation_title,
|
||||
permission=platform_entities.Permission.Member
|
||||
permission='MEMBER',
|
||||
group=platform_entities.Group(
|
||||
id=event.incoming_message.conversation_id,
|
||||
name=event.incoming_message.conversation_title,
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0
|
||||
mute_time_remaining=0,
|
||||
)
|
||||
time = event.incoming_message.create_at
|
||||
return platform_events.GroupMessage(
|
||||
sender =sender,
|
||||
message_chain = message_chain,
|
||||
time = time,
|
||||
source_platform_object=event
|
||||
sender=sender,
|
||||
message_chain=message_chain,
|
||||
time=time,
|
||||
source_platform_object=event,
|
||||
)
|
||||
|
||||
|
||||
@@ -112,28 +96,28 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
event_converter: DingTalkEventConverter = DingTalkEventConverter()
|
||||
config: dict
|
||||
|
||||
def __init__(self,config:dict,ap:app.Application):
|
||||
def __init__(self, config: dict, ap: app.Application):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
required_keys = [
|
||||
"client_id",
|
||||
"client_secret",
|
||||
"robot_name",
|
||||
"robot_code",
|
||||
'client_id',
|
||||
'client_secret',
|
||||
'robot_name',
|
||||
'robot_code',
|
||||
]
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise ParamNotEnoughError("钉钉缺少相关配置项,请查看文档或联系管理员")
|
||||
raise Exception('钉钉缺少相关配置项,请查看文档或联系管理员')
|
||||
|
||||
self.bot_account_id = self.config['robot_name']
|
||||
|
||||
self.bot_account_id = self.config["robot_name"]
|
||||
|
||||
self.bot = DingTalkClient(
|
||||
client_id=config["client_id"],
|
||||
client_secret=config["client_secret"],
|
||||
robot_name = config["robot_name"],
|
||||
robot_code=config["robot_code"]
|
||||
client_id=config['client_id'],
|
||||
client_secret=config['client_secret'],
|
||||
robot_name=config['robot_name'],
|
||||
robot_code=config['robot_code'],
|
||||
)
|
||||
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
@@ -146,17 +130,16 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
incoming_message = event.incoming_message
|
||||
|
||||
content = await DingTalkMessageConverter.yiri2target(message)
|
||||
await self.bot.send_message(content,incoming_message)
|
||||
|
||||
await self.bot.send_message(content, incoming_message)
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
content = await DingTalkMessageConverter.yiri2target(message)
|
||||
if target_type == 'person':
|
||||
await self.bot.send_proactive_message_to_one(target_id,content)
|
||||
await self.bot.send_proactive_message_to_one(target_id, content)
|
||||
if target_type == 'group':
|
||||
await self.bot.send_proactive_message_to_group(target_id,content)
|
||||
await self.bot.send_proactive_message_to_group(target_id, content)
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
@@ -168,15 +151,18 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
async def on_message(event: DingTalkEvent):
|
||||
try:
|
||||
return await callback(
|
||||
await self.event_converter.target2yiri(event, self.config["robot_name"]), self
|
||||
await self.event_converter.target2yiri(
|
||||
event, self.config['robot_name']
|
||||
),
|
||||
self,
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
if event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message("FriendMessage")(on_message)
|
||||
self.bot.on_message('FriendMessage')(on_message)
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
self.bot.on_message("GroupMessage")(on_message)
|
||||
self.bot.on_message('GroupMessage')(on_message)
|
||||
|
||||
async def run_async(self):
|
||||
await self.bot.start()
|
||||
@@ -187,7 +173,8 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
async def unregister_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
|
||||
@@ -3,39 +3,32 @@ from __future__ import annotations
|
||||
import discord
|
||||
|
||||
import typing
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
import re
|
||||
import base64
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
import datetime
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .. import adapter
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ...utils import image
|
||||
|
||||
|
||||
class DiscordMessageConverter(adapter.MessageConverter):
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
message_chain: platform_message.MessageChain
|
||||
message_chain: platform_message.MessageChain,
|
||||
) -> typing.Tuple[str, typing.List[discord.File]]:
|
||||
for ele in message_chain:
|
||||
if isinstance(ele, platform_message.At):
|
||||
message_chain.remove(ele)
|
||||
break
|
||||
|
||||
text_string = ""
|
||||
text_string = ''
|
||||
image_files = []
|
||||
|
||||
for ele in message_chain:
|
||||
@@ -49,46 +42,49 @@ class DiscordMessageConverter(adapter.MessageConverter):
|
||||
async with session.get(ele.url) as response:
|
||||
image_bytes = await response.read()
|
||||
elif ele.path:
|
||||
with open(ele.path, "rb") as f:
|
||||
with open(ele.path, 'rb') as f:
|
||||
image_bytes = f.read()
|
||||
|
||||
image_files.append(discord.File(fp=image_bytes, filename=f"{uuid.uuid4()}.png"))
|
||||
image_files.append(
|
||||
discord.File(fp=image_bytes, filename=f'{uuid.uuid4()}.png')
|
||||
)
|
||||
elif isinstance(ele, platform_message.Plain):
|
||||
text_string += ele.text
|
||||
elif isinstance(ele, platform_message.Forward):
|
||||
for node in ele.node_list:
|
||||
text_string, image_files = await DiscordMessageConverter.yiri2target(node.message_chain)
|
||||
(
|
||||
text_string,
|
||||
image_files,
|
||||
) = await DiscordMessageConverter.yiri2target(node.message_chain)
|
||||
text_string += text_string
|
||||
image_files.extend(image_files)
|
||||
|
||||
return text_string, image_files
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(
|
||||
message: discord.Message
|
||||
) -> platform_message.MessageChain:
|
||||
async def target2yiri(message: discord.Message) -> platform_message.MessageChain:
|
||||
lb_msg_list = []
|
||||
|
||||
msg_create_time = datetime.datetime.fromtimestamp(
|
||||
int(message.created_at.timestamp())
|
||||
)
|
||||
|
||||
lb_msg_list.append(
|
||||
platform_message.Source(id=message.id, time=msg_create_time)
|
||||
)
|
||||
lb_msg_list.append(platform_message.Source(id=message.id, time=msg_create_time))
|
||||
|
||||
element_list = []
|
||||
|
||||
def text_element_recur(text_ele: str) -> list[platform_message.MessageComponent]:
|
||||
if text_ele == "":
|
||||
def text_element_recur(
|
||||
text_ele: str,
|
||||
) -> list[platform_message.MessageComponent]:
|
||||
if text_ele == '':
|
||||
return []
|
||||
|
||||
# <@1234567890>
|
||||
# @everyone
|
||||
# @here
|
||||
at_pattern = re.compile(r"(@everyone|@here|<@[\d]+>)")
|
||||
at_pattern = re.compile(r'(@everyone|@here|<@[\d]+>)')
|
||||
at_matches = at_pattern.findall(text_ele)
|
||||
|
||||
|
||||
if len(at_matches) > 0:
|
||||
mid_at = at_matches[0]
|
||||
|
||||
@@ -96,18 +92,19 @@ class DiscordMessageConverter(adapter.MessageConverter):
|
||||
|
||||
mid_at_component = []
|
||||
|
||||
if mid_at == "@everyone" or mid_at == "@here":
|
||||
if mid_at == '@everyone' or mid_at == '@here':
|
||||
mid_at_component.append(platform_message.AtAll())
|
||||
else:
|
||||
mid_at_component.append(platform_message.At(target=mid_at[2:-1]))
|
||||
|
||||
return text_element_recur(text_split[0]) + \
|
||||
mid_at_component + \
|
||||
text_element_recur(text_split[1])
|
||||
return (
|
||||
text_element_recur(text_split[0])
|
||||
+ mid_at_component
|
||||
+ text_element_recur(text_split[1])
|
||||
)
|
||||
else:
|
||||
return [platform_message.Plain(text=text_ele)]
|
||||
|
||||
|
||||
element_list.extend(text_element_recur(message.content))
|
||||
|
||||
# attachments
|
||||
@@ -115,28 +112,27 @@ class DiscordMessageConverter(adapter.MessageConverter):
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(attachment.url) as response:
|
||||
image_data = await response.read()
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
image_format = response.headers["Content-Type"]
|
||||
element_list.append(platform_message.Image(base64=f"data:{image_format};base64,{image_base64}"))
|
||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||
image_format = response.headers['Content-Type']
|
||||
element_list.append(
|
||||
platform_message.Image(
|
||||
base64=f'data:{image_format};base64,{image_base64}'
|
||||
)
|
||||
)
|
||||
|
||||
return platform_message.MessageChain(element_list)
|
||||
|
||||
|
||||
class DiscordEventConverter(adapter.EventConverter):
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
event: platform_events.Event
|
||||
) -> discord.Message:
|
||||
async def yiri2target(event: platform_events.Event) -> discord.Message:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(
|
||||
event: discord.Message
|
||||
) -> platform_events.Event:
|
||||
async def target2yiri(event: discord.Message) -> platform_events.Event:
|
||||
message_chain = await DiscordMessageConverter.target2yiri(event)
|
||||
|
||||
if type(event.channel) == discord.DMChannel:
|
||||
if isinstance(event.channel, discord.DMChannel):
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=event.author.id,
|
||||
@@ -147,7 +143,7 @@ class DiscordEventConverter(adapter.EventConverter):
|
||||
time=event.created_at.timestamp(),
|
||||
source_platform_object=event,
|
||||
)
|
||||
elif type(event.channel) == discord.TextChannel:
|
||||
elif isinstance(event.channel, discord.TextChannel):
|
||||
return platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=event.author.id,
|
||||
@@ -158,7 +154,7 @@ class DiscordEventConverter(adapter.EventConverter):
|
||||
name=event.channel.name,
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title="",
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
@@ -170,7 +166,6 @@ class DiscordEventConverter(adapter.EventConverter):
|
||||
|
||||
|
||||
class DiscordAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
bot: discord.Client
|
||||
|
||||
bot_account_id: str # 用于在流水线中识别at是否是本bot,直接以bot_name作为标识
|
||||
@@ -191,12 +186,11 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
|
||||
self.bot_account_id = self.config["client_id"]
|
||||
self.bot_account_id = self.config['client_id']
|
||||
|
||||
adapter_self = self
|
||||
|
||||
class MyClient(discord.Client):
|
||||
|
||||
async def on_message(self: discord.Client, message: discord.Message):
|
||||
if message.author.id == self.user.id or message.author.bot:
|
||||
return
|
||||
@@ -209,11 +203,11 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
args = {}
|
||||
|
||||
if os.getenv("http_proxy"):
|
||||
args["proxy"] = os.getenv("http_proxy")
|
||||
if os.getenv('http_proxy'):
|
||||
args['proxy'] = os.getenv('http_proxy')
|
||||
|
||||
self.bot = MyClient(intents=intents, **args)
|
||||
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
@@ -229,17 +223,17 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
||||
assert isinstance(message_source.source_platform_object, discord.Message)
|
||||
|
||||
args = {
|
||||
"content": msg_to_send,
|
||||
'content': msg_to_send,
|
||||
}
|
||||
|
||||
if len(image_files) > 0:
|
||||
args["files"] = image_files
|
||||
args['files'] = image_files
|
||||
|
||||
if quote_origin:
|
||||
args["reference"] = message_source.source_platform_object
|
||||
args['reference'] = message_source.source_platform_object
|
||||
|
||||
if message.has(platform_message.At):
|
||||
args["mention_author"] = True
|
||||
args['mention_author'] = True
|
||||
|
||||
await message_source.source_platform_object.channel.send(**args)
|
||||
|
||||
@@ -249,20 +243,24 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners.pop(event_type)
|
||||
|
||||
async def run_async(self):
|
||||
async with self.bot:
|
||||
await self.bot.start(self.config["token"], reconnect=True)
|
||||
await self.bot.start(self.config['token'], reconnect=True)
|
||||
|
||||
async def kill(self) -> bool:
|
||||
await self.bot.close()
|
||||
|
||||
@@ -8,18 +8,13 @@ import traceback
|
||||
import time
|
||||
import re
|
||||
import base64
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
import copy
|
||||
import datetime
|
||||
import threading
|
||||
|
||||
import quart
|
||||
import aiohttp
|
||||
|
||||
from .. import adapter
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
@@ -29,109 +24,123 @@ import xml.etree.ElementTree as ET
|
||||
|
||||
|
||||
class GewechatMessageConverter(adapter.MessageConverter):
|
||||
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
message_chain: platform_message.MessageChain
|
||||
) -> list[dict]:
|
||||
async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]:
|
||||
content_list = []
|
||||
for component in message_chain:
|
||||
if isinstance(component, platform_message.At):
|
||||
content_list.append({"type": "at", "target": component.target})
|
||||
content_list.append({'type': 'at', 'target': component.target})
|
||||
elif isinstance(component, platform_message.Plain):
|
||||
content_list.append({"type": "text", "content": component.text})
|
||||
content_list.append({'type': 'text', 'content': component.text})
|
||||
elif isinstance(component, platform_message.Image):
|
||||
if not component.url:
|
||||
pass
|
||||
content_list.append({"type": "image", "image": component.url})
|
||||
|
||||
content_list.append({'type': 'image', 'image': component.url})
|
||||
|
||||
elif isinstance(component, platform_message.Voice):
|
||||
content_list.append({"type": "voice", "url": component.url, "length": component.length})
|
||||
content_list.append(
|
||||
{'type': 'voice', 'url': component.url, 'length': component.length}
|
||||
)
|
||||
elif isinstance(component, platform_message.Forward):
|
||||
for node in component.node_list:
|
||||
content_list.extend(await GewechatMessageConverter.yiri2target(node.message_chain))
|
||||
content_list.extend(
|
||||
await GewechatMessageConverter.yiri2target(node.message_chain)
|
||||
)
|
||||
|
||||
return content_list
|
||||
|
||||
async def target2yiri(
|
||||
self,
|
||||
message: dict,
|
||||
bot_account_id: str
|
||||
self, message: dict, bot_account_id: str
|
||||
) -> platform_message.MessageChain:
|
||||
|
||||
|
||||
|
||||
if message["Data"]["MsgType"] == 1:
|
||||
if message['Data']['MsgType'] == 1:
|
||||
# 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉
|
||||
regex = re.compile(r"^wxid_.*:")
|
||||
regex = re.compile(r'^wxid_.*:')
|
||||
# print(message)
|
||||
|
||||
line_split = message["Data"]["Content"]["string"].split("\n")
|
||||
line_split = message['Data']['Content']['string'].split('\n')
|
||||
|
||||
if len(line_split) > 0 and regex.match(line_split[0]):
|
||||
message["Data"]["Content"]["string"] = "\n".join(line_split[1:])
|
||||
|
||||
message['Data']['Content']['string'] = '\n'.join(line_split[1:])
|
||||
|
||||
# 正则表达式模式,匹配'@'后跟任意数量的非空白字符
|
||||
pattern = r'@\S+'
|
||||
at_string = f"@{bot_account_id}"
|
||||
at_string = f'@{bot_account_id}'
|
||||
content_list = []
|
||||
if at_string in message["Data"]["Content"]["string"]:
|
||||
if at_string in message['Data']['Content']['string']:
|
||||
content_list.append(platform_message.At(target=bot_account_id))
|
||||
content_list.append(platform_message.Plain(message["Data"]["Content"]["string"].replace(at_string, '', 1)))
|
||||
content_list.append(
|
||||
platform_message.Plain(
|
||||
message['Data']['Content']['string'].replace(at_string, '', 1)
|
||||
)
|
||||
)
|
||||
# 更优雅的替换改名后@机器人,仅仅限于单独AT的情况
|
||||
elif "PushContent" in message['Data'] and '在群聊中@了你' in message["Data"]["PushContent"]:
|
||||
if '@所有人' in message["Data"]["Content"]["string"]: # at全员时候传入atll不当作at自己
|
||||
elif (
|
||||
'PushContent' in message['Data']
|
||||
and '在群聊中@了你' in message['Data']['PushContent']
|
||||
):
|
||||
if (
|
||||
'@所有人' in message['Data']['Content']['string']
|
||||
): # at全员时候传入atll不当作at自己
|
||||
content_list.append(platform_message.AtAll())
|
||||
else:
|
||||
content_list.append(platform_message.At(target=bot_account_id))
|
||||
content_list.append(platform_message.Plain(re.sub(pattern, '', message["Data"]["Content"]["string"])))
|
||||
content_list.append(
|
||||
platform_message.Plain(
|
||||
re.sub(pattern, '', message['Data']['Content']['string'])
|
||||
)
|
||||
)
|
||||
else:
|
||||
content_list = [platform_message.Plain(message["Data"]["Content"]["string"])]
|
||||
content_list = [
|
||||
platform_message.Plain(message['Data']['Content']['string'])
|
||||
]
|
||||
|
||||
return platform_message.MessageChain(content_list)
|
||||
|
||||
elif message["Data"]["MsgType"] == 3:
|
||||
image_xml = message["Data"]["Content"]["string"]
|
||||
if not image_xml:
|
||||
return platform_message.MessageChain([
|
||||
platform_message.Plain(text="[图片内容为空]")
|
||||
])
|
||||
|
||||
elif message['Data']['MsgType'] == 3:
|
||||
image_xml = message['Data']['Content']['string']
|
||||
if not image_xml:
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Plain(text='[图片内容为空]')]
|
||||
)
|
||||
|
||||
try:
|
||||
base64_str, image_format = await image.get_gewechat_image_base64(
|
||||
gewechat_url=self.config["gewechat_url"],
|
||||
gewechat_file_url=self.config["gewechat_file_url"],
|
||||
app_id=self.config["app_id"],
|
||||
gewechat_url=self.config['gewechat_url'],
|
||||
gewechat_file_url=self.config['gewechat_file_url'],
|
||||
app_id=self.config['app_id'],
|
||||
xml_content=image_xml,
|
||||
token=self.config["token"],
|
||||
token=self.config['token'],
|
||||
image_type=2,
|
||||
)
|
||||
|
||||
return platform_message.MessageChain([
|
||||
platform_message.Image(
|
||||
base64=f"data:image/{image_format};base64,{base64_str}"
|
||||
)
|
||||
])
|
||||
return platform_message.MessageChain(
|
||||
[
|
||||
platform_message.Image(
|
||||
base64=f'data:image/{image_format};base64,{base64_str}'
|
||||
)
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"处理图片消息失败: {str(e)}")
|
||||
return platform_message.MessageChain([
|
||||
platform_message.Plain(text=f"[图片处理失败]")
|
||||
])
|
||||
elif message["Data"]["MsgType"] == 34:
|
||||
audio_base64 = message["Data"]["ImgBuf"]["buffer"]
|
||||
print(f'处理图片消息失败: {str(e)}')
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Plain(text='[图片处理失败]')]
|
||||
)
|
||||
elif message['Data']['MsgType'] == 34:
|
||||
audio_base64 = message['Data']['ImgBuf']['buffer']
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Voice(base64=f"data:audio/silk;base64,{audio_base64}")]
|
||||
[
|
||||
platform_message.Voice(
|
||||
base64=f'data:audio/silk;base64,{audio_base64}'
|
||||
)
|
||||
]
|
||||
)
|
||||
elif message["Data"]["MsgType"] == 49:
|
||||
elif message['Data']['MsgType'] == 49:
|
||||
# 支持微信聊天记录的消息类型,将 XML 内容转换为 MessageChain 传递
|
||||
try:
|
||||
content = message["Data"]["Content"]["string"]
|
||||
content = message['Data']['Content']['string']
|
||||
# 有三种可能的消息结构weid开头,私聊直接<?xml>和直接<msg>
|
||||
if content.startswith('wxid'):
|
||||
xml_list = content.split('\n')[2:]
|
||||
@@ -145,140 +154,145 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
content_data = ET.fromstring(xml_data)
|
||||
# print(xml_data)
|
||||
# 拿到细分消息类型,按照gewe接口中描述
|
||||
'''
|
||||
"""
|
||||
小程序:33/36
|
||||
引用消息:57
|
||||
转账消息:2000
|
||||
红包消息:2001
|
||||
视频号消息:51
|
||||
'''
|
||||
"""
|
||||
appmsg_data = content_data.find('.//appmsg')
|
||||
data_type = appmsg_data.find('.//type').text
|
||||
if data_type == '57':
|
||||
user_data = appmsg_data.find('.//title').text # 拿到用户消息
|
||||
quote_data = appmsg_data.find('.//refermsg').find('.//content').text # 引用原文
|
||||
sender_id = appmsg_data.find('.//refermsg').find('.//chatusr').text # 引用用户id
|
||||
quote_data = (
|
||||
appmsg_data.find('.//refermsg').find('.//content').text
|
||||
) # 引用原文
|
||||
sender_id = (
|
||||
appmsg_data.find('.//refermsg').find('.//chatusr').text
|
||||
) # 引用用户id
|
||||
from_name = message['Data']['FromUserName']['string']
|
||||
message_list =[]
|
||||
if message['Wxid'] == sender_id and from_name.endswith('@chatroom'): # 因为引用机制暂时无法响应用户,所以当引用用户是机器人是构建一个at激活机器人
|
||||
message_list = []
|
||||
if (
|
||||
message['Wxid'] == sender_id and from_name.endswith('@chatroom')
|
||||
): # 因为引用机制暂时无法响应用户,所以当引用用户是机器人是构建一个at激活机器人
|
||||
message_list.append(platform_message.At(target=bot_account_id))
|
||||
message_list.append(platform_message.Quote(
|
||||
message_list.append(
|
||||
platform_message.Quote(
|
||||
sender_id=sender_id,
|
||||
origin=platform_message.MessageChain(
|
||||
[platform_message.Plain(quote_data)]
|
||||
)))
|
||||
),
|
||||
)
|
||||
)
|
||||
message_list.append(platform_message.Plain(user_data))
|
||||
return platform_message.MessageChain(message_list)
|
||||
elif data_type == '51':
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Plain(text=f'[视频号消息]')]
|
||||
[platform_message.Plain(text='[视频号消息]')]
|
||||
)
|
||||
# print(content_data)
|
||||
elif data_type == '2000':
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Plain(text=f'[转账消息]')]
|
||||
[platform_message.Plain(text='[转账消息]')]
|
||||
)
|
||||
elif data_type == '2001':
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Plain(text=f'[红包消息]')]
|
||||
[platform_message.Plain(text='[红包消息]')]
|
||||
)
|
||||
elif data_type == '5':
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Plain(text=f'[公众号消息]')]
|
||||
[platform_message.Plain(text='[公众号消息]')]
|
||||
)
|
||||
elif data_type == '33' or data_type == '36':
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Plain(text=f'[小程序消息]')]
|
||||
[platform_message.Plain(text='[小程序消息]')]
|
||||
)
|
||||
# print(data_type.text)
|
||||
else:
|
||||
|
||||
|
||||
try:
|
||||
content_bytes = content.encode('utf-8')
|
||||
decoded_content = base64.b64decode(content_bytes)
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Unknown(content=decoded_content)]
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Plain(text=content)]
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error processing type 49 message: {str(e)}")
|
||||
print(f'Error processing type 49 message: {str(e)}')
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Plain(text="[无法解析的消息]")]
|
||||
[platform_message.Plain(text='[无法解析的消息]')]
|
||||
)
|
||||
|
||||
class GewechatEventConverter(adapter.EventConverter):
|
||||
|
||||
class GewechatEventConverter(adapter.EventConverter):
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.message_converter = GewechatMessageConverter(config)
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
event: platform_events.MessageEvent
|
||||
) -> dict:
|
||||
async def yiri2target(event: platform_events.MessageEvent) -> dict:
|
||||
pass
|
||||
|
||||
async def target2yiri(
|
||||
self,
|
||||
event: dict,
|
||||
bot_account_id: str
|
||||
self, event: dict, bot_account_id: str
|
||||
) -> platform_events.MessageEvent:
|
||||
# print(event)
|
||||
# 排除自己发消息回调回答问题
|
||||
if event['Wxid'] == event['Data']['FromUserName']['string']:
|
||||
return None
|
||||
# 排除公众号以及微信团队消息
|
||||
if event['Data']['FromUserName']['string'].startswith('gh_')\
|
||||
or event['Data']['FromUserName']['string'].startswith('weixin'):
|
||||
if event['Data']['FromUserName']['string'].startswith('gh_') or event['Data'][
|
||||
'FromUserName'
|
||||
]['string'].startswith('weixin'):
|
||||
return None
|
||||
message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id)
|
||||
message_chain = await self.message_converter.target2yiri(
|
||||
copy.deepcopy(event), bot_account_id
|
||||
)
|
||||
|
||||
if not message_chain:
|
||||
return None
|
||||
|
||||
if '@chatroom' in event["Data"]["FromUserName"]["string"]:
|
||||
|
||||
if '@chatroom' in event['Data']['FromUserName']['string']:
|
||||
# 找出开头的 wxid_ 字符串,以:结尾
|
||||
sender_wxid = event["Data"]["Content"]["string"].split(":")[0]
|
||||
sender_wxid = event['Data']['Content']['string'].split(':')[0]
|
||||
|
||||
return platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=sender_wxid,
|
||||
member_name=event["Data"]["FromUserName"]["string"],
|
||||
member_name=event['Data']['FromUserName']['string'],
|
||||
permission=platform_entities.Permission.Member,
|
||||
group=platform_entities.Group(
|
||||
id=event["Data"]["FromUserName"]["string"],
|
||||
name=event["Data"]["FromUserName"]["string"],
|
||||
id=event['Data']['FromUserName']['string'],
|
||||
name=event['Data']['FromUserName']['string'],
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title="",
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event["Data"]["CreateTime"],
|
||||
time=event['Data']['CreateTime'],
|
||||
source_platform_object=event,
|
||||
)
|
||||
else:
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=event["Data"]["FromUserName"]["string"],
|
||||
nickname=event["Data"]["FromUserName"]["string"],
|
||||
id=event['Data']['FromUserName']['string'],
|
||||
nickname=event['Data']['FromUserName']['string'],
|
||||
remark='',
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event["Data"]["CreateTime"],
|
||||
time=event['Data']['CreateTime'],
|
||||
source_platform_object=event,
|
||||
)
|
||||
|
||||
|
||||
class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
name: str = "gewechat" # 定义适配器名称
|
||||
name: str = 'gewechat' # 定义适配器名称
|
||||
|
||||
bot: gewechat_client.GewechatClient
|
||||
quart_app: quart.Quart
|
||||
@@ -296,7 +310,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
typing.Type[platform_events.Event],
|
||||
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
] = {}
|
||||
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
@@ -310,21 +324,21 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
async def gewechat_callback():
|
||||
data = await quart.request.json
|
||||
# print(json.dumps(data, indent=4, ensure_ascii=False))
|
||||
|
||||
|
||||
if 'data' in data:
|
||||
data['Data'] = data['data']
|
||||
if 'type_name' in data:
|
||||
data['TypeName'] = data['type_name']
|
||||
# print(json.dumps(data, indent=4, ensure_ascii=False))
|
||||
|
||||
|
||||
if 'testMsg' in data:
|
||||
return 'ok'
|
||||
elif 'TypeName' in data and data['TypeName'] == 'AddMsg':
|
||||
try:
|
||||
|
||||
event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id)
|
||||
except Exception as e:
|
||||
event = await self.event_converter.target2yiri(
|
||||
data.copy(), self.bot_account_id
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
if event.__class__ in self.listeners:
|
||||
@@ -333,65 +347,67 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
return 'ok'
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: platform_message.MessageChain
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
geweap_msg = await self.message_converter.yiri2target(message)
|
||||
# 此处加上群消息at处理
|
||||
ats = [item["target"] for item in geweap_msg if item["type"] == "at"]
|
||||
|
||||
ats = [item['target'] for item in geweap_msg if item['type'] == 'at']
|
||||
|
||||
for msg in geweap_msg:
|
||||
# at主动发送消息
|
||||
if msg['type'] == 'text':
|
||||
if ats:
|
||||
member_info = self.bot.get_chatroom_member_detail(
|
||||
self.config["app_id"],
|
||||
target_id,
|
||||
ats[::-1]
|
||||
)["data"]
|
||||
self.config['app_id'], target_id, ats[::-1]
|
||||
)['data']
|
||||
|
||||
for member in member_info:
|
||||
msg['content'] = f'@{member["nickName"]} {msg["content"]}'
|
||||
self.bot.post_text(app_id=self.config['app_id'], to_wxid=target_id, content=msg['content'],
|
||||
ats=",".join(ats))
|
||||
self.bot.post_text(
|
||||
app_id=self.config['app_id'],
|
||||
to_wxid=target_id,
|
||||
content=msg['content'],
|
||||
ats=','.join(ats),
|
||||
)
|
||||
|
||||
elif msg['type'] == 'image':
|
||||
|
||||
self.bot.post_image(app_id=self.config['app_id'], to_wxid=target_id, img_url=msg["image"])
|
||||
|
||||
|
||||
self.bot.post_image(
|
||||
app_id=self.config['app_id'],
|
||||
to_wxid=target_id,
|
||||
img_url=msg['image'],
|
||||
)
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
content_list = await self.message_converter.yiri2target(message)
|
||||
|
||||
ats = [item["target"] for item in content_list if item["type"] == "at"]
|
||||
ats = [item['target'] for item in content_list if item['type'] == 'at']
|
||||
|
||||
for msg in content_list:
|
||||
if msg["type"] == "text":
|
||||
|
||||
if msg['type'] == 'text':
|
||||
if ats:
|
||||
member_info = self.bot.get_chatroom_member_detail(
|
||||
self.config["app_id"],
|
||||
message_source.source_platform_object["Data"]["FromUserName"]["string"],
|
||||
ats[::-1]
|
||||
)["data"]
|
||||
self.config['app_id'],
|
||||
message_source.source_platform_object['Data']['FromUserName'][
|
||||
'string'
|
||||
],
|
||||
ats[::-1],
|
||||
)['data']
|
||||
|
||||
for member in member_info:
|
||||
msg['content'] = f'@{member["nickName"]} {msg["content"]}'
|
||||
|
||||
self.bot.post_text(
|
||||
app_id=self.config["app_id"],
|
||||
to_wxid=message_source.source_platform_object["Data"]["FromUserName"]["string"],
|
||||
content=msg["content"],
|
||||
ats=",".join(ats)
|
||||
app_id=self.config['app_id'],
|
||||
to_wxid=message_source.source_platform_object['Data'][
|
||||
'FromUserName'
|
||||
]['string'],
|
||||
content=msg['content'],
|
||||
ats=','.join(ats),
|
||||
)
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
@@ -400,51 +416,57 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None]
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None]
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
pass
|
||||
|
||||
async def run_async(self):
|
||||
|
||||
if not self.config["token"]:
|
||||
if not self.config['token']:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.config['gewechat_url']}/v2/api/tools/getTokenId",
|
||||
json={"app_id": self.config["app_id"]}
|
||||
f'{self.config["gewechat_url"]}/v2/api/tools/getTokenId',
|
||||
json={'app_id': self.config['app_id']},
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"获取gewechat token失败: {await response.text()}")
|
||||
self.config["token"] = (await response.json())["data"]
|
||||
raise Exception(
|
||||
f'获取gewechat token失败: {await response.text()}'
|
||||
)
|
||||
self.config['token'] = (await response.json())['data']
|
||||
|
||||
self.bot = gewechat_client.GewechatClient(
|
||||
f"{self.config['gewechat_url']}/v2/api",
|
||||
self.config["token"]
|
||||
f'{self.config["gewechat_url"]}/v2/api', self.config['token']
|
||||
)
|
||||
|
||||
app_id, error_msg = self.bot.login(self.config["app_id"])
|
||||
app_id, error_msg = self.bot.login(self.config['app_id'])
|
||||
if error_msg:
|
||||
raise Exception(f"Gewechat 登录失败: {error_msg}")
|
||||
raise Exception(f'Gewechat 登录失败: {error_msg}')
|
||||
|
||||
self.config["app_id"] = app_id
|
||||
self.config['app_id'] = app_id
|
||||
|
||||
self.ap.logger.info(f"Gewechat 登录成功,app_id: {app_id}")
|
||||
self.ap.logger.info(f'Gewechat 登录成功,app_id: {app_id}')
|
||||
|
||||
await self.ap.platform_mgr.write_back_config('gewechat', self, self.config)
|
||||
|
||||
# 获取 nickname
|
||||
profile = self.bot.get_profile(self.config["app_id"])
|
||||
self.bot_account_id = profile["data"]["nickName"]
|
||||
profile = self.bot.get_profile(self.config['app_id'])
|
||||
self.bot_account_id = profile['data']['nickName']
|
||||
|
||||
def thread_set_callback():
|
||||
time.sleep(3)
|
||||
ret = self.bot.set_callback(self.config["token"], self.config["callback_url"])
|
||||
ret = self.bot.set_callback(
|
||||
self.config['token'], self.config['callback_url']
|
||||
)
|
||||
print('设置 Gewechat 回调:', ret)
|
||||
|
||||
threading.Thread(target=thread_set_callback).start()
|
||||
@@ -455,7 +477,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
await self.quart_app.run_task(
|
||||
host='0.0.0.0',
|
||||
port=self.config["port"],
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,56 +5,53 @@ import lark_oapi
|
||||
import typing
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
import re
|
||||
import base64
|
||||
import uuid
|
||||
import json
|
||||
import datetime
|
||||
import hashlib
|
||||
import base64
|
||||
from Crypto.Cipher import AES
|
||||
|
||||
import aiohttp
|
||||
import lark_oapi.ws.exception
|
||||
import quart
|
||||
from flask import jsonify
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from lark_oapi.api.verification.v1 import GetVerificationRequest
|
||||
|
||||
from .. import adapter
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ...utils import image
|
||||
|
||||
|
||||
class AESCipher(object):
|
||||
class AESCipher(object):
|
||||
def __init__(self, key):
|
||||
self.bs = AES.block_size
|
||||
self.key=hashlib.sha256(AESCipher.str_to_bytes(key)).digest()
|
||||
self.key = hashlib.sha256(AESCipher.str_to_bytes(key)).digest()
|
||||
|
||||
@staticmethod
|
||||
def str_to_bytes(data):
|
||||
u_type = type(b"".decode('utf8'))
|
||||
u_type = type(b''.decode('utf8'))
|
||||
if isinstance(data, u_type):
|
||||
return data.encode('utf8')
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _unpad(s):
|
||||
return s[:-ord(s[len(s) - 1:])]
|
||||
return s[: -ord(s[len(s) - 1 :])]
|
||||
|
||||
def decrypt(self, enc):
|
||||
iv = enc[:AES.block_size]
|
||||
iv = enc[: AES.block_size]
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
return self._unpad(cipher.decrypt(enc[AES.block_size:]))
|
||||
return self._unpad(cipher.decrypt(enc[AES.block_size :]))
|
||||
|
||||
def decrypt_string(self, enc):
|
||||
enc = base64.b64decode(enc)
|
||||
return self.decrypt(enc).decode('utf8')
|
||||
return self.decrypt(enc).decode('utf8')
|
||||
|
||||
|
||||
class LarkMessageConverter(adapter.MessageConverter):
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
message_chain: platform_message.MessageChain, api_client: lark_oapi.Client
|
||||
@@ -65,15 +62,14 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
|
||||
for msg in message_chain:
|
||||
if isinstance(msg, platform_message.Plain):
|
||||
pending_paragraph.append({"tag": "md", "text": msg.text})
|
||||
pending_paragraph.append({'tag': 'md', 'text': msg.text})
|
||||
elif isinstance(msg, platform_message.At):
|
||||
pending_paragraph.append(
|
||||
{"tag": "at", "user_id": msg.target, "style": []}
|
||||
{'tag': 'at', 'user_id': msg.target, 'style': []}
|
||||
)
|
||||
elif isinstance(msg, platform_message.AtAll):
|
||||
pending_paragraph.append({"tag": "at", "user_id": "all", "style": []})
|
||||
pending_paragraph.append({'tag': 'at', 'user_id': 'all', 'style': []})
|
||||
elif isinstance(msg, platform_message.Image):
|
||||
|
||||
image_bytes = None
|
||||
|
||||
if msg.base64:
|
||||
@@ -83,14 +79,14 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
async with session.get(msg.url) as response:
|
||||
image_bytes = await response.read()
|
||||
elif msg.path:
|
||||
with open(msg.path, "rb") as f:
|
||||
with open(msg.path, 'rb') as f:
|
||||
image_bytes = f.read()
|
||||
|
||||
request: CreateImageRequest = (
|
||||
CreateImageRequest.builder()
|
||||
.request_body(
|
||||
CreateImageRequestBody.builder()
|
||||
.image_type("message")
|
||||
.image_type('message')
|
||||
.image(image_bytes)
|
||||
.build()
|
||||
)
|
||||
@@ -103,7 +99,7 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
|
||||
if not response.success():
|
||||
raise Exception(
|
||||
f"client.im.v1.image.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}"
|
||||
f'client.im.v1.image.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
|
||||
image_key = response.data.image_key
|
||||
@@ -112,15 +108,19 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
message_elements.append(
|
||||
[
|
||||
{
|
||||
"tag": "img",
|
||||
"image_key": image_key,
|
||||
'tag': 'img',
|
||||
'image_key': image_key,
|
||||
}
|
||||
]
|
||||
)
|
||||
pending_paragraph = []
|
||||
elif isinstance(msg, platform_message.Forward):
|
||||
for node in msg.node_list:
|
||||
message_elements.extend(await LarkMessageConverter.yiri2target(node.message_chain, api_client))
|
||||
message_elements.extend(
|
||||
await LarkMessageConverter.yiri2target(
|
||||
node.message_chain, api_client
|
||||
)
|
||||
)
|
||||
|
||||
if pending_paragraph:
|
||||
message_elements.append(pending_paragraph)
|
||||
@@ -144,15 +144,15 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
platform_message.Source(id=message.message_id, time=msg_create_time)
|
||||
)
|
||||
|
||||
if message.message_type == "text":
|
||||
if message.message_type == 'text':
|
||||
element_list = []
|
||||
|
||||
def text_element_recur(text_ele: dict) -> list[dict]:
|
||||
if text_ele["text"] == "":
|
||||
if text_ele['text'] == '':
|
||||
return []
|
||||
|
||||
at_pattern = re.compile(r"@_user_[\d]+")
|
||||
at_matches = at_pattern.findall(text_ele["text"])
|
||||
at_pattern = re.compile(r'@_user_[\d]+')
|
||||
at_matches = at_pattern.findall(text_ele['text'])
|
||||
|
||||
name_mapping = {}
|
||||
for mathc in at_matches:
|
||||
@@ -165,7 +165,7 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
return [text_ele]
|
||||
|
||||
# 只处理第一个,剩下的递归处理
|
||||
text_split = text_ele["text"].split(list(name_mapping.keys())[0])
|
||||
text_split = text_ele['text'].split(list(name_mapping.keys())[0])
|
||||
|
||||
new_list = []
|
||||
|
||||
@@ -173,58 +173,58 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
right_text = text_split[1]
|
||||
|
||||
new_list.extend(
|
||||
text_element_recur({"tag": "text", "text": left_text, "style": []})
|
||||
text_element_recur({'tag': 'text', 'text': left_text, 'style': []})
|
||||
)
|
||||
|
||||
new_list.append(
|
||||
{
|
||||
"tag": "at",
|
||||
"user_id": list(name_mapping.keys())[0],
|
||||
"user_name": name_mapping[list(name_mapping.keys())[0]],
|
||||
"style": [],
|
||||
'tag': 'at',
|
||||
'user_id': list(name_mapping.keys())[0],
|
||||
'user_name': name_mapping[list(name_mapping.keys())[0]],
|
||||
'style': [],
|
||||
}
|
||||
)
|
||||
|
||||
new_list.extend(
|
||||
text_element_recur({"tag": "text", "text": right_text, "style": []})
|
||||
text_element_recur({'tag': 'text', 'text': right_text, 'style': []})
|
||||
)
|
||||
|
||||
return new_list
|
||||
|
||||
element_list = text_element_recur(
|
||||
{"tag": "text", "text": message_content["text"], "style": []}
|
||||
{'tag': 'text', 'text': message_content['text'], 'style': []}
|
||||
)
|
||||
|
||||
message_content = {"title": "", "content": element_list}
|
||||
message_content = {'title': '', 'content': element_list}
|
||||
|
||||
elif message.message_type == "post":
|
||||
elif message.message_type == 'post':
|
||||
new_list = []
|
||||
|
||||
for ele in message_content["content"]:
|
||||
for ele in message_content['content']:
|
||||
if type(ele) is dict:
|
||||
new_list.append(ele)
|
||||
elif type(ele) is list:
|
||||
new_list.extend(ele)
|
||||
|
||||
message_content["content"] = new_list
|
||||
elif message.message_type == "image":
|
||||
message_content["content"] = [
|
||||
{"tag": "img", "image_key": message_content["image_key"], "style": []}
|
||||
message_content['content'] = new_list
|
||||
elif message.message_type == 'image':
|
||||
message_content['content'] = [
|
||||
{'tag': 'img', 'image_key': message_content['image_key'], 'style': []}
|
||||
]
|
||||
|
||||
for ele in message_content["content"]:
|
||||
if ele["tag"] == "text":
|
||||
lb_msg_list.append(platform_message.Plain(text=ele["text"]))
|
||||
elif ele["tag"] == "at":
|
||||
lb_msg_list.append(platform_message.At(target=ele["user_name"]))
|
||||
elif ele["tag"] == "img":
|
||||
image_key = ele["image_key"]
|
||||
for ele in message_content['content']:
|
||||
if ele['tag'] == 'text':
|
||||
lb_msg_list.append(platform_message.Plain(text=ele['text']))
|
||||
elif ele['tag'] == 'at':
|
||||
lb_msg_list.append(platform_message.At(target=ele['user_name']))
|
||||
elif ele['tag'] == 'img':
|
||||
image_key = ele['image_key']
|
||||
|
||||
request: GetMessageResourceRequest = (
|
||||
GetMessageResourceRequest.builder()
|
||||
.message_id(message.message_id)
|
||||
.file_key(image_key)
|
||||
.type("image")
|
||||
.type('image')
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -234,17 +234,17 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
|
||||
if not response.success():
|
||||
raise Exception(
|
||||
f"client.im.v1.message_resource.get failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}"
|
||||
f'client.im.v1.message_resource.get failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
|
||||
image_bytes = response.file.read()
|
||||
image_base64 = base64.b64encode(image_bytes).decode()
|
||||
|
||||
image_format = response.raw.headers["content-type"]
|
||||
image_format = response.raw.headers['content-type']
|
||||
|
||||
lb_msg_list.append(
|
||||
platform_message.Image(
|
||||
base64=f"data:{image_format};base64,{image_base64}"
|
||||
base64=f'data:{image_format};base64,{image_base64}'
|
||||
)
|
||||
)
|
||||
|
||||
@@ -252,7 +252,6 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
|
||||
|
||||
class LarkEventConverter(adapter.EventConverter):
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
event: platform_events.MessageEvent,
|
||||
@@ -267,17 +266,17 @@ class LarkEventConverter(adapter.EventConverter):
|
||||
event.event.message, api_client
|
||||
)
|
||||
|
||||
if event.event.message.chat_type == "p2p":
|
||||
if event.event.message.chat_type == 'p2p':
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=event.event.sender.sender_id.open_id,
|
||||
nickname=event.event.sender.sender_id.union_id,
|
||||
remark="",
|
||||
remark='',
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event.event.message.create_time,
|
||||
)
|
||||
elif event.event.message.chat_type == "group":
|
||||
elif event.event.message.chat_type == 'group':
|
||||
return platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=event.event.sender.sender_id.open_id,
|
||||
@@ -285,10 +284,10 @@ class LarkEventConverter(adapter.EventConverter):
|
||||
permission=platform_entities.Permission.Member,
|
||||
group=platform_entities.Group(
|
||||
id=event.event.message.chat_id,
|
||||
name="",
|
||||
name='',
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title="",
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
@@ -299,7 +298,6 @@ class LarkEventConverter(adapter.EventConverter):
|
||||
|
||||
|
||||
class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
bot: lark_oapi.ws.Client
|
||||
api_client: lark_oapi.Client
|
||||
|
||||
@@ -333,17 +331,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
data = cipher.decrypt_string(data['encrypt'])
|
||||
data = json.loads(data)
|
||||
|
||||
type = data.get("type")
|
||||
if type is None :
|
||||
type = data.get('type')
|
||||
if type is None:
|
||||
context = EventContext(data)
|
||||
type = context.header.event_type
|
||||
|
||||
|
||||
if 'url_verification' == type:
|
||||
print(data.get("challenge"))
|
||||
print(data.get('challenge'))
|
||||
# todo 验证verification token
|
||||
return {
|
||||
"challenge": data.get("challenge")
|
||||
}
|
||||
return {'challenge': data.get('challenge')}
|
||||
context = EventContext(data)
|
||||
type = context.header.event_type
|
||||
p2v1 = P2ImMessageReceiveV1()
|
||||
@@ -355,20 +351,21 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
p2v1.schema = context.schema
|
||||
if 'im.message.receive_v1' == type:
|
||||
try:
|
||||
event = await self.event_converter.target2yiri(p2v1, self.api_client)
|
||||
except Exception as e:
|
||||
event = await self.event_converter.target2yiri(
|
||||
p2v1, self.api_client
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
if event.__class__ in self.listeners:
|
||||
await self.listeners[event.__class__](event, self)
|
||||
|
||||
return {"code": 200, "message": "ok"}
|
||||
except Exception as e:
|
||||
return {'code': 200, 'message': 'ok'}
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return {"code": 500, "message": "error"}
|
||||
return {'code': 500, 'message': 'error'}
|
||||
|
||||
async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1):
|
||||
|
||||
lb_event = await self.event_converter.target2yiri(event, self.api_client)
|
||||
|
||||
await self.listeners[type(lb_event)](lb_event, self)
|
||||
@@ -377,20 +374,20 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
asyncio.create_task(on_message(event))
|
||||
|
||||
event_handler = (
|
||||
lark_oapi.EventDispatcherHandler.builder("", "")
|
||||
lark_oapi.EventDispatcherHandler.builder('', '')
|
||||
.register_p2_im_message_receive_v1(sync_on_message)
|
||||
.build()
|
||||
)
|
||||
|
||||
self.bot_account_id = config["bot_name"]
|
||||
self.bot_account_id = config['bot_name']
|
||||
|
||||
self.bot = lark_oapi.ws.Client(
|
||||
config["app_id"], config["app_secret"], event_handler=event_handler
|
||||
config['app_id'], config['app_secret'], event_handler=event_handler
|
||||
)
|
||||
self.api_client = (
|
||||
lark_oapi.Client.builder()
|
||||
.app_id(config["app_id"])
|
||||
.app_secret(config["app_secret"])
|
||||
.app_id(config['app_id'])
|
||||
.app_secret(config['app_secret'])
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -405,7 +402,6 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
|
||||
# 不再需要了,因为message_id已经被包含到message_chain中
|
||||
# lark_event = await self.event_converter.yiri2target(message_source)
|
||||
lark_message = await self.message_converter.yiri2target(
|
||||
@@ -413,9 +409,9 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
)
|
||||
|
||||
final_content = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": lark_message,
|
||||
'zh_cn': {
|
||||
'title': '',
|
||||
'content': lark_message,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -425,7 +421,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.content(json.dumps(final_content))
|
||||
.msg_type("post")
|
||||
.msg_type('post')
|
||||
.reply_in_thread(False)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
@@ -439,7 +435,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
if not response.success():
|
||||
raise Exception(
|
||||
f"client.im.v1.message.reply failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}"
|
||||
f'client.im.v1.message.reply failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
@@ -479,6 +475,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
|
||||
async def shutdown_trigger_placeholder():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
@@ -488,5 +485,6 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
port=port,
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
import asyncio
|
||||
import typing
|
||||
import traceback
|
||||
import logging
|
||||
|
||||
|
||||
import nakuru
|
||||
@@ -19,6 +18,7 @@ from ...platform.types import events as platform_events
|
||||
|
||||
class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
"""消息转换器"""
|
||||
|
||||
@staticmethod
|
||||
def yiri2target(message_chain: platform_message.MessageChain) -> list:
|
||||
msg_list = []
|
||||
@@ -29,10 +29,12 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
elif type(message_chain) is str:
|
||||
msg_list = [platform_message.Plain(message_chain)]
|
||||
else:
|
||||
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
|
||||
|
||||
raise Exception(
|
||||
'Unknown message type: ' + str(message_chain) + str(type(message_chain))
|
||||
)
|
||||
|
||||
nakuru_msg_list = []
|
||||
|
||||
|
||||
# 遍历并转换
|
||||
for component in msg_list:
|
||||
if type(component) is platform_message.Plain:
|
||||
@@ -61,33 +63,43 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
# 遍历并转换
|
||||
for yiri_forward_node in yiri_forward_node_list:
|
||||
try:
|
||||
content_list = NakuruProjectMessageConverter.yiri2target(yiri_forward_node.message_chain)
|
||||
content_list = NakuruProjectMessageConverter.yiri2target(
|
||||
yiri_forward_node.message_chain
|
||||
)
|
||||
nakuru_forward_node = nkc.Node(
|
||||
name=yiri_forward_node.sender_name,
|
||||
uin=yiri_forward_node.sender_id,
|
||||
time=int(yiri_forward_node.time.timestamp()) if yiri_forward_node.time is not None else None,
|
||||
content=content_list
|
||||
time=int(yiri_forward_node.time.timestamp())
|
||||
if yiri_forward_node.time is not None
|
||||
else None,
|
||||
content=content_list,
|
||||
)
|
||||
nakuru_forward_node_list.append(nakuru_forward_node)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
nakuru_msg_list.append(nakuru_forward_node_list)
|
||||
else:
|
||||
nakuru_msg_list.append(nkc.Plain(str(component)))
|
||||
|
||||
|
||||
return nakuru_msg_list
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(message_chain: typing.Any, message_id: int = -1) -> platform_message.MessageChain:
|
||||
def target2yiri(
|
||||
message_chain: typing.Any, message_id: int = -1
|
||||
) -> platform_message.MessageChain:
|
||||
"""将Yiri的消息链转换为YiriMirai的消息链"""
|
||||
assert type(message_chain) is list
|
||||
|
||||
yiri_msg_list = []
|
||||
import datetime
|
||||
|
||||
# 添加Source组件以标记message_id等信息
|
||||
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
|
||||
yiri_msg_list.append(
|
||||
platform_message.Source(id=message_id, time=datetime.datetime.now())
|
||||
)
|
||||
for component in message_chain:
|
||||
if type(component) is nkc.Plain:
|
||||
yiri_msg_list.append(platform_message.Plain(text=component.text))
|
||||
@@ -106,6 +118,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
|
||||
class NakuruProjectEventConverter(adapter_model.EventConverter):
|
||||
"""事件转换器"""
|
||||
|
||||
@staticmethod
|
||||
def yiri2target(event: typing.Type[platform_events.Event]):
|
||||
if event is platform_events.GroupMessage:
|
||||
@@ -113,28 +126,30 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
|
||||
elif event is platform_events.FriendMessage:
|
||||
return nakuru.FriendMessage
|
||||
else:
|
||||
raise Exception("未支持转换的事件类型: " + str(event))
|
||||
raise Exception('未支持转换的事件类型: ' + str(event))
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(event: typing.Any) -> platform_events.Event:
|
||||
yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id)
|
||||
yiri_chain = NakuruProjectMessageConverter.target2yiri(
|
||||
event.message, event.message_id
|
||||
)
|
||||
if type(event) is nakuru.FriendMessage: # 私聊消息事件
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=event.sender.user_id,
|
||||
nickname=event.sender.nickname,
|
||||
remark=event.sender.nickname
|
||||
remark=event.sender.nickname,
|
||||
),
|
||||
message_chain=yiri_chain,
|
||||
time=event.time
|
||||
time=event.time,
|
||||
)
|
||||
elif type(event) is nakuru.GroupMessage: # 群聊消息事件
|
||||
permission = "MEMBER"
|
||||
permission = 'MEMBER'
|
||||
|
||||
if event.sender.role == "admin":
|
||||
permission = "ADMINISTRATOR"
|
||||
elif event.sender.role == "owner":
|
||||
permission = "OWNER"
|
||||
if event.sender.role == 'admin':
|
||||
permission = 'ADMINISTRATOR'
|
||||
elif event.sender.role == 'owner':
|
||||
permission = 'OWNER'
|
||||
|
||||
return platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
@@ -144,7 +159,7 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
|
||||
group=platform_entities.Group(
|
||||
id=event.group_id,
|
||||
name=event.sender.nickname,
|
||||
permission=platform_entities.Permission.Member
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title=event.sender.title,
|
||||
join_timestamp=0,
|
||||
@@ -152,14 +167,15 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=yiri_chain,
|
||||
time=event.time
|
||||
time=event.time,
|
||||
)
|
||||
else:
|
||||
raise Exception("未支持转换的事件类型: " + str(event))
|
||||
raise Exception('未支持转换的事件类型: ' + str(event))
|
||||
|
||||
|
||||
class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
"""nakuru-project适配器"""
|
||||
|
||||
bot: nakuru.CQHTTP
|
||||
bot_account_id: int
|
||||
|
||||
@@ -186,12 +202,14 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: typing.Union[platform_message.MessageChain, list],
|
||||
converted: bool = False
|
||||
converted: bool = False,
|
||||
):
|
||||
task = None
|
||||
|
||||
converted_msg = self.message_converter.yiri2target(message) if not converted else message
|
||||
|
||||
converted_msg = (
|
||||
self.message_converter.yiri2target(message) if not converted else message
|
||||
)
|
||||
|
||||
# 检查是否有转发消息
|
||||
has_forward = False
|
||||
for msg in converted_msg:
|
||||
@@ -200,19 +218,19 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
converted_msg = msg
|
||||
break
|
||||
if has_forward:
|
||||
if target_type == "group":
|
||||
if target_type == 'group':
|
||||
task = self.bot.sendGroupForwardMessage(int(target_id), converted_msg)
|
||||
elif target_type == "person":
|
||||
elif target_type == 'person':
|
||||
task = self.bot.sendPrivateForwardMessage(int(target_id), converted_msg)
|
||||
else:
|
||||
raise Exception("Unknown target type: " + target_type)
|
||||
raise Exception('Unknown target type: ' + target_type)
|
||||
else:
|
||||
if target_type == "group":
|
||||
if target_type == 'group':
|
||||
task = self.bot.sendGroupMessage(int(target_id), converted_msg)
|
||||
elif target_type == "person":
|
||||
elif target_type == 'person':
|
||||
task = self.bot.sendFriendMessage(int(target_id), converted_msg)
|
||||
else:
|
||||
raise Exception("Unknown target type: " + target_type)
|
||||
raise Exception('Unknown target type: ' + target_type)
|
||||
|
||||
await task
|
||||
|
||||
@@ -220,45 +238,45 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
message = self.message_converter.yiri2target(message)
|
||||
if quote_origin:
|
||||
# 在前方添加引用组件
|
||||
message.insert(0, nkc.Reply(
|
||||
message.insert(
|
||||
0,
|
||||
nkc.Reply(
|
||||
id=message_source.message_chain.message_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
if type(message_source) is platform_events.GroupMessage:
|
||||
await self.send_message(
|
||||
"group",
|
||||
message_source.sender.group.id,
|
||||
message,
|
||||
converted=True
|
||||
'group', message_source.sender.group.id, message, converted=True
|
||||
)
|
||||
elif type(message_source) is platform_events.FriendMessage:
|
||||
await self.send_message(
|
||||
"person",
|
||||
message_source.sender.id,
|
||||
message,
|
||||
converted=True
|
||||
'person', message_source.sender.id, message, converted=True
|
||||
)
|
||||
else:
|
||||
raise Exception("Unknown message source type: " + str(type(message_source)))
|
||||
raise Exception('Unknown message source type: ' + str(type(message_source)))
|
||||
|
||||
def is_muted(self, group_id: int) -> bool:
|
||||
import time
|
||||
|
||||
# 检查是否被禁言
|
||||
group_member_info = asyncio.run(self.bot.getGroupMemberInfo(group_id, self.bot_account_id))
|
||||
group_member_info = asyncio.run(
|
||||
self.bot.getGroupMemberInfo(group_id, self.bot_account_id)
|
||||
)
|
||||
return group_member_info.shut_up_timestamp > int(time.time())
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None]
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
try:
|
||||
|
||||
source_cls = NakuruProjectEventConverter.yiri2target(event_type)
|
||||
|
||||
# 包装函数
|
||||
@@ -268,9 +286,9 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
# 将包装函数和原函数的对应关系存入列表
|
||||
self.listener_list.append(
|
||||
{
|
||||
"event_type": event_type,
|
||||
"callable": callback,
|
||||
"wrapper": listener_wrapper,
|
||||
'event_type': event_type,
|
||||
'callable': callback,
|
||||
'wrapper': listener_wrapper,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -283,7 +301,9 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None]
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
|
||||
|
||||
@@ -292,13 +312,16 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
# 从本对象的监听器列表中查找并删除
|
||||
target_wrapper = None
|
||||
for listener in self.listener_list:
|
||||
if listener["event_type"] == event_type and listener["callable"] == callback:
|
||||
target_wrapper = listener["wrapper"]
|
||||
if (
|
||||
listener['event_type'] == event_type
|
||||
and listener['callable'] == callback
|
||||
):
|
||||
target_wrapper = listener['wrapper']
|
||||
self.listener_list.remove(listener)
|
||||
break
|
||||
|
||||
if target_wrapper is None:
|
||||
raise Exception("未找到对应的监听器")
|
||||
raise Exception('未找到对应的监听器')
|
||||
|
||||
for func in self.bot.event[nakuru_event_name]:
|
||||
if func.callable != target_wrapper:
|
||||
@@ -309,23 +332,30 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
async def run_async(self):
|
||||
try:
|
||||
import requests
|
||||
|
||||
resp = requests.get(
|
||||
url="http://{}:{}/get_login_info".format(self.cfg['host'], self.cfg['http_port']),
|
||||
url='http://{}:{}/get_login_info'.format(
|
||||
self.cfg['host'], self.cfg['http_port']
|
||||
),
|
||||
headers={
|
||||
'Authorization': "Bearer " + self.cfg['token'] if 'token' in self.cfg else ""
|
||||
'Authorization': 'Bearer ' + self.cfg['token']
|
||||
if 'token' in self.cfg
|
||||
else ''
|
||||
},
|
||||
timeout=5,
|
||||
proxies=None
|
||||
proxies=None,
|
||||
)
|
||||
if resp.status_code == 403:
|
||||
raise Exception("go-cqhttp拒绝访问,请检查配置文件中nakuru适配器的配置")
|
||||
raise Exception('go-cqhttp拒绝访问,请检查配置文件中nakuru适配器的配置')
|
||||
self.bot_account_id = int(resp.json()['data']['user_id'])
|
||||
except Exception as e:
|
||||
raise Exception("获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确")
|
||||
except Exception:
|
||||
raise Exception(
|
||||
'获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确'
|
||||
)
|
||||
await self.bot._run()
|
||||
self.ap.logger.info("运行 Nakuru 适配器")
|
||||
self.ap.logger.info('运行 Nakuru 适配器')
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -4,20 +4,13 @@ import asyncio
|
||||
import traceback
|
||||
|
||||
import datetime
|
||||
from pkg.core import app
|
||||
from pkg.platform.adapter import MessagePlatformAdapter
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
from collections import deque
|
||||
from libs.official_account_api.oaevent import OAEvent
|
||||
from pkg.platform.adapter import MessagePlatformAdapter
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
from libs.official_account_api.api import OAClient
|
||||
from libs.official_account_api.api import OAClientForLongerResponse
|
||||
from pkg.core import app
|
||||
from .. import adapter
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ...command.errors import ParamNotEnoughError
|
||||
|
||||
@@ -28,10 +21,9 @@ class OAMessageConverter(adapter.MessageConverter):
|
||||
for msg in message_chain:
|
||||
if type(msg) is platform_message.Plain:
|
||||
return msg.text
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(message:str,message_id =-1):
|
||||
async def target2yiri(message: str, message_id=-1):
|
||||
yiri_msg_list = []
|
||||
yiri_msg_list.append(
|
||||
platform_message.Source(id=message_id, time=datetime.datetime.now())
|
||||
@@ -41,12 +33,12 @@ class OAMessageConverter(adapter.MessageConverter):
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
|
||||
class OAEventConverter(adapter.EventConverter):
|
||||
@staticmethod
|
||||
async def target2yiri(event:OAEvent):
|
||||
if event.type == "text":
|
||||
async def target2yiri(event: OAEvent):
|
||||
if event.type == 'text':
|
||||
yiri_chain = await OAMessageConverter.target2yiri(
|
||||
event.message, event.message_id
|
||||
)
|
||||
@@ -54,91 +46,101 @@ class OAEventConverter(adapter.EventConverter):
|
||||
friend = platform_entities.Friend(
|
||||
id=event.user_id,
|
||||
nickname=str(event.user_id),
|
||||
remark="",
|
||||
remark='',
|
||||
)
|
||||
|
||||
return platform_events.FriendMessage(
|
||||
sender=friend, message_chain=yiri_chain, time=event.timestamp, source_platform_object=event
|
||||
sender=friend,
|
||||
message_chain=yiri_chain,
|
||||
time=event.timestamp,
|
||||
source_platform_object=event,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
bot : OAClient | OAClientForLongerResponse
|
||||
ap : app.Application
|
||||
class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
||||
bot: OAClient | OAClientForLongerResponse
|
||||
ap: app.Application
|
||||
bot_account_id: str
|
||||
message_converter: OAMessageConverter = OAMessageConverter()
|
||||
event_converter: OAEventConverter = OAEventConverter()
|
||||
config: dict
|
||||
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application):
|
||||
self.config = config
|
||||
|
||||
|
||||
self.ap = ap
|
||||
|
||||
required_keys = [
|
||||
"token",
|
||||
"EncodingAESKey",
|
||||
"AppSecret",
|
||||
"AppID",
|
||||
"Mode",
|
||||
'token',
|
||||
'EncodingAESKey',
|
||||
'AppSecret',
|
||||
'AppID',
|
||||
'Mode',
|
||||
]
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise ParamNotEnoughError("微信公众号缺少相关配置项,请查看文档或联系管理员")
|
||||
|
||||
|
||||
if self.config['Mode'] == "drop":
|
||||
raise ParamNotEnoughError(
|
||||
'微信公众号缺少相关配置项,请查看文档或联系管理员'
|
||||
)
|
||||
|
||||
if self.config['Mode'] == 'drop':
|
||||
self.bot = OAClient(
|
||||
token=config['token'],
|
||||
EncodingAESKey=config['EncodingAESKey'],
|
||||
Appsecret=config['AppSecret'],
|
||||
AppID=config['AppID'],
|
||||
AppID=config['AppID'],
|
||||
)
|
||||
elif self.config['Mode'] == "passive":
|
||||
elif self.config['Mode'] == 'passive':
|
||||
self.bot = OAClientForLongerResponse(
|
||||
token=config['token'],
|
||||
EncodingAESKey=config['EncodingAESKey'],
|
||||
Appsecret=config['AppSecret'],
|
||||
AppID=config['AppID'],
|
||||
LoadingMessage=config['LoadingMessage']
|
||||
AppID=config['AppID'],
|
||||
LoadingMessage=config['LoadingMessage'],
|
||||
)
|
||||
else:
|
||||
raise KeyError("请设置微信公众号通信模式")
|
||||
raise KeyError('请设置微信公众号通信模式')
|
||||
|
||||
|
||||
async def reply_message(self, message_source: platform_events.FriendMessage, message: platform_message.MessageChain, quote_origin: bool = False):
|
||||
|
||||
content = await OAMessageConverter.yiri2target(
|
||||
message
|
||||
)
|
||||
if type(self.bot) == OAClient:
|
||||
await self.bot.set_message(message_source.message_chain.message_id,content)
|
||||
if type(self.bot) == OAClientForLongerResponse:
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.FriendMessage,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
content = await OAMessageConverter.yiri2target(message)
|
||||
if isinstance(self.bot, OAClient):
|
||||
await self.bot.set_message(message_source.message_chain.message_id, content)
|
||||
elif isinstance(self.bot, OAClientForLongerResponse):
|
||||
from_user = message_source.sender.id
|
||||
await self.bot.set_message(from_user,message_source.message_chain.message_id,content)
|
||||
await self.bot.set_message(
|
||||
from_user, message_source.message_chain.message_id, content
|
||||
)
|
||||
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def register_listener(self, event_type: type, callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None]):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
async def on_message(event: OAEvent):
|
||||
self.bot_account_id = event.receiver_id
|
||||
try:
|
||||
return await callback(
|
||||
await self.event_converter.target2yiri(event), self
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
if event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message("text")(on_message)
|
||||
self.bot.on_message('text')(on_message)
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
pass
|
||||
|
||||
@@ -148,8 +150,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await self.bot.run_task(
|
||||
host=self.config["host"],
|
||||
port=self.config["port"],
|
||||
host=self.config['host'],
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
|
||||
@@ -159,8 +161,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
||||
async def unregister_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
|
||||
@@ -22,12 +22,20 @@ from ...platform.types import message as platform_message
|
||||
class OfficialGroupMessage(platform_events.GroupMessage):
|
||||
pass
|
||||
|
||||
|
||||
class OfficialFriendMessage(platform_events.FriendMessage):
|
||||
pass
|
||||
|
||||
|
||||
event_handler_mapping = {
|
||||
platform_events.GroupMessage: ["on_at_message_create", "on_group_at_message_create"],
|
||||
platform_events.FriendMessage: ["on_direct_message_create", "on_c2c_message_create"],
|
||||
platform_events.GroupMessage: [
|
||||
'on_at_message_create',
|
||||
'on_group_at_message_create',
|
||||
],
|
||||
platform_events.FriendMessage: [
|
||||
'on_direct_message_create',
|
||||
'on_c2c_message_create',
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -53,9 +61,10 @@ def char_to_value(char):
|
||||
return ord(char) - ord('0')
|
||||
elif 'A' <= char <= 'Z':
|
||||
return ord(char) - ord('A') + 10
|
||||
|
||||
|
||||
return ord(char) - ord('a') + 36
|
||||
|
||||
|
||||
def digest(s: str) -> int:
|
||||
"""计算字符串的hash值。"""
|
||||
# 取末尾的8位
|
||||
@@ -69,19 +78,24 @@ def digest(s: str) -> int:
|
||||
|
||||
return number
|
||||
|
||||
K = typing.TypeVar("K")
|
||||
V = typing.TypeVar("V")
|
||||
|
||||
K = typing.TypeVar('K')
|
||||
V = typing.TypeVar('V')
|
||||
|
||||
|
||||
class OpenIDMapping(typing.Generic[K, V]):
|
||||
|
||||
map: dict[K, V]
|
||||
|
||||
dump_func: typing.Callable
|
||||
|
||||
digest_func: typing.Callable[[K], V]
|
||||
|
||||
def __init__(self, map: dict[K, V], dump_func: typing.Callable, digest_func: typing.Callable[[K], V] = digest):
|
||||
def __init__(
|
||||
self,
|
||||
map: dict[K, V],
|
||||
dump_func: typing.Callable,
|
||||
digest_func: typing.Callable[[K], V] = digest,
|
||||
):
|
||||
self.map = map
|
||||
|
||||
self.dump_func = dump_func
|
||||
@@ -104,12 +118,11 @@ class OpenIDMapping(typing.Generic[K, V]):
|
||||
|
||||
def getkey(self, value: V) -> K:
|
||||
return list(self.map.keys())[list(self.map.values()).index(value)]
|
||||
|
||||
|
||||
def save_openid(self, key: K) -> V:
|
||||
|
||||
if key in self.map:
|
||||
return self.map[key]
|
||||
|
||||
|
||||
value = self.digest_func(key)
|
||||
|
||||
self.map[key] = value
|
||||
@@ -135,7 +148,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
msg_list = [platform_message.Plain(text=message_chain)]
|
||||
else:
|
||||
raise Exception(
|
||||
"Unknown message type: " + str(message_chain) + str(type(message_chain))
|
||||
'Unknown message type: ' + str(message_chain) + str(type(message_chain))
|
||||
)
|
||||
|
||||
offcial_messages: list[dict] = []
|
||||
@@ -154,23 +167,23 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
# 遍历并转换
|
||||
for component in msg_list:
|
||||
if type(component) is platform_message.Plain:
|
||||
offcial_messages.append({"type": "text", "content": component.text})
|
||||
offcial_messages.append({'type': 'text', 'content': component.text})
|
||||
elif type(component) is platform_message.Image:
|
||||
if component.url is not None:
|
||||
offcial_messages.append({"type": "image", "content": component.url})
|
||||
offcial_messages.append({'type': 'image', 'content': component.url})
|
||||
elif component.path is not None:
|
||||
offcial_messages.append(
|
||||
{"type": "file_image", "content": component.path}
|
||||
{'type': 'file_image', 'content': component.path}
|
||||
)
|
||||
elif type(component) is platform_message.At:
|
||||
offcial_messages.append({"type": "at", "content": ""})
|
||||
offcial_messages.append({'type': 'at', 'content': ''})
|
||||
elif type(component) is platform_message.AtAll:
|
||||
print(
|
||||
"上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
|
||||
'上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。'
|
||||
)
|
||||
elif type(component) is platform_message.Voice:
|
||||
print(
|
||||
"上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
|
||||
'上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。'
|
||||
)
|
||||
elif type(component) is forward.Forward:
|
||||
# 转发消息
|
||||
@@ -185,7 +198,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
offcial_messages.extend(
|
||||
OfficialMessageConverter.yiri2target(message_chain)
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
@@ -194,7 +207,12 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
|
||||
@staticmethod
|
||||
def extract_message_chain_from_obj(
|
||||
message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage],
|
||||
message: typing.Union[
|
||||
botpy_message.Message,
|
||||
botpy_message.DirectMessage,
|
||||
botpy_message.GroupMessage,
|
||||
botpy_message.C2CMessage,
|
||||
],
|
||||
message_id: str = None,
|
||||
bot_account_id: int = 0,
|
||||
) -> platform_message.MessageChain:
|
||||
@@ -210,7 +228,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
if type(message) not in [botpy_message.DirectMessage, botpy_message.C2CMessage]:
|
||||
yiri_msg_list.append(platform_message.At(target=bot_account_id))
|
||||
|
||||
if hasattr(message, "mentions"):
|
||||
if hasattr(message, 'mentions'):
|
||||
for mention in message.mentions:
|
||||
if mention.bot:
|
||||
continue
|
||||
@@ -218,15 +236,15 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
yiri_msg_list.append(platform_message.At(target=mention.id))
|
||||
|
||||
for attachment in message.attachments:
|
||||
if attachment.content_type.startswith("image"):
|
||||
if attachment.content_type.startswith('image'):
|
||||
yiri_msg_list.append(platform_message.Image(url=attachment.url))
|
||||
else:
|
||||
logging.warning(
|
||||
"不支持的附件类型:" + attachment.content_type + ",忽略此附件。"
|
||||
'不支持的附件类型:' + attachment.content_type + ',忽略此附件。'
|
||||
)
|
||||
|
||||
content = re.sub(r"<@!\d+>", "", str(message.content))
|
||||
if content.strip() != "":
|
||||
content = re.sub(r'<@!\d+>', '', str(message.content))
|
||||
if content.strip() != '':
|
||||
yiri_msg_list.append(platform_message.Plain(text=content))
|
||||
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
@@ -247,21 +265,25 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
return botpy_message.DirectMessage
|
||||
else:
|
||||
raise Exception(
|
||||
"未支持转换的事件类型(YiriMirai -> Official): " + str(event)
|
||||
'未支持转换的事件类型(YiriMirai -> Official): ' + str(event)
|
||||
)
|
||||
|
||||
def target2yiri(
|
||||
self,
|
||||
event: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage],
|
||||
event: typing.Union[
|
||||
botpy_message.Message,
|
||||
botpy_message.DirectMessage,
|
||||
botpy_message.GroupMessage,
|
||||
botpy_message.C2CMessage,
|
||||
],
|
||||
) -> platform_events.Event:
|
||||
if isinstance(event, botpy_message.Message): # 频道内,转群聊事件
|
||||
permission = 'MEMBER'
|
||||
|
||||
if type(event) == botpy_message.Message: # 频道内,转群聊事件
|
||||
permission = "MEMBER"
|
||||
|
||||
if "2" in event.member.roles:
|
||||
permission = "ADMINISTRATOR"
|
||||
elif "4" in event.member.roles:
|
||||
permission = "OWNER"
|
||||
if '2' in event.member.roles:
|
||||
permission = 'ADMINISTRATOR'
|
||||
elif '4' in event.member.roles:
|
||||
permission = 'OWNER'
|
||||
|
||||
return platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
@@ -273,10 +295,10 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
name=event.author.username,
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title="",
|
||||
special_title='',
|
||||
join_timestamp=int(
|
||||
datetime.datetime.strptime(
|
||||
event.member.joined_at, "%Y-%m-%dT%H:%M:%S%z"
|
||||
event.member.joined_at, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
),
|
||||
last_speak_timestamp=datetime.datetime.now().timestamp(),
|
||||
@@ -287,11 +309,11 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
),
|
||||
)
|
||||
elif type(event) == botpy_message.DirectMessage: # 频道私聊,转私聊事件
|
||||
elif isinstance(event, botpy_message.DirectMessage): # 频道私聊,转私聊事件
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=event.guild_id,
|
||||
@@ -303,25 +325,24 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
),
|
||||
)
|
||||
elif type(event) == botpy_message.GroupMessage: # 群聊,转群聊事件
|
||||
|
||||
elif isinstance(event, botpy_message.GroupMessage): # 群聊,转群聊事件
|
||||
author_member_id = event.author.member_openid
|
||||
|
||||
return OfficialGroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=author_member_id,
|
||||
member_name=author_member_id,
|
||||
permission="MEMBER",
|
||||
permission='MEMBER',
|
||||
group=platform_entities.Group(
|
||||
id=event.group_openid,
|
||||
name=author_member_id,
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title="",
|
||||
special_title='',
|
||||
join_timestamp=int(0),
|
||||
last_speak_timestamp=datetime.datetime.now().timestamp(),
|
||||
mute_time_remaining=0,
|
||||
@@ -331,12 +352,11 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
),
|
||||
)
|
||||
elif type(event) == botpy_message.C2CMessage: # 私聊,转私聊事件
|
||||
|
||||
elif isinstance(event, botpy_message.C2CMessage): # 私聊,转私聊事件
|
||||
user_id_alter = event.author.user_openid
|
||||
|
||||
return OfficialFriendMessage(
|
||||
@@ -350,7 +370,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
),
|
||||
)
|
||||
@@ -391,10 +411,10 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
|
||||
switchs = {}
|
||||
|
||||
for intent in cfg["intents"]:
|
||||
for intent in cfg['intents']:
|
||||
switchs[intent] = True
|
||||
|
||||
del cfg["intents"]
|
||||
del cfg['intents']
|
||||
|
||||
intents = botpy.Intents(**switchs)
|
||||
|
||||
@@ -408,21 +428,21 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
for msg in message_list:
|
||||
args = {}
|
||||
|
||||
if msg["type"] == "text":
|
||||
args["content"] = msg["content"]
|
||||
elif msg["type"] == "image":
|
||||
args["image"] = msg["content"]
|
||||
elif msg["type"] == "file_image":
|
||||
args["file_image"] = msg["content"]
|
||||
if msg['type'] == 'text':
|
||||
args['content'] = msg['content']
|
||||
elif msg['type'] == 'image':
|
||||
args['image'] = msg['content']
|
||||
elif msg['type'] == 'file_image':
|
||||
args['file_image'] = msg['content']
|
||||
else:
|
||||
continue
|
||||
|
||||
if target_type == "group":
|
||||
args["channel_id"] = str(target_id)
|
||||
if target_type == 'group':
|
||||
args['channel_id'] = str(target_id)
|
||||
|
||||
await self.bot.api.post_message(**args)
|
||||
elif target_type == "person":
|
||||
args["guild_id"] = str(target_id)
|
||||
elif target_type == 'person':
|
||||
args['guild_id'] = str(target_id)
|
||||
|
||||
await self.bot.api.post_dms(**args)
|
||||
|
||||
@@ -432,86 +452,82 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
|
||||
message_list = self.message_converter.yiri2target(message)
|
||||
|
||||
for msg in message_list:
|
||||
args = {}
|
||||
|
||||
if msg["type"] == "text":
|
||||
args["content"] = msg["content"]
|
||||
elif msg["type"] == "image":
|
||||
args["image"] = msg["content"]
|
||||
elif msg["type"] == "file_image":
|
||||
args["file_image"] = msg["content"]
|
||||
if msg['type'] == 'text':
|
||||
args['content'] = msg['content']
|
||||
elif msg['type'] == 'image':
|
||||
args['image'] = msg['content']
|
||||
elif msg['type'] == 'file_image':
|
||||
args['file_image'] = msg['content']
|
||||
else:
|
||||
continue
|
||||
|
||||
if quote_origin:
|
||||
args["message_reference"] = botpy_message_type.Reference(
|
||||
args['message_reference'] = botpy_message_type.Reference(
|
||||
message_id=cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
)
|
||||
|
||||
if type(message_source) == platform_events.GroupMessage:
|
||||
args["channel_id"] = str(message_source.sender.group.id)
|
||||
args["msg_id"] = cached_message_ids[
|
||||
if isinstance(message_source, platform_events.GroupMessage):
|
||||
args['channel_id'] = str(message_source.sender.group.id)
|
||||
args['msg_id'] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
await self.bot.api.post_message(**args)
|
||||
elif type(message_source) == platform_events.FriendMessage:
|
||||
args["guild_id"] = str(message_source.sender.id)
|
||||
args["msg_id"] = cached_message_ids[
|
||||
elif isinstance(message_source, platform_events.FriendMessage):
|
||||
args['guild_id'] = str(message_source.sender.id)
|
||||
args['msg_id'] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
await self.bot.api.post_dms(**args)
|
||||
elif type(message_source) == OfficialGroupMessage:
|
||||
|
||||
if "file_image" in args: # 暂不支持发送文件图片
|
||||
elif isinstance(message_source, OfficialGroupMessage):
|
||||
if 'file_image' in args: # 暂不支持发送文件图片
|
||||
continue
|
||||
|
||||
args["group_openid"] = message_source.sender.group.id
|
||||
args['group_openid'] = message_source.sender.group.id
|
||||
|
||||
if "image" in args:
|
||||
if 'image' in args:
|
||||
uploadMedia = await self.bot.api.post_group_file(
|
||||
group_openid=args["group_openid"],
|
||||
group_openid=args['group_openid'],
|
||||
file_type=1,
|
||||
url=str(args['image'])
|
||||
url=str(args['image']),
|
||||
)
|
||||
|
||||
del args['image']
|
||||
args['media'] = uploadMedia
|
||||
args['msg_type'] = 7
|
||||
|
||||
args["msg_id"] = cached_message_ids[
|
||||
args['msg_id'] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
args["msg_seq"] = self.group_msg_seq
|
||||
args['msg_seq'] = self.group_msg_seq
|
||||
self.group_msg_seq += 1
|
||||
|
||||
await self.bot.api.post_group_message(**args)
|
||||
elif type(message_source) == OfficialFriendMessage:
|
||||
if "file_image" in args:
|
||||
elif isinstance(message_source, OfficialFriendMessage):
|
||||
if 'file_image' in args:
|
||||
continue
|
||||
args["openid"] = message_source.sender.id
|
||||
args['openid'] = message_source.sender.id
|
||||
|
||||
if "image" in args:
|
||||
if 'image' in args:
|
||||
uploadMedia = await self.bot.api.post_c2c_file(
|
||||
openid=args["openid"],
|
||||
file_type=1,
|
||||
url=str(args['image'])
|
||||
openid=args['openid'], file_type=1, url=str(args['image'])
|
||||
)
|
||||
|
||||
del args['image']
|
||||
args['media'] = uploadMedia
|
||||
args['msg_type'] = 7
|
||||
|
||||
args["msg_id"] = cached_message_ids[
|
||||
args['msg_id'] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
|
||||
args["msg_seq"] = self.c2c_msg_seq
|
||||
args['msg_seq'] = self.c2c_msg_seq
|
||||
self.c2c_msg_seq += 1
|
||||
|
||||
await self.bot.api.post_c2c_message(**args)
|
||||
@@ -526,7 +542,6 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
|
||||
try:
|
||||
|
||||
async def wrapper(
|
||||
@@ -534,7 +549,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
botpy_message.Message,
|
||||
botpy_message.DirectMessage,
|
||||
botpy_message.GroupMessage,
|
||||
]
|
||||
],
|
||||
):
|
||||
self.cached_official_messages[str(message.id)] = message
|
||||
await callback(self.event_converter.target2yiri(message), self)
|
||||
@@ -555,7 +570,6 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
delattr(self.bot, event_handler_mapping[event_type])
|
||||
|
||||
async def run_async(self):
|
||||
|
||||
self.metadata = self.ap.adapter_qq_botpy_meta
|
||||
|
||||
self.message_converter = OfficialMessageConverter()
|
||||
@@ -563,7 +577,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
|
||||
self.cfg['ret_coro'] = True
|
||||
|
||||
self.ap.logger.info("运行 QQ 官方适配器")
|
||||
self.ap.logger.info('运行 QQ 官方适配器')
|
||||
await (await self.bot.start(**self.cfg))
|
||||
|
||||
async def kill(self) -> bool:
|
||||
|
||||
@@ -7,12 +7,8 @@ import datetime
|
||||
|
||||
from pkg.platform.adapter import MessagePlatformAdapter
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
from pkg.core import app
|
||||
from .. import adapter
|
||||
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ...command.errors import ParamNotEnoughError
|
||||
from libs.qq_official_api.api import QQOfficialClient
|
||||
@@ -21,157 +17,164 @@ from ...utils import image
|
||||
|
||||
|
||||
class QQOfficialMessageConverter(adapter.MessageConverter):
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain: platform_message.MessageChain):
|
||||
content_list = []
|
||||
#只实现了发文字
|
||||
# 只实现了发文字
|
||||
for msg in message_chain:
|
||||
if type(msg) is platform_message.Plain:
|
||||
content_list.append({
|
||||
"type":"text",
|
||||
"content":msg.text,
|
||||
})
|
||||
|
||||
content_list.append(
|
||||
{
|
||||
'type': 'text',
|
||||
'content': msg.text,
|
||||
}
|
||||
)
|
||||
|
||||
return content_list
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(message:str,message_id:str,pic_url:str,content_type):
|
||||
async def target2yiri(message: str, message_id: str, pic_url: str, content_type):
|
||||
yiri_msg_list = []
|
||||
yiri_msg_list.append(
|
||||
platform_message.Source(id=message_id,time=datetime.datetime.now())
|
||||
platform_message.Source(id=message_id, time=datetime.datetime.now())
|
||||
)
|
||||
if pic_url is not None:
|
||||
base64_url = await image.get_qq_official_image_base64(pic_url=pic_url,content_type=content_type)
|
||||
yiri_msg_list.append(
|
||||
platform_message.Image(base64=base64_url)
|
||||
base64_url = await image.get_qq_official_image_base64(
|
||||
pic_url=pic_url, content_type=content_type
|
||||
)
|
||||
yiri_msg_list.append(platform_message.Image(base64=base64_url))
|
||||
|
||||
yiri_msg_list.append(platform_message.Plain(text=message))
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
return chain
|
||||
|
||||
|
||||
class QQOfficialEventConverter(adapter.EventConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(event: platform_events.MessageEvent) -> QQOfficialEvent:
|
||||
return event.source_platform_object
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(event:platform_events.MessageEvent) -> QQOfficialEvent:
|
||||
return event.source_platform_object
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(event:QQOfficialEvent):
|
||||
async def target2yiri(event: QQOfficialEvent):
|
||||
"""
|
||||
QQ官方消息转换为LB对象
|
||||
"""
|
||||
yiri_chain = await QQOfficialMessageConverter.target2yiri(
|
||||
message=event.content,message_id=event.d_id,pic_url=event.attachments,content_type=event.content_type
|
||||
message=event.content,
|
||||
message_id=event.d_id,
|
||||
pic_url=event.attachments,
|
||||
content_type=event.content_type,
|
||||
)
|
||||
|
||||
|
||||
if event.t == 'C2C_MESSAGE_CREATE':
|
||||
friend = platform_entities.Friend(
|
||||
id = event.user_openid,
|
||||
nickname = event.t,
|
||||
remark = "",
|
||||
id=event.user_openid,
|
||||
nickname=event.t,
|
||||
remark='',
|
||||
)
|
||||
return platform_events.FriendMessage(
|
||||
sender = friend,message_chain = yiri_chain,time = int(
|
||||
sender=friend,
|
||||
message_chain=yiri_chain,
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
),
|
||||
source_platform_object=event
|
||||
source_platform_object=event,
|
||||
)
|
||||
|
||||
|
||||
if event.t == 'DIRECT_MESSAGE_CREATE':
|
||||
friend = platform_entities.Friend(
|
||||
id = event.guild_id,
|
||||
nickname = event.t,
|
||||
remark = "",
|
||||
id=event.guild_id,
|
||||
nickname=event.t,
|
||||
remark='',
|
||||
)
|
||||
return platform_events.FriendMessage(
|
||||
sender = friend,message_chain = yiri_chain,
|
||||
source_platform_object=event
|
||||
sender=friend, message_chain=yiri_chain, source_platform_object=event
|
||||
)
|
||||
if event.t == 'GROUP_AT_MESSAGE_CREATE':
|
||||
yiri_chain.insert(0, platform_message.At(target="justbot"))
|
||||
yiri_chain.insert(0, platform_message.At(target='justbot'))
|
||||
|
||||
sender = platform_entities.GroupMember(
|
||||
id = event.group_openid,
|
||||
member_name= event.t,
|
||||
permission= 'MEMBER',
|
||||
group = platform_entities.Group(
|
||||
id = event.group_openid,
|
||||
name = 'MEMBER',
|
||||
permission= platform_entities.Permission.Member
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0
|
||||
)
|
||||
time = int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
|
||||
).timestamp()
|
||||
)
|
||||
return platform_events.GroupMessage(
|
||||
sender = sender,
|
||||
message_chain=yiri_chain,
|
||||
time = time,
|
||||
source_platform_object=event
|
||||
)
|
||||
if event.t =='AT_MESSAGE_CREATE':
|
||||
yiri_chain.insert(0, platform_message.At(target="justbot"))
|
||||
sender = platform_entities.GroupMember(
|
||||
id = event.channel_id,
|
||||
id=event.group_openid,
|
||||
member_name=event.t,
|
||||
permission= 'MEMBER',
|
||||
group = platform_entities.Group(
|
||||
id = event.channel_id,
|
||||
name = 'MEMBER',
|
||||
permission=platform_entities.Permission.Member
|
||||
permission='MEMBER',
|
||||
group=platform_entities.Group(
|
||||
id=event.group_openid,
|
||||
name='MEMBER',
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0
|
||||
mute_time_remaining=0,
|
||||
)
|
||||
time = int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
|
||||
).timestamp()
|
||||
)
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
)
|
||||
return platform_events.GroupMessage(
|
||||
sender =sender,
|
||||
message_chain = yiri_chain,
|
||||
time = time,
|
||||
source_platform_object=event
|
||||
sender=sender,
|
||||
message_chain=yiri_chain,
|
||||
time=time,
|
||||
source_platform_object=event,
|
||||
)
|
||||
if event.t == 'AT_MESSAGE_CREATE':
|
||||
yiri_chain.insert(0, platform_message.At(target='justbot'))
|
||||
sender = platform_entities.GroupMember(
|
||||
id=event.channel_id,
|
||||
member_name=event.t,
|
||||
permission='MEMBER',
|
||||
group=platform_entities.Group(
|
||||
id=event.channel_id,
|
||||
name='MEMBER',
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
)
|
||||
time = int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
)
|
||||
return platform_events.GroupMessage(
|
||||
sender=sender,
|
||||
message_chain=yiri_chain,
|
||||
time=time,
|
||||
source_platform_object=event,
|
||||
)
|
||||
|
||||
|
||||
class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||
bot:QQOfficialClient
|
||||
ap:app.Application
|
||||
config:dict
|
||||
bot_account_id:str
|
||||
bot: QQOfficialClient
|
||||
ap: app.Application
|
||||
config: dict
|
||||
bot_account_id: str
|
||||
message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter()
|
||||
event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
|
||||
|
||||
def __init__(self, config:dict, ap:app.Application):
|
||||
def __init__(self, config: dict, ap: app.Application):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
|
||||
required_keys = [
|
||||
"appid",
|
||||
"secret",
|
||||
'appid',
|
||||
'secret',
|
||||
]
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise ParamNotEnoughError("QQ官方机器人缺少相关配置项,请查看文档或联系管理员")
|
||||
|
||||
raise ParamNotEnoughError(
|
||||
'QQ官方机器人缺少相关配置项,请查看文档或联系管理员'
|
||||
)
|
||||
|
||||
self.bot = QQOfficialClient(
|
||||
app_id=config["appid"],
|
||||
secret=config["secret"],
|
||||
token=config["token"],
|
||||
app_id=config['appid'],
|
||||
secret=config['secret'],
|
||||
token=config['token'],
|
||||
)
|
||||
|
||||
async def reply_message(
|
||||
@@ -186,31 +189,45 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
content_list = await QQOfficialMessageConverter.yiri2target(message)
|
||||
|
||||
#私聊消息
|
||||
# 私聊消息
|
||||
if qq_official_event.t == 'C2C_MESSAGE_CREATE':
|
||||
for content in content_list:
|
||||
if content["type"] == 'text':
|
||||
await self.bot.send_private_text_msg(qq_official_event.user_openid,content['content'],qq_official_event.d_id)
|
||||
if content['type'] == 'text':
|
||||
await self.bot.send_private_text_msg(
|
||||
qq_official_event.user_openid,
|
||||
content['content'],
|
||||
qq_official_event.d_id,
|
||||
)
|
||||
|
||||
#群聊消息
|
||||
# 群聊消息
|
||||
if qq_official_event.t == 'GROUP_AT_MESSAGE_CREATE':
|
||||
for content in content_list:
|
||||
if content["type"] == 'text':
|
||||
await self.bot.send_group_text_msg(qq_official_event.group_openid,content['content'],qq_official_event.d_id)
|
||||
|
||||
#频道群聊
|
||||
if content['type'] == 'text':
|
||||
await self.bot.send_group_text_msg(
|
||||
qq_official_event.group_openid,
|
||||
content['content'],
|
||||
qq_official_event.d_id,
|
||||
)
|
||||
|
||||
# 频道群聊
|
||||
if qq_official_event.t == 'AT_MESSAGE_CREATE':
|
||||
for content in content_list:
|
||||
if content["type"] == 'text':
|
||||
await self.bot.send_channle_group_text_msg(qq_official_event.channel_id,content['content'],qq_official_event.d_id)
|
||||
if content['type'] == 'text':
|
||||
await self.bot.send_channle_group_text_msg(
|
||||
qq_official_event.channel_id,
|
||||
content['content'],
|
||||
qq_official_event.d_id,
|
||||
)
|
||||
|
||||
#频道私聊
|
||||
# 频道私聊
|
||||
if qq_official_event.t == 'DIRECT_MESSAGE_CREATE':
|
||||
for content in content_list:
|
||||
if content["type"] == 'text':
|
||||
await self.bot.send_channle_private_text_msg(qq_official_event.guild_id,content['content'],qq_official_event.d_id)
|
||||
|
||||
|
||||
if content['type'] == 'text':
|
||||
await self.bot.send_channle_private_text_msg(
|
||||
qq_official_event.guild_id,
|
||||
content['content'],
|
||||
qq_official_event.d_id,
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
@@ -224,22 +241,21 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
async def on_message(event:QQOfficialEvent):
|
||||
self.bot_account_id = "justbot"
|
||||
async def on_message(event: QQOfficialEvent):
|
||||
self.bot_account_id = 'justbot'
|
||||
try:
|
||||
return await callback(
|
||||
await self.event_converter.target2yiri(event),self
|
||||
await self.event_converter.target2yiri(event), self
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
if event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message("DIRECT_MESSAGE_CREATE")(on_message)
|
||||
self.bot.on_message("C2C_MESSAGE_CREATE")(on_message)
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
self.bot.on_message("GROUP_AT_MESSAGE_CREATE")(on_message)
|
||||
self.bot.on_message("AT_MESSAGE_CREATE")(on_message)
|
||||
|
||||
if event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message('DIRECT_MESSAGE_CREATE')(on_message)
|
||||
self.bot.on_message('C2C_MESSAGE_CREATE')(on_message)
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
self.bot.on_message('GROUP_AT_MESSAGE_CREATE')(on_message)
|
||||
self.bot.on_message('AT_MESSAGE_CREATE')(on_message)
|
||||
|
||||
async def run_async(self):
|
||||
async def shutdown_trigger_placeholder():
|
||||
@@ -248,17 +264,18 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
await self.bot.run_task(
|
||||
host='0.0.0.0',
|
||||
port=self.config["port"],
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
|
||||
@@ -3,48 +3,33 @@ from __future__ import annotations
|
||||
import telegram
|
||||
import telegram.ext
|
||||
from telegram import Update
|
||||
from telegram.ext import ApplicationBuilder, ContextTypes, CommandHandler, MessageHandler, filters
|
||||
from telegram.ext import ApplicationBuilder, ContextTypes, MessageHandler, filters
|
||||
|
||||
import typing
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
import re
|
||||
import base64
|
||||
import uuid
|
||||
import json
|
||||
import datetime
|
||||
import hashlib
|
||||
import base64
|
||||
import aiohttp
|
||||
from Crypto.Cipher import AES
|
||||
|
||||
from flask import jsonify
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from lark_oapi.api.verification.v1 import GetVerificationRequest
|
||||
|
||||
from .. import adapter
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ...utils import image
|
||||
|
||||
|
||||
class TelegramMessageConverter(adapter.MessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain: platform_message.MessageChain, bot: telegram.Bot) -> list[dict]:
|
||||
async def yiri2target(
|
||||
message_chain: platform_message.MessageChain, bot: telegram.Bot
|
||||
) -> list[dict]:
|
||||
components = []
|
||||
|
||||
for component in message_chain:
|
||||
if isinstance(component, platform_message.Plain):
|
||||
components.append({
|
||||
"type": "text",
|
||||
"text": component.text
|
||||
})
|
||||
components.append({'type': 'text', 'text': component.text})
|
||||
elif isinstance(component, platform_message.Image):
|
||||
|
||||
photo_bytes = None
|
||||
|
||||
if component.base64:
|
||||
@@ -54,24 +39,25 @@ class TelegramMessageConverter(adapter.MessageConverter):
|
||||
async with session.get(component.url) as response:
|
||||
photo_bytes = await response.read()
|
||||
elif component.path:
|
||||
with open(component.path, "rb") as f:
|
||||
with open(component.path, 'rb') as f:
|
||||
photo_bytes = f.read()
|
||||
|
||||
components.append({
|
||||
"type": "photo",
|
||||
"photo": photo_bytes
|
||||
})
|
||||
|
||||
components.append({'type': 'photo', 'photo': photo_bytes})
|
||||
elif isinstance(component, platform_message.Forward):
|
||||
for node in component.node_list:
|
||||
components.extend(await TelegramMessageConverter.yiri2target(node.message_chain, bot))
|
||||
components.extend(
|
||||
await TelegramMessageConverter.yiri2target(
|
||||
node.message_chain, bot
|
||||
)
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(message: telegram.Message, bot: telegram.Bot, bot_account_id: str):
|
||||
|
||||
message_components = []
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(
|
||||
message: telegram.Message, bot: telegram.Bot, bot_account_id: str
|
||||
):
|
||||
message_components = []
|
||||
|
||||
def parse_message_text(text: str) -> list[platform_message.MessageComponent]:
|
||||
msg_components = []
|
||||
@@ -86,7 +72,7 @@ class TelegramMessageConverter(adapter.MessageConverter):
|
||||
if message.text:
|
||||
message_text = message.text
|
||||
message_components.extend(parse_message_text(message_text))
|
||||
|
||||
|
||||
if message.photo:
|
||||
message_components.extend(parse_message_text(message.caption))
|
||||
|
||||
@@ -100,21 +86,26 @@ class TelegramMessageConverter(adapter.MessageConverter):
|
||||
file_bytes = await response.read()
|
||||
file_format = 'image/jpeg'
|
||||
|
||||
message_components.append(platform_message.Image(base64=f"data:{file_format};base64,{base64.b64encode(file_bytes).decode('utf-8')}"))
|
||||
|
||||
message_components.append(
|
||||
platform_message.Image(
|
||||
base64=f'data:{file_format};base64,{base64.b64encode(file_bytes).decode("utf-8")}'
|
||||
)
|
||||
)
|
||||
|
||||
return platform_message.MessageChain(message_components)
|
||||
|
||||
|
||||
|
||||
class TelegramEventConverter(adapter.EventConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(event: platform_events.MessageEvent, bot: telegram.Bot):
|
||||
return event.source_platform_object
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(event: Update, bot: telegram.Bot, bot_account_id: str):
|
||||
lb_message = await TelegramMessageConverter.target2yiri(
|
||||
event.message, bot, bot_account_id
|
||||
)
|
||||
|
||||
lb_message = await TelegramMessageConverter.target2yiri(event.message, bot, bot_account_id)
|
||||
|
||||
if event.effective_chat.type == 'private':
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
@@ -124,7 +115,7 @@ class TelegramEventConverter(adapter.EventConverter):
|
||||
),
|
||||
message_chain=lb_message,
|
||||
time=event.message.date.timestamp(),
|
||||
source_platform_object=event
|
||||
source_platform_object=event,
|
||||
)
|
||||
elif event.effective_chat.type == 'group':
|
||||
return platform_events.GroupMessage(
|
||||
@@ -137,19 +128,18 @@ class TelegramEventConverter(adapter.EventConverter):
|
||||
name=event.effective_chat.title,
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title="",
|
||||
special_title='',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=lb_message,
|
||||
time=event.message.date.timestamp(),
|
||||
source_platform_object=event
|
||||
source_platform_object=event,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class TelegramAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
bot: telegram.Bot
|
||||
application: telegram.ext.Application
|
||||
|
||||
@@ -165,26 +155,31 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
|
||||
typing.Type[platform_events.Event],
|
||||
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
] = {}
|
||||
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
|
||||
async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
|
||||
async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
if update.message.from_user.is_bot:
|
||||
return
|
||||
|
||||
try:
|
||||
lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id)
|
||||
lb_event = await self.event_converter.target2yiri(
|
||||
update, self.bot, self.bot_account_id
|
||||
)
|
||||
await self.listeners[type(lb_event)](lb_event, self)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
self.application = ApplicationBuilder().token(self.config['token']).build()
|
||||
self.bot = self.application.bot
|
||||
self.application.add_handler(MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO , telegram_callback))
|
||||
|
||||
self.application.add_handler(
|
||||
MessageHandler(
|
||||
filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback
|
||||
)
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
@@ -198,45 +193,48 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
|
||||
):
|
||||
assert isinstance(message_source.source_platform_object, Update)
|
||||
components = await TelegramMessageConverter.yiri2target(message, self.bot)
|
||||
|
||||
|
||||
for component in components:
|
||||
if component['type'] == 'text':
|
||||
|
||||
args = {
|
||||
"chat_id": message_source.source_platform_object.effective_chat.id,
|
||||
"text": component['text'],
|
||||
'chat_id': message_source.source_platform_object.effective_chat.id,
|
||||
'text': component['text'],
|
||||
}
|
||||
|
||||
if quote_origin:
|
||||
args['reply_to_message_id'] = message_source.source_platform_object.message.id
|
||||
args['reply_to_message_id'] = (
|
||||
message_source.source_platform_object.message.id
|
||||
)
|
||||
|
||||
await self.bot.send_message(**args)
|
||||
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
self.listeners.pop(event_type)
|
||||
|
||||
|
||||
async def run_async(self):
|
||||
await self.application.initialize()
|
||||
self.bot_account_id = (await self.bot.get_me()).username
|
||||
await self.application.updater.start_polling(
|
||||
allowed_updates=Update.ALL_TYPES
|
||||
)
|
||||
await self.application.updater.start_polling(allowed_updates=Update.ALL_TYPES)
|
||||
await self.application.start()
|
||||
|
||||
|
||||
async def kill(self) -> bool:
|
||||
await self.application.stop()
|
||||
return True
|
||||
return True
|
||||
|
||||
@@ -9,17 +9,14 @@ from libs.wecom_api.api import WecomClient
|
||||
from pkg.platform.adapter import MessagePlatformAdapter
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
from libs.wecom_api.wecomevent import WecomEvent
|
||||
from pkg.core import app
|
||||
from .. import adapter
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ...command.errors import ParamNotEnoughError
|
||||
from ...utils import image
|
||||
|
||||
class WecomMessageConverter(adapter.MessageConverter):
|
||||
|
||||
class WecomMessageConverter(adapter.MessageConverter):
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
message_chain: platform_message.MessageChain, bot: WecomClient
|
||||
@@ -28,23 +25,35 @@ class WecomMessageConverter(adapter.MessageConverter):
|
||||
|
||||
for msg in message_chain:
|
||||
if type(msg) is platform_message.Plain:
|
||||
content_list.append({
|
||||
"type": "text",
|
||||
"content": msg.text,
|
||||
})
|
||||
content_list.append(
|
||||
{
|
||||
'type': 'text',
|
||||
'content': msg.text,
|
||||
}
|
||||
)
|
||||
elif type(msg) is platform_message.Image:
|
||||
content_list.append({
|
||||
"type": "image",
|
||||
"media_id": await bot.get_media_id(msg),
|
||||
})
|
||||
content_list.append(
|
||||
{
|
||||
'type': 'image',
|
||||
'media_id': await bot.get_media_id(msg),
|
||||
}
|
||||
)
|
||||
elif type(msg) is platform_message.Forward:
|
||||
for node in msg.node_list:
|
||||
content_list.extend((await WecomMessageConverter.yiri2target(node.message_chain, bot)))
|
||||
content_list.extend(
|
||||
(
|
||||
await WecomMessageConverter.yiri2target(
|
||||
node.message_chain, bot
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
content_list.append({
|
||||
"type": "text",
|
||||
"content": str(msg),
|
||||
})
|
||||
content_list.append(
|
||||
{
|
||||
'type': 'text',
|
||||
'content': str(msg),
|
||||
}
|
||||
)
|
||||
|
||||
return content_list
|
||||
|
||||
@@ -67,14 +76,17 @@ class WecomMessageConverter(adapter.MessageConverter):
|
||||
platform_message.Source(id=message_id, time=datetime.datetime.now())
|
||||
)
|
||||
image_base64, image_format = await image.get_wecom_image_base64(pic_url=picurl)
|
||||
yiri_msg_list.append(platform_message.Image(base64=f"data:image/{image_format};base64,{image_base64}"))
|
||||
yiri_msg_list.append(
|
||||
platform_message.Image(
|
||||
base64=f'data:image/{image_format};base64,{image_base64}'
|
||||
)
|
||||
)
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
class WecomEventConverter:
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(
|
||||
event: platform_events.Event, bot_account_id: int, bot: WecomClient
|
||||
@@ -85,18 +97,17 @@ class WecomEventConverter:
|
||||
pass
|
||||
|
||||
if type(event) is platform_events.FriendMessage:
|
||||
|
||||
payload = {
|
||||
"MsgType": "text",
|
||||
"Content": '',
|
||||
"FromUserName": event.sender.id,
|
||||
"ToUserName": bot_account_id,
|
||||
"CreateTime": int(datetime.datetime.now().timestamp()),
|
||||
"AgentID": event.sender.nickname,
|
||||
'MsgType': 'text',
|
||||
'Content': '',
|
||||
'FromUserName': event.sender.id,
|
||||
'ToUserName': bot_account_id,
|
||||
'CreateTime': int(datetime.datetime.now().timestamp()),
|
||||
'AgentID': event.sender.nickname,
|
||||
}
|
||||
wecom_event = WecomEvent.from_payload(payload=payload)
|
||||
if not wecom_event:
|
||||
raise ValueError("无法从 message_data 构造 WecomEvent 对象")
|
||||
raise ValueError('无法从 message_data 构造 WecomEvent 对象')
|
||||
|
||||
return wecom_event
|
||||
|
||||
@@ -112,24 +123,24 @@ class WecomEventConverter:
|
||||
platform_events.FriendMessage: 转换后的 FriendMessage 对象。
|
||||
"""
|
||||
# 转换消息链
|
||||
if event.type == "text":
|
||||
if event.type == 'text':
|
||||
yiri_chain = await WecomMessageConverter.target2yiri(
|
||||
event.message, event.message_id
|
||||
)
|
||||
friend = platform_entities.Friend(
|
||||
id=f"u{event.user_id}",
|
||||
id=f'u{event.user_id}',
|
||||
nickname=str(event.agent_id),
|
||||
remark="",
|
||||
remark='',
|
||||
)
|
||||
|
||||
return platform_events.FriendMessage(
|
||||
sender=friend, message_chain=yiri_chain, time=event.timestamp
|
||||
)
|
||||
elif event.type == "image":
|
||||
elif event.type == 'image':
|
||||
friend = platform_entities.Friend(
|
||||
id=f"u{event.user_id}",
|
||||
id=f'u{event.user_id}',
|
||||
nickname=str(event.agent_id),
|
||||
remark="",
|
||||
remark='',
|
||||
)
|
||||
|
||||
yiri_chain = await WecomMessageConverter.target2yiri_image(
|
||||
@@ -142,7 +153,6 @@ class WecomEventConverter:
|
||||
|
||||
|
||||
class WecomAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
bot: WecomClient
|
||||
ap: app.Application
|
||||
bot_account_id: str
|
||||
@@ -156,22 +166,22 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
|
||||
self.ap = ap
|
||||
|
||||
required_keys = [
|
||||
"corpid",
|
||||
"secret",
|
||||
"token",
|
||||
"EncodingAESKey",
|
||||
"contacts_secret",
|
||||
'corpid',
|
||||
'secret',
|
||||
'token',
|
||||
'EncodingAESKey',
|
||||
'contacts_secret',
|
||||
]
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise ParamNotEnoughError("企业微信缺少相关配置项,请查看文档或联系管理员")
|
||||
raise ParamNotEnoughError('企业微信缺少相关配置项,请查看文档或联系管理员')
|
||||
|
||||
self.bot = WecomClient(
|
||||
corpid=config["corpid"],
|
||||
secret=config["secret"],
|
||||
token=config["token"],
|
||||
EncodingAESKey=config["EncodingAESKey"],
|
||||
contacts_secret=config["contacts_secret"],
|
||||
corpid=config['corpid'],
|
||||
secret=config['secret'],
|
||||
token=config['token'],
|
||||
EncodingAESKey=config['EncodingAESKey'],
|
||||
contacts_secret=config['contacts_secret'],
|
||||
)
|
||||
|
||||
async def reply_message(
|
||||
@@ -180,7 +190,6 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
|
||||
Wecom_event = await WecomEventConverter.yiri2target(
|
||||
message_source, self.bot_account_id, self.bot
|
||||
)
|
||||
@@ -189,11 +198,15 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
|
||||
# 删掉开头的u
|
||||
fixed_user_id = fixed_user_id[1:]
|
||||
for content in content_list:
|
||||
if content["type"] == "text":
|
||||
await self.bot.send_private_msg(fixed_user_id, Wecom_event.agent_id, content["content"])
|
||||
elif content["type"] == "image":
|
||||
await self.bot.send_image(fixed_user_id, Wecom_event.agent_id, content["media_id"])
|
||||
|
||||
if content['type'] == 'text':
|
||||
await self.bot.send_private_msg(
|
||||
fixed_user_id, Wecom_event.agent_id, content['content']
|
||||
)
|
||||
elif content['type'] == 'image':
|
||||
await self.bot.send_image(
|
||||
fixed_user_id, Wecom_event.agent_id, content['media_id']
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
@@ -201,15 +214,17 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
|
||||
构造target_id的方式为前半部分为账户id,后半部分为agent_id,中间使用“|”符号隔开。
|
||||
"""
|
||||
content_list = await WecomMessageConverter.yiri2target(message, self.bot)
|
||||
parts = target_id.split("|")
|
||||
parts = target_id.split('|')
|
||||
user_id = parts[0]
|
||||
agent_id = int(parts[1])
|
||||
if target_type == 'person':
|
||||
for content in content_list:
|
||||
if content["type"] == "text":
|
||||
await self.bot.send_private_msg(user_id,agent_id,content["content"])
|
||||
if content["type"] == "image":
|
||||
await self.bot.send_image(user_id,agent_id,content["media"])
|
||||
if content['type'] == 'text':
|
||||
await self.bot.send_private_msg(
|
||||
user_id, agent_id, content['content']
|
||||
)
|
||||
if content['type'] == 'image':
|
||||
await self.bot.send_image(user_id, agent_id, content['media'])
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
@@ -224,12 +239,12 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
|
||||
return await callback(
|
||||
await self.event_converter.target2yiri(event), self
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
if event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message("text")(on_message)
|
||||
self.bot.on_message("image")(on_message)
|
||||
self.bot.on_message('text')(on_message)
|
||||
self.bot.on_message('image')(on_message)
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
pass
|
||||
|
||||
@@ -239,8 +254,8 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await self.bot.run_task(
|
||||
host=self.config["host"],
|
||||
port=self.config["port"],
|
||||
host=self.config['host'],
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
|
||||
@@ -250,6 +265,8 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
|
||||
async def unregister_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, MessagePlatformAdapter], None
|
||||
],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from typing import Dict, List, Type
|
||||
|
||||
import pydantic.v1.main as pdm
|
||||
@@ -25,14 +24,18 @@ class PlatformBaseModel(BaseModel, metaclass=PlatformMetaclass):
|
||||
2. 允许通过别名访问字段。
|
||||
3. 自动生成小驼峰风格的别名。
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
""""""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__ + '(' + ', '.join(
|
||||
(f'{k}={repr(v)}' for k, v in self.__dict__.items() if v)
|
||||
) + ')'
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ '('
|
||||
+ ', '.join((f'{k}={repr(v)}' for k, v in self.__dict__.items() if v))
|
||||
+ ')'
|
||||
)
|
||||
|
||||
class Config:
|
||||
extra = 'allow'
|
||||
@@ -42,6 +45,7 @@ class PlatformBaseModel(BaseModel, metaclass=PlatformMetaclass):
|
||||
|
||||
class PlatformIndexedMetaclass(PlatformMetaclass):
|
||||
"""可以通过子类名获取子类的类的元类。"""
|
||||
|
||||
__indexedbases__: List[Type['PlatformIndexedModel']] = []
|
||||
__indexedmodel__ = None
|
||||
|
||||
@@ -69,6 +73,7 @@ class PlatformIndexedMetaclass(PlatformMetaclass):
|
||||
|
||||
class PlatformIndexedModel(PlatformBaseModel, metaclass=PlatformIndexedMetaclass):
|
||||
"""可以通过子类名获取子类的类。"""
|
||||
|
||||
__indexes__: Dict[str, Type['PlatformIndexedModel']]
|
||||
|
||||
@classmethod
|
||||
@@ -86,7 +91,7 @@ class PlatformIndexedModel(PlatformBaseModel, metaclass=PlatformIndexedMetaclass
|
||||
if not (type_ and issubclass(type_, cls)):
|
||||
raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!')
|
||||
return type_
|
||||
except AttributeError as e:
|
||||
except AttributeError:
|
||||
raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!') from None
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
"""
|
||||
此模块提供实体和配置项模型。
|
||||
"""
|
||||
|
||||
import abc
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -12,8 +13,10 @@ import pydantic.v1 as pydantic
|
||||
|
||||
class Entity(pydantic.BaseModel):
|
||||
"""实体,表示一个用户或群。"""
|
||||
|
||||
id: int
|
||||
"""ID。"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""名称。"""
|
||||
@@ -21,31 +24,35 @@ class Entity(pydantic.BaseModel):
|
||||
|
||||
class Friend(Entity):
|
||||
"""私聊对象。"""
|
||||
|
||||
id: typing.Union[int, str]
|
||||
"""ID。"""
|
||||
nickname: typing.Optional[str]
|
||||
"""昵称。"""
|
||||
remark: typing.Optional[str]
|
||||
"""备注。"""
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.nickname or self.remark or ''
|
||||
|
||||
|
||||
|
||||
class Permission(str, Enum):
|
||||
"""群成员身份权限。"""
|
||||
Member = "MEMBER"
|
||||
|
||||
Member = 'MEMBER'
|
||||
"""成员。"""
|
||||
Administrator = "ADMINISTRATOR"
|
||||
Administrator = 'ADMINISTRATOR'
|
||||
"""管理员。"""
|
||||
Owner = "OWNER"
|
||||
Owner = 'OWNER'
|
||||
"""群主。"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.value)
|
||||
|
||||
|
||||
class Group(Entity):
|
||||
"""群。"""
|
||||
|
||||
id: typing.Union[int, str]
|
||||
"""群号。"""
|
||||
name: str
|
||||
@@ -59,6 +66,7 @@ class Group(Entity):
|
||||
|
||||
class GroupMember(Entity):
|
||||
"""群成员。"""
|
||||
|
||||
id: typing.Union[int, str]
|
||||
"""群员 ID。"""
|
||||
member_name: str
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
"""
|
||||
此模块提供事件模型。
|
||||
"""
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
import typing
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
@@ -18,15 +17,23 @@ class Event(pydantic.BaseModel):
|
||||
Args:
|
||||
type: 事件名。
|
||||
"""
|
||||
|
||||
type: str
|
||||
"""事件名。"""
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + ', '.join(
|
||||
(
|
||||
f'{k}={repr(v)}'
|
||||
for k, v in self.__dict__.items() if k != 'type' and v
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ '('
|
||||
+ ', '.join(
|
||||
(
|
||||
f'{k}={repr(v)}'
|
||||
for k, v in self.__dict__.items()
|
||||
if k != 'type' and v
|
||||
)
|
||||
)
|
||||
) + ')'
|
||||
+ ')'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse_subtype(cls, obj: dict) -> 'Event':
|
||||
@@ -52,6 +59,7 @@ class MessageEvent(Event):
|
||||
type: 事件名。
|
||||
message_chain: 消息内容。
|
||||
"""
|
||||
|
||||
type: str
|
||||
"""事件名。"""
|
||||
message_chain: platform_message.MessageChain
|
||||
@@ -74,6 +82,7 @@ class FriendMessage(MessageEvent):
|
||||
sender: 发送消息的好友。
|
||||
message_chain: 消息内容。
|
||||
"""
|
||||
|
||||
type: str = 'FriendMessage'
|
||||
"""事件名。"""
|
||||
sender: platform_entities.Friend
|
||||
@@ -90,12 +99,14 @@ class GroupMessage(MessageEvent):
|
||||
sender: 发送消息的群成员。
|
||||
message_chain: 消息内容。
|
||||
"""
|
||||
|
||||
type: str = 'GroupMessage'
|
||||
"""事件名。"""
|
||||
sender: platform_entities.GroupMember
|
||||
"""发送消息的群成员。"""
|
||||
message_chain: platform_message.MessageChain
|
||||
"""消息内容。"""
|
||||
|
||||
@property
|
||||
def group(self) -> platform_entities.Group:
|
||||
return self.sender.group
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import itertools
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
import typing
|
||||
|
||||
@@ -16,6 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class MessageComponentMetaclass(PlatformIndexedMetaclass):
|
||||
"""消息组件元类。"""
|
||||
|
||||
__message_component__ = None
|
||||
|
||||
def __new__(cls, name, bases, attrs, **kwargs):
|
||||
@@ -41,18 +41,26 @@ class MessageComponentMetaclass(PlatformIndexedMetaclass):
|
||||
|
||||
class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass):
|
||||
"""消息组件。"""
|
||||
|
||||
type: str
|
||||
"""消息组件类型。"""
|
||||
|
||||
def __str__(self):
|
||||
return ''
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + ', '.join(
|
||||
(
|
||||
f'{k}={repr(v)}'
|
||||
for k, v in self.__dict__.items() if k != 'type' and v
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ '('
|
||||
+ ', '.join(
|
||||
(
|
||||
f'{k}={repr(v)}'
|
||||
for k, v in self.__dict__.items()
|
||||
if k != 'type' and v
|
||||
)
|
||||
)
|
||||
) + ')'
|
||||
+ ')'
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# 解析参数列表,将位置参数转化为具名参数
|
||||
@@ -63,7 +71,9 @@ class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass
|
||||
)
|
||||
for name, value in zip(parameter_names, args):
|
||||
if name in kwargs:
|
||||
raise TypeError(f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。')
|
||||
raise TypeError(
|
||||
f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。'
|
||||
)
|
||||
kwargs[name] = value
|
||||
|
||||
super().__init__(**kwargs)
|
||||
@@ -117,6 +127,7 @@ class MessageChain(PlatformBaseModel):
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
__root__: typing.List[MessageComponent]
|
||||
|
||||
@staticmethod
|
||||
@@ -131,10 +142,10 @@ class MessageChain(PlatformBaseModel):
|
||||
result.append(Plain(msg))
|
||||
else:
|
||||
raise TypeError(
|
||||
f"消息链中元素需为 dict 或 str 或 MessageComponent,当前类型:{type(msg)}"
|
||||
f'消息链中元素需为 dict 或 str 或 MessageComponent,当前类型:{type(msg)}'
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@pydantic.validator('__root__', always=True, pre=True)
|
||||
def _parse_component(cls, msg_chain):
|
||||
if isinstance(msg_chain, (str, MessageComponent)):
|
||||
@@ -157,7 +168,7 @@ class MessageChain(PlatformBaseModel):
|
||||
super().__init__(__root__=__root__)
|
||||
|
||||
def __str__(self):
|
||||
return "".join(str(component) for component in self.__root__)
|
||||
return ''.join(str(component) for component in self.__root__)
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__root__!r})'
|
||||
@@ -165,8 +176,9 @@ class MessageChain(PlatformBaseModel):
|
||||
def __iter__(self):
|
||||
yield from self.__root__
|
||||
|
||||
def get_first(self,
|
||||
t: typing.Type[TMessageComponent]) -> typing.Optional[TMessageComponent]:
|
||||
def get_first(
|
||||
self, t: typing.Type[TMessageComponent]
|
||||
) -> typing.Optional[TMessageComponent]:
|
||||
"""获取消息链中第一个符合类型的消息组件。"""
|
||||
for component in self:
|
||||
if isinstance(component, t):
|
||||
@@ -174,35 +186,40 @@ class MessageChain(PlatformBaseModel):
|
||||
return None
|
||||
|
||||
@typing.overload
|
||||
def __getitem__(self, index: int) -> MessageComponent:
|
||||
...
|
||||
def __getitem__(self, index: int) -> MessageComponent: ...
|
||||
|
||||
@typing.overload
|
||||
def __getitem__(self, index: slice) -> typing.List[MessageComponent]:
|
||||
...
|
||||
def __getitem__(self, index: slice) -> typing.List[MessageComponent]: ...
|
||||
|
||||
@typing.overload
|
||||
def __getitem__(self,
|
||||
index: typing.Type[TMessageComponent]) -> typing.List[TMessageComponent]:
|
||||
...
|
||||
def __getitem__(
|
||||
self, index: typing.Type[TMessageComponent]
|
||||
) -> typing.List[TMessageComponent]: ...
|
||||
|
||||
@typing.overload
|
||||
def __getitem__(
|
||||
self, index: typing.Tuple[typing.Type[TMessageComponent], int]
|
||||
) -> typing.List[TMessageComponent]:
|
||||
...
|
||||
) -> typing.List[TMessageComponent]: ...
|
||||
|
||||
def __getitem__(
|
||||
self, index: typing.Union[int, slice, typing.Type[TMessageComponent],
|
||||
typing.Tuple[typing.Type[TMessageComponent], int]]
|
||||
) -> typing.Union[MessageComponent, typing.List[MessageComponent],
|
||||
typing.List[TMessageComponent]]:
|
||||
self,
|
||||
index: typing.Union[
|
||||
int,
|
||||
slice,
|
||||
typing.Type[TMessageComponent],
|
||||
typing.Tuple[typing.Type[TMessageComponent], int],
|
||||
],
|
||||
) -> typing.Union[
|
||||
MessageComponent, typing.List[MessageComponent], typing.List[TMessageComponent]
|
||||
]:
|
||||
return self.get(index)
|
||||
|
||||
def __setitem__(
|
||||
self, key: typing.Union[int, slice],
|
||||
value: typing.Union[MessageComponent, str, typing.Iterable[typing.Union[MessageComponent,
|
||||
str]]]
|
||||
self,
|
||||
key: typing.Union[int, slice],
|
||||
value: typing.Union[
|
||||
MessageComponent, str, typing.Iterable[typing.Union[MessageComponent, str]]
|
||||
],
|
||||
):
|
||||
if isinstance(value, str):
|
||||
value = Plain(value)
|
||||
@@ -217,8 +234,10 @@ class MessageChain(PlatformBaseModel):
|
||||
return reversed(self.__root__)
|
||||
|
||||
def has(
|
||||
self, sub: typing.Union[MessageComponent, typing.Type[MessageComponent],
|
||||
'MessageChain', str]
|
||||
self,
|
||||
sub: typing.Union[
|
||||
MessageComponent, typing.Type[MessageComponent], 'MessageChain', str
|
||||
],
|
||||
) -> bool:
|
||||
"""判断消息链中:
|
||||
1. 是否有某个消息组件。
|
||||
@@ -242,7 +261,7 @@ class MessageChain(PlatformBaseModel):
|
||||
if i == sub:
|
||||
return True
|
||||
return False
|
||||
raise TypeError(f"类型不匹配,当前类型:{type(sub)}")
|
||||
raise TypeError(f'类型不匹配,当前类型:{type(sub)}')
|
||||
|
||||
def __contains__(self, sub) -> bool:
|
||||
return self.has(sub)
|
||||
@@ -293,7 +312,7 @@ class MessageChain(PlatformBaseModel):
|
||||
self,
|
||||
x: typing.Union[MessageComponent, typing.Type[MessageComponent]],
|
||||
i: int = 0,
|
||||
j: int = -1
|
||||
j: int = -1,
|
||||
) -> int:
|
||||
"""返回 x 在消息链中首次出现项的索引号(索引号在 i 或其后且在 j 之前)。
|
||||
|
||||
@@ -323,12 +342,14 @@ class MessageChain(PlatformBaseModel):
|
||||
for index in range(i, j):
|
||||
if type(self[index]) is x:
|
||||
return index
|
||||
raise ValueError("消息链中不存在该类型的组件。")
|
||||
raise ValueError('消息链中不存在该类型的组件。')
|
||||
if isinstance(x, MessageComponent):
|
||||
return self.__root__.index(x, i, j)
|
||||
raise TypeError(f"类型不匹配,当前类型:{type(x)}")
|
||||
raise TypeError(f'类型不匹配,当前类型:{type(x)}')
|
||||
|
||||
def count(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]) -> int:
|
||||
def count(
|
||||
self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]
|
||||
) -> int:
|
||||
"""返回消息链中 x 出现的次数。
|
||||
|
||||
Args:
|
||||
@@ -342,7 +363,7 @@ class MessageChain(PlatformBaseModel):
|
||||
return sum(1 for i in self if type(i) is x)
|
||||
if isinstance(x, MessageComponent):
|
||||
return self.__root__.count(x)
|
||||
raise TypeError(f"类型不匹配,当前类型:{type(x)}")
|
||||
raise TypeError(f'类型不匹配,当前类型:{type(x)}')
|
||||
|
||||
def extend(self, x: typing.Iterable[typing.Union[MessageComponent, str]]):
|
||||
"""将另一个消息链中的元素添加到消息链末尾。
|
||||
@@ -394,7 +415,7 @@ class MessageChain(PlatformBaseModel):
|
||||
def exclude(
|
||||
self,
|
||||
x: typing.Union[MessageComponent, typing.Type[MessageComponent]],
|
||||
count: int = -1
|
||||
count: int = -1,
|
||||
) -> 'MessageChain':
|
||||
"""返回移除指定元素或指定类型的元素后剩余的消息链。
|
||||
|
||||
@@ -405,6 +426,7 @@ class MessageChain(PlatformBaseModel):
|
||||
Returns:
|
||||
MessageChain: 剩余的消息链。
|
||||
"""
|
||||
|
||||
def _exclude():
|
||||
nonlocal count
|
||||
x_is_type = isinstance(x, type)
|
||||
@@ -423,8 +445,7 @@ class MessageChain(PlatformBaseModel):
|
||||
@classmethod
|
||||
def join(cls, *args: typing.Iterable[typing.Union[str, MessageComponent]]):
|
||||
return cls(
|
||||
Plain(c) if isinstance(c, str) else c
|
||||
for c in itertools.chain(*args)
|
||||
Plain(c) if isinstance(c, str) else c for c in itertools.chain(*args)
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -439,14 +460,19 @@ class MessageChain(PlatformBaseModel):
|
||||
return source.id if source else -1
|
||||
|
||||
|
||||
TMessage = typing.Union[MessageChain, typing.Iterable[typing.Union[MessageComponent, str]],
|
||||
MessageComponent, str]
|
||||
TMessage = typing.Union[
|
||||
MessageChain,
|
||||
typing.Iterable[typing.Union[MessageComponent, str]],
|
||||
MessageComponent,
|
||||
str,
|
||||
]
|
||||
"""可以转化为 MessageChain 的类型。"""
|
||||
|
||||
|
||||
class Source(MessageComponent):
|
||||
"""源。包含消息的基本信息。"""
|
||||
type: str = "Source"
|
||||
|
||||
type: str = 'Source'
|
||||
"""消息组件类型。"""
|
||||
id: typing.Union[int, str]
|
||||
"""消息的识别号,用于引用回复(Source 类型永远为 MessageChain 的第一个元素)。"""
|
||||
@@ -456,10 +482,12 @@ class Source(MessageComponent):
|
||||
|
||||
class Plain(MessageComponent):
|
||||
"""纯文本。"""
|
||||
type: str = "Plain"
|
||||
|
||||
type: str = 'Plain'
|
||||
"""消息组件类型。"""
|
||||
text: str
|
||||
"""文字消息。"""
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
@@ -469,7 +497,8 @@ class Plain(MessageComponent):
|
||||
|
||||
class Quote(MessageComponent):
|
||||
"""引用。"""
|
||||
type: str = "Quote"
|
||||
|
||||
type: str = 'Quote'
|
||||
"""消息组件类型。"""
|
||||
id: typing.Optional[int] = None
|
||||
"""被引用回复的原消息的 message_id。"""
|
||||
@@ -482,37 +511,42 @@ class Quote(MessageComponent):
|
||||
origin: MessageChain
|
||||
"""被引用回复的原消息的消息链对象。"""
|
||||
|
||||
@pydantic.validator("origin", always=True, pre=True)
|
||||
@pydantic.validator('origin', always=True, pre=True)
|
||||
def origin_formater(cls, v):
|
||||
return MessageChain.parse_obj(v)
|
||||
|
||||
|
||||
class At(MessageComponent):
|
||||
"""At某人。"""
|
||||
type: str = "At"
|
||||
|
||||
type: str = 'At'
|
||||
"""消息组件类型。"""
|
||||
target: typing.Union[int, str]
|
||||
"""群员 ID。"""
|
||||
display: typing.Optional[str] = None
|
||||
"""At时显示的文字,发送消息时无效,自动使用群名片。"""
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, At) and self.target == other.target
|
||||
|
||||
def __str__(self):
|
||||
return f"@{self.display or self.target}"
|
||||
return f'@{self.display or self.target}'
|
||||
|
||||
|
||||
class AtAll(MessageComponent):
|
||||
"""At全体。"""
|
||||
type: str = "AtAll"
|
||||
|
||||
type: str = 'AtAll'
|
||||
"""消息组件类型。"""
|
||||
|
||||
def __str__(self):
|
||||
return "@全体成员"
|
||||
return '@全体成员'
|
||||
|
||||
|
||||
class Image(MessageComponent):
|
||||
"""图片。"""
|
||||
type: str = "Image"
|
||||
|
||||
type: str = 'Image'
|
||||
"""消息组件类型。"""
|
||||
image_id: typing.Optional[str] = None
|
||||
"""图片的 image_id,不为空时将忽略 url 属性。"""
|
||||
@@ -522,10 +556,13 @@ class Image(MessageComponent):
|
||||
"""图片的路径,发送本地图片。"""
|
||||
base64: typing.Optional[str] = None
|
||||
"""图片的 Base64 编码。"""
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(
|
||||
other, Image
|
||||
) and self.type == other.type and self.uuid == other.uuid
|
||||
return (
|
||||
isinstance(other, Image)
|
||||
and self.type == other.type
|
||||
and self.uuid == other.uuid
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return '[图片]'
|
||||
@@ -537,7 +574,7 @@ class Image(MessageComponent):
|
||||
try:
|
||||
return str(Path(path).resolve(strict=True))
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"无效路径:{path}")
|
||||
raise ValueError(f'无效路径:{path}')
|
||||
else:
|
||||
return path
|
||||
|
||||
@@ -554,7 +591,7 @@ class Image(MessageComponent):
|
||||
self,
|
||||
filename: typing.Union[str, Path, None] = None,
|
||||
directory: typing.Union[str, Path, None] = None,
|
||||
determine_type: bool = True
|
||||
determine_type: bool = True,
|
||||
):
|
||||
"""下载图片到本地。
|
||||
|
||||
@@ -568,6 +605,7 @@ class Image(MessageComponent):
|
||||
return
|
||||
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(self.url)
|
||||
response.raise_for_status()
|
||||
@@ -577,19 +615,20 @@ class Image(MessageComponent):
|
||||
path = Path(filename)
|
||||
if determine_type:
|
||||
import imghdr
|
||||
path = path.with_suffix(
|
||||
'.' + str(imghdr.what(None, content))
|
||||
)
|
||||
|
||||
path = path.with_suffix('.' + str(imghdr.what(None, content)))
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
elif directory:
|
||||
import imghdr
|
||||
|
||||
path = Path(directory)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
path = path / f'{self.uuid}.{imghdr.what(None, content)}'
|
||||
else:
|
||||
raise ValueError("请指定文件路径或文件夹路径!")
|
||||
raise ValueError('请指定文件路径或文件夹路径!')
|
||||
|
||||
import aiofiles
|
||||
|
||||
async with aiofiles.open(path, 'wb') as f:
|
||||
await f.write(content)
|
||||
|
||||
@@ -600,7 +639,7 @@ class Image(MessageComponent):
|
||||
cls,
|
||||
filename: typing.Union[str, Path, None] = None,
|
||||
content: typing.Optional[bytes] = None,
|
||||
) -> "Image":
|
||||
) -> 'Image':
|
||||
"""从本地文件路径加载图片,以 base64 的形式传递。
|
||||
|
||||
Args:
|
||||
@@ -615,16 +654,18 @@ class Image(MessageComponent):
|
||||
elif filename:
|
||||
path = Path(filename)
|
||||
import aiofiles
|
||||
|
||||
async with aiofiles.open(path, 'rb') as f:
|
||||
content = await f.read()
|
||||
else:
|
||||
raise ValueError("请指定图片路径或图片内容!")
|
||||
raise ValueError('请指定图片路径或图片内容!')
|
||||
import base64
|
||||
|
||||
img = cls(base64=base64.b64encode(content).decode())
|
||||
return img
|
||||
|
||||
@classmethod
|
||||
def from_unsafe_path(cls, path: typing.Union[str, Path]) -> "Image":
|
||||
def from_unsafe_path(cls, path: typing.Union[str, Path]) -> 'Image':
|
||||
"""从不安全的路径加载图片。
|
||||
|
||||
Args:
|
||||
@@ -638,7 +679,8 @@ class Image(MessageComponent):
|
||||
|
||||
class Unknown(MessageComponent):
|
||||
"""未知。"""
|
||||
type: str = "Unknown"
|
||||
|
||||
type: str = 'Unknown'
|
||||
"""消息组件类型。"""
|
||||
text: str
|
||||
"""文本。"""
|
||||
@@ -646,7 +688,8 @@ class Unknown(MessageComponent):
|
||||
|
||||
class Voice(MessageComponent):
|
||||
"""语音。"""
|
||||
type: str = "Voice"
|
||||
|
||||
type: str = 'Voice'
|
||||
"""消息组件类型。"""
|
||||
voice_id: typing.Optional[str] = None
|
||||
"""语音的 voice_id,不为空时将忽略 url 属性。"""
|
||||
@@ -658,6 +701,7 @@ class Voice(MessageComponent):
|
||||
"""语音的 Base64 编码。"""
|
||||
length: typing.Optional[int] = None
|
||||
"""语音的长度,单位为秒。"""
|
||||
|
||||
@pydantic.validator('path')
|
||||
def validate_path(cls, path: typing.Optional[str]):
|
||||
"""修复 path 参数的行为,使之相对于 LangBot 的启动路径。"""
|
||||
@@ -665,7 +709,7 @@ class Voice(MessageComponent):
|
||||
try:
|
||||
return str(Path(path).resolve(strict=True))
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"无效路径:{path}")
|
||||
raise ValueError(f'无效路径:{path}')
|
||||
else:
|
||||
return path
|
||||
|
||||
@@ -675,7 +719,7 @@ class Voice(MessageComponent):
|
||||
async def download(
|
||||
self,
|
||||
filename: typing.Union[str, Path, None] = None,
|
||||
directory: typing.Union[str, Path, None] = None
|
||||
directory: typing.Union[str, Path, None] = None,
|
||||
):
|
||||
"""下载语音到本地。
|
||||
|
||||
@@ -688,6 +732,7 @@ class Voice(MessageComponent):
|
||||
return
|
||||
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(self.url)
|
||||
response.raise_for_status()
|
||||
@@ -701,9 +746,10 @@ class Voice(MessageComponent):
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
path = path / f'{self.voice_id}.silk'
|
||||
else:
|
||||
raise ValueError("请指定文件路径或文件夹路径!")
|
||||
raise ValueError('请指定文件路径或文件夹路径!')
|
||||
|
||||
import aiofiles
|
||||
|
||||
async with aiofiles.open(path, 'wb') as f:
|
||||
await f.write(content)
|
||||
|
||||
@@ -712,7 +758,7 @@ class Voice(MessageComponent):
|
||||
cls,
|
||||
filename: typing.Union[str, Path, None] = None,
|
||||
content: typing.Optional[bytes] = None,
|
||||
) -> "Voice":
|
||||
) -> 'Voice':
|
||||
"""从本地文件路径加载语音,以 base64 的形式传递。
|
||||
|
||||
Args:
|
||||
@@ -724,17 +770,20 @@ class Voice(MessageComponent):
|
||||
if filename:
|
||||
path = Path(filename)
|
||||
import aiofiles
|
||||
|
||||
async with aiofiles.open(path, 'rb') as f:
|
||||
content = await f.read()
|
||||
else:
|
||||
raise ValueError("请指定语音路径或语音内容!")
|
||||
raise ValueError('请指定语音路径或语音内容!')
|
||||
import base64
|
||||
|
||||
img = cls(base64=base64.b64encode(content).decode())
|
||||
return img
|
||||
|
||||
|
||||
class ForwardMessageNode(pydantic.BaseModel):
|
||||
"""合并转发中的一条消息。"""
|
||||
|
||||
sender_id: typing.Optional[typing.Union[int, str]] = None
|
||||
"""发送人ID。"""
|
||||
sender_name: typing.Optional[str] = None
|
||||
@@ -745,6 +794,7 @@ class ForwardMessageNode(pydantic.BaseModel):
|
||||
"""消息的 message_id。"""
|
||||
time: typing.Optional[datetime] = None
|
||||
"""发送时间。"""
|
||||
|
||||
@pydantic.validator('message_chain', check_fields=False)
|
||||
def _validate_message_chain(cls, value: typing.Union[MessageChain, list]):
|
||||
if isinstance(value, list):
|
||||
@@ -753,7 +803,9 @@ class ForwardMessageNode(pydantic.BaseModel):
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, sender: typing.Union[platform_entities.Friend, platform_entities.GroupMember], message: MessageChain
|
||||
cls,
|
||||
sender: typing.Union[platform_entities.Friend, platform_entities.GroupMember],
|
||||
message: MessageChain,
|
||||
) -> 'ForwardMessageNode':
|
||||
"""从消息链生成转发消息。
|
||||
|
||||
@@ -765,28 +817,28 @@ class ForwardMessageNode(pydantic.BaseModel):
|
||||
ForwardMessageNode: 生成的一条消息。
|
||||
"""
|
||||
return ForwardMessageNode(
|
||||
sender_id=sender.id,
|
||||
sender_name=sender.get_name(),
|
||||
message_chain=message
|
||||
sender_id=sender.id, sender_name=sender.get_name(), message_chain=message
|
||||
)
|
||||
|
||||
|
||||
class ForwardMessageDiaplay(pydantic.BaseModel):
|
||||
title: str = "群聊的聊天记录"
|
||||
brief: str = "[聊天记录]"
|
||||
source: str = "聊天记录"
|
||||
title: str = '群聊的聊天记录'
|
||||
brief: str = '[聊天记录]'
|
||||
source: str = '聊天记录'
|
||||
preview: typing.List[str] = []
|
||||
summary: str = "查看x条转发消息"
|
||||
summary: str = '查看x条转发消息'
|
||||
|
||||
|
||||
class Forward(MessageComponent):
|
||||
"""合并转发。"""
|
||||
type: str = "Forward"
|
||||
|
||||
type: str = 'Forward'
|
||||
"""消息组件类型。"""
|
||||
display: ForwardMessageDiaplay
|
||||
"""显示信息"""
|
||||
node_list: typing.List[ForwardMessageNode]
|
||||
"""转发消息节点列表。"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if len(args) == 1:
|
||||
self.node_list = args[0]
|
||||
@@ -799,7 +851,8 @@ class Forward(MessageComponent):
|
||||
|
||||
class File(MessageComponent):
|
||||
"""文件。"""
|
||||
type: str = "File"
|
||||
|
||||
type: str = 'File'
|
||||
"""消息组件类型。"""
|
||||
id: str
|
||||
"""文件识别 ID。"""
|
||||
@@ -807,6 +860,6 @@ class File(MessageComponent):
|
||||
"""文件名称。"""
|
||||
size: int
|
||||
"""文件大小。"""
|
||||
|
||||
def __str__(self):
|
||||
return f'[文件]{self.name}'
|
||||
|
||||
|
||||
Reference in New Issue
Block a user