chore: 修改包名

This commit is contained in:
RockChinQ
2024-01-28 19:20:10 +08:00
parent 698782c537
commit b730f17eb6
45 changed files with 27 additions and 27 deletions

0
pkg/platform/__init__.py Normal file
View File

138
pkg/platform/adapter.py Normal file
View File

@@ -0,0 +1,138 @@
# MessageSource的适配器
import typing
import abc
import mirai
class MessageSourceAdapter(metaclass=abc.ABCMeta):
bot_account_id: int
def __init__(self, config: dict):
pass
async def send_message(
self,
target_type: str,
target_id: str,
message: mirai.MessageChain
):
"""发送消息
Args:
target_type (str): 目标类型,`person`或`group`
target_id (str): 目标ID
message (mirai.MessageChain): YiriMirai库的消息链
"""
raise NotImplementedError
async def reply_message(
self,
message_source: mirai.MessageEvent,
message: mirai.MessageChain,
quote_origin: bool = False
):
"""回复消息
Args:
message_source (mirai.MessageEvent): YiriMirai消息源事件
message (mirai.MessageChain): YiriMirai库的消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
"""
raise NotImplementedError
async def is_muted(self, group_id: int) -> bool:
"""获取账号是否在指定群被禁言"""
raise NotImplementedError
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
):
"""注册事件监听器
Args:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
"""
raise NotImplementedError
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
):
"""注销事件监听器
Args:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
"""
raise NotImplementedError
async def run_async(self):
"""异步运行"""
raise NotImplementedError
def kill(self) -> bool:
"""关闭适配器
Returns:
bool: 是否成功关闭热重载时若此函数返回False则不会重载MessageSource底层
"""
raise NotImplementedError
class MessageConverter:
"""消息链转换器基类"""
@staticmethod
def yiri2target(message_chain: mirai.MessageChain):
"""将YiriMirai消息链转换为目标消息链
Args:
message_chain (mirai.MessageChain): YiriMirai消息链
Returns:
typing.Any: 目标消息链
"""
raise NotImplementedError
@staticmethod
def target2yiri(message_chain: typing.Any) -> mirai.MessageChain:
"""将目标消息链转换为YiriMirai消息链
Args:
message_chain (typing.Any): 目标消息链
Returns:
mirai.MessageChain: YiriMirai消息链
"""
raise NotImplementedError
class EventConverter:
"""事件转换器基类"""
@staticmethod
def yiri2target(event: typing.Type[mirai.Event]):
"""将YiriMirai事件转换为目标事件
Args:
event (typing.Type[mirai.Event]): YiriMirai事件
Returns:
typing.Any: 目标事件
"""
raise NotImplementedError
@staticmethod
def target2yiri(event: typing.Any) -> mirai.Event:
"""将目标事件的调用参数转换为YiriMirai的事件参数对象
Args:
event (typing.Any): 目标事件
Returns:
typing.Type[mirai.Event]: YiriMirai事件
"""
raise NotImplementedError

154
pkg/platform/manager.py Normal file
View File

@@ -0,0 +1,154 @@
from __future__ import annotations
import json
import os
import logging
import asyncio
from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \
FriendMessage, Image, MessageChain, Plain
import mirai
import func_timeout
from ..provider import session as openai_session
from ..utils import context
import tips as tips_custom
from ..platform import adapter as msadapter
from .ratelim import ratelim
from ..core import app, entities as core_entities
# 控制QQ消息输入输出的类
class QQBotManager:
adapter: msadapter.MessageSourceAdapter = None
bot_account_id: int = 0
# modern
ap: app.Application = None
ratelimiter: ratelim.RateLimiter = None
def __init__(self, ap: app.Application = None):
self.ap = ap
self.ratelimiter = ratelim.RateLimiter(ap)
async def initialize(self):
await self.ratelimiter.initialize()
config = context.get_config_manager().data
logging.debug("Use adapter:" + config['msg_source_adapter'])
if config['msg_source_adapter'] == 'yirimirai':
from pkg.platform.sources.yirimirai import YiriMiraiAdapter
mirai_http_api_config = config['mirai_http_api_config']
self.bot_account_id = config['mirai_http_api_config']['qq']
self.adapter = YiriMiraiAdapter(mirai_http_api_config)
elif config['msg_source_adapter'] == 'nakuru':
from pkg.platform.sources.nakuru import NakuruProjectAdapter
self.adapter = NakuruProjectAdapter(config['nakuru_config'])
self.bot_account_id = self.adapter.bot_account_id
# 保存 account_id 到审计模块
from ..utils.center import apigroup
apigroup.APIGroup._runtime_info['account_id'] = "{}".format(self.bot_account_id)
async def on_friend_message(event: FriendMessage):
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
)
self.adapter.register_listener(
FriendMessage,
on_friend_message
)
async def on_stranger_message(event: StrangerMessage):
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
)
# nakuru不区分好友和陌生人故仅为yirimirai注册陌生人事件
if config['msg_source_adapter'] == 'yirimirai':
self.adapter.register_listener(
StrangerMessage,
on_stranger_message
)
async def on_group_message(event: GroupMessage):
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.GROUP,
launcher_id=event.group.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
)
self.adapter.register_listener(
GroupMessage,
on_group_message
)
async def send(self, event, msg, check_quote=True, check_at_sender=True):
config = context.get_config_manager().data
if check_at_sender and config['at_sender']:
msg.insert(
0,
Plain(" \n")
)
# 当回复的正文中包含换行时quote可能会自带at此时就不再单独添加at只添加换行
if "\n" not in str(msg[1]) or config['msg_source_adapter'] == 'nakuru':
msg.insert(
0,
At(
event.sender.id
)
)
await self.adapter.reply_message(
event,
msg,
quote_origin=True if config['quote_origin'] and check_quote else False
)
# 通知系统管理员
async def notify_admin(self, message: str):
await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
async def notify_admin_message_chain(self, message: mirai.MessageChain):
config = context.get_config_manager().data
if config['admin_qq'] != 0 and config['admin_qq'] != []:
logging.info("通知管理员:{}".format(message))
admin_list = []
if type(config['admin_qq']) == int:
admin_list.append(config['admin_qq'])
for adm in admin_list:
self.adapter.send_message(
"person",
adm,
message
)
async def run(self):
await self.adapter.run_async()

View File

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
import abc
from ...core import app
class ReteLimitAlgo(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
raise NotImplementedError
@abc.abstractmethod
async def release_access(self, launcher_type: str, launcher_id: int):
raise NotImplementedError

View File

View File

@@ -0,0 +1,85 @@
# 固定窗口算法
from __future__ import annotations
import asyncio
import time
from .. import algo
class SessionContainer:
wait_lock: asyncio.Lock
records: dict[int, int]
"""访问记录key为每分钟的起始时间戳value为访问次数"""
def __init__(self):
self.wait_lock = asyncio.Lock()
self.records = {}
class FixedWindowAlgo(algo.ReteLimitAlgo):
containers_lock: asyncio.Lock
"""访问记录容器锁"""
containers: dict[str, SessionContainer]
"""访问记录容器key为launcher_type launcher_id"""
async def initialize(self):
self.containers_lock = asyncio.Lock()
self.containers = {}
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
# 加锁,找容器
container: SessionContainer = None
session_name = f'{launcher_type}_{launcher_id}'
async with self.containers_lock:
container = self.containers.get(session_name)
if container is None:
container = SessionContainer()
self.containers[session_name] = container
# 等待锁
async with container.wait_lock:
# 获取当前时间戳
now = int(time.time())
# 获取当前分钟的起始时间戳
now = now - now % 60
# 获取当前分钟的访问次数
count = container.records.get(now, 0)
limitation = self.ap.cfg_mgr.data['rate_limitation']['default']
if session_name in self.ap.cfg_mgr.data['rate_limitation']:
limitation = self.ap.cfg_mgr.data['rate_limitation'][session_name]
# 如果访问次数超过了限制
if count >= limitation:
if self.ap.cfg_mgr.data['rate_limit_strategy'] == 'drop':
return False
elif self.ap.cfg_mgr.data['rate_limit_strategy'] == 'wait':
# 等待下一分钟
await asyncio.sleep(60 - time.time() % 60)
now = int(time.time())
now = now - now % 60
if now not in container.records:
container.records = {}
container.records[now] = 1
else:
# 访问次数加一
container.records[now] = count + 1
# 返回True
return True
async def release_access(self, launcher_type: str, launcher_id: int):
pass

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from . import algo
from .algos import fixedwin
from ...core import app
class RateLimiter:
"""限速器
"""
ap: app.Application
algo: algo.ReteLimitAlgo
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
self.algo = fixedwin.FixedWindowAlgo(self.ap)
await self.algo.initialize()
async def require(self, launcher_type: str, launcher_id: int) -> bool:
"""请求访问
"""
return await self.algo.require_access(launcher_type, launcher_id)
async def release(self, launcher_type: str, launcher_id: int):
"""释放访问
"""
return await self.algo.release_access(launcher_type, launcher_id)

View File

View File

@@ -0,0 +1,331 @@
import asyncio
import typing
import traceback
import logging
import mirai
import nakuru
import nakuru.entities.components as nkc
from .. import adapter as adapter_model
from ...platform import blob
from ...utils import context
class NakuruProjectMessageConverter(adapter_model.MessageConverter):
"""消息转换器"""
@staticmethod
def yiri2target(message_chain: mirai.MessageChain) -> list:
msg_list = []
if type(message_chain) is mirai.MessageChain:
msg_list = message_chain.__root__
elif type(message_chain) is list:
msg_list = message_chain
else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
nakuru_msg_list = []
# 遍历并转换
for component in msg_list:
if type(component) is mirai.Plain:
nakuru_msg_list.append(nkc.Plain(component.text, False))
elif type(component) is mirai.Image:
if component.url is not None:
nakuru_msg_list.append(nkc.Image.fromURL(component.url))
elif component.base64 is not None:
nakuru_msg_list.append(nkc.Image.fromBase64(component.base64))
elif component.path is not None:
nakuru_msg_list.append(nkc.Image.fromFileSystem(component.path))
elif type(component) is mirai.Face:
nakuru_msg_list.append(nkc.Face(id=component.face_id))
elif type(component) is mirai.At:
nakuru_msg_list.append(nkc.At(qq=component.target))
elif type(component) is mirai.AtAll:
nakuru_msg_list.append(nkc.AtAll())
elif type(component) is mirai.Voice:
if component.url is not None:
nakuru_msg_list.append(nkc.Record.fromURL(component.url))
elif component.path is not None:
nakuru_msg_list.append(nkc.Record.fromFileSystem(component.path))
elif type(component) is blob.Forward:
# 转发消息
yiri_forward_node_list = component.node_list
nakuru_forward_node_list = []
# 遍历并转换
for yiri_forward_node in yiri_forward_node_list:
try:
content_list = NakuruProjectMessageConverter.yiri2target(yiri_forward_node.message_chain)
nakuru_forward_node = nkc.Node(
name=yiri_forward_node.sender_name,
uin=yiri_forward_node.sender_id,
time=int(yiri_forward_node.time.timestamp()) if yiri_forward_node.time is not None else None,
content=content_list
)
nakuru_forward_node_list.append(nakuru_forward_node)
except Exception as e:
import traceback
traceback.print_exc()
nakuru_msg_list.append(nakuru_forward_node_list)
else:
nakuru_msg_list.append(nkc.Plain(str(component)))
return nakuru_msg_list
@staticmethod
def target2yiri(message_chain: typing.Any, message_id: int = -1) -> mirai.MessageChain:
"""将Yiri的消息链转换为YiriMirai的消息链"""
assert type(message_chain) is list
yiri_msg_list = []
import datetime
# 添加Source组件以标记message_id等信息
yiri_msg_list.append(mirai.models.message.Source(id=message_id, time=datetime.datetime.now()))
for component in message_chain:
if type(component) is nkc.Plain:
yiri_msg_list.append(mirai.Plain(text=component.text))
elif type(component) is nkc.Image:
yiri_msg_list.append(mirai.Image(url=component.url))
elif type(component) is nkc.Face:
yiri_msg_list.append(mirai.Face(face_id=component.id))
elif type(component) is nkc.At:
yiri_msg_list.append(mirai.At(target=component.qq))
elif type(component) is nkc.AtAll:
yiri_msg_list.append(mirai.AtAll())
else:
pass
logging.debug("转换后的消息链: " + str(yiri_msg_list))
chain = mirai.MessageChain(yiri_msg_list)
return chain
class NakuruProjectEventConverter(adapter_model.EventConverter):
"""事件转换器"""
@staticmethod
def yiri2target(event: typing.Type[mirai.Event]):
if event is mirai.GroupMessage:
return nakuru.GroupMessage
elif event is mirai.FriendMessage:
return nakuru.FriendMessage
else:
raise Exception("未支持转换的事件类型: " + str(event))
@staticmethod
def target2yiri(event: typing.Any) -> mirai.Event:
yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id)
if type(event) is nakuru.FriendMessage: # 私聊消息事件
return mirai.FriendMessage(
sender=mirai.models.entities.Friend(
id=event.sender.user_id,
nickname=event.sender.nickname,
remark=event.sender.nickname
),
message_chain=yiri_chain,
time=event.time
)
elif type(event) is nakuru.GroupMessage: # 群聊消息事件
permission = "MEMBER"
if event.sender.role == "admin":
permission = "ADMINISTRATOR"
elif event.sender.role == "owner":
permission = "OWNER"
import mirai.models.entities as entities
return mirai.GroupMessage(
sender=mirai.models.entities.GroupMember(
id=event.sender.user_id,
member_name=event.sender.nickname,
permission=permission,
group=mirai.models.entities.Group(
id=event.group_id,
name=event.sender.nickname,
permission=entities.Permission.Member
),
special_title=event.sender.title,
join_timestamp=0,
last_speak_timestamp=0,
mute_time_remaining=0,
),
message_chain=yiri_chain,
time=event.time
)
else:
raise Exception("未支持转换的事件类型: " + str(event))
class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
"""nakuru-project适配器"""
bot: nakuru.CQHTTP
bot_account_id: int
message_converter: NakuruProjectMessageConverter = NakuruProjectMessageConverter()
event_converter: NakuruProjectEventConverter = NakuruProjectEventConverter()
listener_list: list[dict]
def __init__(self, cfg: dict):
"""初始化nakuru-project的对象"""
self.bot = nakuru.CQHTTP(**cfg)
self.listener_list = []
# nakuru库有bug这个接口没法带access_token会失败
# 所以目前自行发请求
config = context.get_config_manager().data
import requests
resp = requests.get(
url="http://{}:{}/get_login_info".format(config['nakuru_config']['host'], config['nakuru_config']['http_port']),
headers={
'Authorization': "Bearer " + config['nakuru_config']['token'] if 'token' in config['nakuru_config']else ""
},
timeout=5,
proxies=None
)
if resp.status_code == 403:
logging.error("go-cqhttp拒绝访问请检查config.py中nakuru_config的token是否与go-cqhttp设置的access-token匹配")
raise Exception("go-cqhttp拒绝访问请检查config.py中nakuru_config的token是否与go-cqhttp设置的access-token匹配")
try:
self.bot_account_id = int(resp.json()['data']['user_id'])
except Exception as e:
logging.error("获取go-cqhttp账号信息失败: {}, 请检查是否已启动go-cqhttp并配置正确".format(e))
raise Exception("获取go-cqhttp账号信息失败: {}, 请检查是否已启动go-cqhttp并配置正确".format(e))
def send_message(
self,
target_type: str,
target_id: str,
message: typing.Union[mirai.MessageChain, list],
converted: bool = False
):
task = None
converted_msg = self.message_converter.yiri2target(message) if not converted else message
# 检查是否有转发消息
has_forward = False
for msg in converted_msg:
if type(msg) is list: # 转发消息,仅回复此消息组件
has_forward = True
converted_msg = msg
break
if has_forward:
if target_type == "group":
task = self.bot.sendGroupForwardMessage(int(target_id), converted_msg)
elif target_type == "person":
task = self.bot.sendPrivateForwardMessage(int(target_id), converted_msg)
else:
raise Exception("Unknown target type: " + target_type)
else:
if target_type == "group":
task = self.bot.sendGroupMessage(int(target_id), converted_msg)
elif target_type == "person":
task = self.bot.sendFriendMessage(int(target_id), converted_msg)
else:
raise Exception("Unknown target type: " + target_type)
asyncio.run(task)
def reply_message(
self,
message_source: mirai.MessageEvent,
message: mirai.MessageChain,
quote_origin: bool = False
):
message = self.message_converter.yiri2target(message)
if quote_origin:
# 在前方添加引用组件
message.insert(0, nkc.Reply(
id=message_source.message_chain.message_id,
)
)
if type(message_source) is mirai.GroupMessage:
self.send_message(
"group",
message_source.sender.group.id,
message,
converted=True
)
elif type(message_source) is mirai.FriendMessage:
self.send_message(
"person",
message_source.sender.id,
message,
converted=True
)
else:
raise Exception("Unknown message source type: " + str(type(message_source)))
def is_muted(self, group_id: int) -> bool:
import time
# 检查是否被禁言
group_member_info = asyncio.run(self.bot.getGroupMemberInfo(group_id, self.bot_account_id))
return group_member_info.shut_up_timestamp > int(time.time())
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
):
try:
logging.debug("注册监听器: " + str(event_type) + " -> " + str(callback))
# 包装函数
async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.event_converter.yiri2target(event_type)):
callback(self.event_converter.target2yiri(source))
# 将包装函数和原函数的对应关系存入列表
self.listener_list.append(
{
"event_type": event_type,
"callable": callback,
"wrapper": listener_wrapper,
}
)
# 注册监听器
self.bot.receiver(self.event_converter.yiri2target(event_type).__name__)(listener_wrapper)
logging.debug("注册完成")
except Exception as e:
traceback.print_exc()
raise e
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
):
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
new_event_list = []
# 从本对象的监听器列表中查找并删除
target_wrapper = None
for listener in self.listener_list:
if listener["event_type"] == event_type and listener["callable"] == callback:
target_wrapper = listener["wrapper"]
self.listener_list.remove(listener)
break
if target_wrapper is None:
raise Exception("未找到对应的监听器")
for func in self.bot.event[nakuru_event_name]:
if func.callable != target_wrapper:
new_event_list.append(func)
self.bot.event[nakuru_event_name] = new_event_list
def run_sync(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.bot.run()
async def run_async(self):
return await self.bot._run()
def kill(self) -> bool:
return False

View File

@@ -0,0 +1,117 @@
import asyncio
import typing
import mirai
import mirai.models.bus
from mirai.bot import MiraiRunner
from .. import adapter as adapter_model
class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
"""YiriMirai适配器"""
bot: mirai.Mirai
def __init__(self, config: dict):
"""初始化YiriMirai的对象"""
if 'adapter' not in config or \
config['adapter'] == 'WebSocketAdapter':
self.bot = mirai.Mirai(
qq=config['qq'],
adapter=mirai.WebSocketAdapter(
host=config['host'],
port=config['port'],
verify_key=config['verifyKey']
)
)
elif config['adapter'] == 'HTTPAdapter':
self.bot = mirai.Mirai(
qq=config['qq'],
adapter=mirai.HTTPAdapter(
host=config['host'],
port=config['port'],
verify_key=config['verifyKey']
)
)
else:
raise Exception('Unknown adapter for YiriMirai: ' + config['adapter'])
async def send_message(
self,
target_type: str,
target_id: str,
message: mirai.MessageChain
):
"""发送消息
Args:
target_type (str): 目标类型,`person`或`group`
target_id (str): 目标ID
message (mirai.MessageChain): YiriMirai库的消息链
"""
task = None
if target_type == 'person':
task = self.bot.send_friend_message(int(target_id), message)
elif target_type == 'group':
task = self.bot.send_group_message(int(target_id), message)
else:
raise Exception('Unknown target type: ' + target_type)
await task
async def reply_message(
self,
message_source: mirai.MessageEvent,
message: mirai.MessageChain,
quote_origin: bool = False
):
"""回复消息
Args:
message_source (mirai.MessageEvent): YiriMirai消息源事件
message (mirai.MessageChain): YiriMirai库的消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
"""
await self.bot.send(message_source, message, quote_origin)
async def is_muted(self, group_id: int) -> bool:
result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get()
if result.mute_time_remaining > 0:
return True
return False
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
):
"""注册事件监听器
Args:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
"""
self.bot.on(event_type)(callback)
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
):
"""注销事件监听器
Args:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
"""
assert isinstance(self.bot, mirai.Mirai)
bus = self.bot.bus
assert isinstance(bus, mirai.models.bus.ModelEventBus)
bus.unsubscribe(event_type, callback)
async def run_async(self):
return await MiraiRunner(self.bot)._run()
def kill(self) -> bool:
return False