mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-09 15:26:03 +00:00
chore: 修改包名
This commit is contained in:
0
pkg/platform/__init__.py
Normal file
0
pkg/platform/__init__.py
Normal file
138
pkg/platform/adapter.py
Normal file
138
pkg/platform/adapter.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# MessageSource的适配器
|
||||
import typing
|
||||
import abc
|
||||
|
||||
import mirai
|
||||
|
||||
|
||||
class MessageSourceAdapter(metaclass=abc.ABCMeta):
|
||||
bot_account_id: int
|
||||
def __init__(self, config: dict):
|
||||
pass
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: mirai.MessageChain
|
||||
):
|
||||
"""发送消息
|
||||
|
||||
Args:
|
||||
target_type (str): 目标类型,`person`或`group`
|
||||
target_id (str): 目标ID
|
||||
message (mirai.MessageChain): YiriMirai库的消息链
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: mirai.MessageEvent,
|
||||
message: mirai.MessageChain,
|
||||
quote_origin: bool = False
|
||||
):
|
||||
"""回复消息
|
||||
|
||||
Args:
|
||||
message_source (mirai.MessageEvent): YiriMirai消息源事件
|
||||
message (mirai.MessageChain): YiriMirai库的消息链
|
||||
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
"""获取账号是否在指定群被禁言"""
|
||||
raise NotImplementedError
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
):
|
||||
"""注册事件监听器
|
||||
|
||||
Args:
|
||||
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
|
||||
callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
):
|
||||
"""注销事件监听器
|
||||
|
||||
Args:
|
||||
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
|
||||
callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def run_async(self):
|
||||
"""异步运行"""
|
||||
raise NotImplementedError
|
||||
|
||||
def kill(self) -> bool:
|
||||
"""关闭适配器
|
||||
|
||||
Returns:
|
||||
bool: 是否成功关闭,热重载时若此函数返回False则不会重载MessageSource底层
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MessageConverter:
|
||||
"""消息链转换器基类"""
|
||||
@staticmethod
|
||||
def yiri2target(message_chain: mirai.MessageChain):
|
||||
"""将YiriMirai消息链转换为目标消息链
|
||||
|
||||
Args:
|
||||
message_chain (mirai.MessageChain): YiriMirai消息链
|
||||
|
||||
Returns:
|
||||
typing.Any: 目标消息链
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(message_chain: typing.Any) -> mirai.MessageChain:
|
||||
"""将目标消息链转换为YiriMirai消息链
|
||||
|
||||
Args:
|
||||
message_chain (typing.Any): 目标消息链
|
||||
|
||||
Returns:
|
||||
mirai.MessageChain: YiriMirai消息链
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EventConverter:
|
||||
"""事件转换器基类"""
|
||||
|
||||
@staticmethod
|
||||
def yiri2target(event: typing.Type[mirai.Event]):
|
||||
"""将YiriMirai事件转换为目标事件
|
||||
|
||||
Args:
|
||||
event (typing.Type[mirai.Event]): YiriMirai事件
|
||||
|
||||
Returns:
|
||||
typing.Any: 目标事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(event: typing.Any) -> mirai.Event:
|
||||
"""将目标事件的调用参数转换为YiriMirai的事件参数对象
|
||||
|
||||
Args:
|
||||
event (typing.Any): 目标事件
|
||||
|
||||
Returns:
|
||||
typing.Type[mirai.Event]: YiriMirai事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
154
pkg/platform/manager.py
Normal file
154
pkg/platform/manager.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \
|
||||
FriendMessage, Image, MessageChain, Plain
|
||||
import mirai
|
||||
import func_timeout
|
||||
|
||||
from ..provider import session as openai_session
|
||||
|
||||
from ..utils import context
|
||||
import tips as tips_custom
|
||||
from ..platform import adapter as msadapter
|
||||
from .ratelim import ratelim
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
|
||||
|
||||
# 控制QQ消息输入输出的类
|
||||
class QQBotManager:
|
||||
|
||||
adapter: msadapter.MessageSourceAdapter = None
|
||||
|
||||
bot_account_id: int = 0
|
||||
|
||||
# modern
|
||||
ap: app.Application = None
|
||||
|
||||
ratelimiter: ratelim.RateLimiter = None
|
||||
|
||||
def __init__(self, ap: app.Application = None):
|
||||
|
||||
self.ap = ap
|
||||
self.ratelimiter = ratelim.RateLimiter(ap)
|
||||
|
||||
async def initialize(self):
|
||||
await self.ratelimiter.initialize()
|
||||
|
||||
config = context.get_config_manager().data
|
||||
|
||||
logging.debug("Use adapter:" + config['msg_source_adapter'])
|
||||
if config['msg_source_adapter'] == 'yirimirai':
|
||||
from pkg.platform.sources.yirimirai import YiriMiraiAdapter
|
||||
|
||||
mirai_http_api_config = config['mirai_http_api_config']
|
||||
self.bot_account_id = config['mirai_http_api_config']['qq']
|
||||
self.adapter = YiriMiraiAdapter(mirai_http_api_config)
|
||||
elif config['msg_source_adapter'] == 'nakuru':
|
||||
from pkg.platform.sources.nakuru import NakuruProjectAdapter
|
||||
self.adapter = NakuruProjectAdapter(config['nakuru_config'])
|
||||
self.bot_account_id = self.adapter.bot_account_id
|
||||
|
||||
# 保存 account_id 到审计模块
|
||||
from ..utils.center import apigroup
|
||||
apigroup.APIGroup._runtime_info['account_id'] = "{}".format(self.bot_account_id)
|
||||
|
||||
async def on_friend_message(event: FriendMessage):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
)
|
||||
|
||||
self.adapter.register_listener(
|
||||
FriendMessage,
|
||||
on_friend_message
|
||||
)
|
||||
|
||||
async def on_stranger_message(event: StrangerMessage):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
)
|
||||
|
||||
# nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件
|
||||
if config['msg_source_adapter'] == 'yirimirai':
|
||||
self.adapter.register_listener(
|
||||
StrangerMessage,
|
||||
on_stranger_message
|
||||
)
|
||||
|
||||
async def on_group_message(event: GroupMessage):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
)
|
||||
|
||||
self.adapter.register_listener(
|
||||
GroupMessage,
|
||||
on_group_message
|
||||
)
|
||||
|
||||
async def send(self, event, msg, check_quote=True, check_at_sender=True):
|
||||
config = context.get_config_manager().data
|
||||
|
||||
if check_at_sender and config['at_sender']:
|
||||
msg.insert(
|
||||
0,
|
||||
Plain(" \n")
|
||||
)
|
||||
|
||||
# 当回复的正文中包含换行时,quote可能会自带at,此时就不再单独添加at,只添加换行
|
||||
if "\n" not in str(msg[1]) or config['msg_source_adapter'] == 'nakuru':
|
||||
msg.insert(
|
||||
0,
|
||||
At(
|
||||
event.sender.id
|
||||
)
|
||||
)
|
||||
|
||||
await self.adapter.reply_message(
|
||||
event,
|
||||
msg,
|
||||
quote_origin=True if config['quote_origin'] and check_quote else False
|
||||
)
|
||||
|
||||
# 通知系统管理员
|
||||
async def notify_admin(self, message: str):
|
||||
await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
|
||||
|
||||
async def notify_admin_message_chain(self, message: mirai.MessageChain):
|
||||
config = context.get_config_manager().data
|
||||
if config['admin_qq'] != 0 and config['admin_qq'] != []:
|
||||
logging.info("通知管理员:{}".format(message))
|
||||
|
||||
admin_list = []
|
||||
|
||||
if type(config['admin_qq']) == int:
|
||||
admin_list.append(config['admin_qq'])
|
||||
|
||||
for adm in admin_list:
|
||||
self.adapter.send_message(
|
||||
"person",
|
||||
adm,
|
||||
message
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
await self.adapter.run_async()
|
||||
0
pkg/platform/ratelim/__init__.py
Normal file
0
pkg/platform/ratelim/__init__.py
Normal file
24
pkg/platform/ratelim/algo.py
Normal file
24
pkg/platform/ratelim/algo.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
|
||||
from ...core import app
|
||||
|
||||
|
||||
class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def release_access(self, launcher_type: str, launcher_id: int):
|
||||
raise NotImplementedError
|
||||
|
||||
0
pkg/platform/ratelim/algos/__init__.py
Normal file
0
pkg/platform/ratelim/algos/__init__.py
Normal file
85
pkg/platform/ratelim/algos/fixedwin.py
Normal file
85
pkg/platform/ratelim/algos/fixedwin.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# 固定窗口算法
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from .. import algo
|
||||
|
||||
|
||||
class SessionContainer:
|
||||
|
||||
wait_lock: asyncio.Lock
|
||||
|
||||
records: dict[int, int]
|
||||
"""访问记录,key为每分钟的起始时间戳,value为访问次数"""
|
||||
|
||||
def __init__(self):
|
||||
self.wait_lock = asyncio.Lock()
|
||||
self.records = {}
|
||||
|
||||
|
||||
class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
|
||||
containers_lock: asyncio.Lock
|
||||
"""访问记录容器锁"""
|
||||
|
||||
containers: dict[str, SessionContainer]
|
||||
"""访问记录容器,key为launcher_type launcher_id"""
|
||||
|
||||
async def initialize(self):
|
||||
self.containers_lock = asyncio.Lock()
|
||||
self.containers = {}
|
||||
|
||||
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
|
||||
# 加锁,找容器
|
||||
container: SessionContainer = None
|
||||
|
||||
session_name = f'{launcher_type}_{launcher_id}'
|
||||
|
||||
async with self.containers_lock:
|
||||
container = self.containers.get(session_name)
|
||||
|
||||
if container is None:
|
||||
container = SessionContainer()
|
||||
self.containers[session_name] = container
|
||||
|
||||
# 等待锁
|
||||
async with container.wait_lock:
|
||||
# 获取当前时间戳
|
||||
now = int(time.time())
|
||||
|
||||
# 获取当前分钟的起始时间戳
|
||||
now = now - now % 60
|
||||
|
||||
# 获取当前分钟的访问次数
|
||||
count = container.records.get(now, 0)
|
||||
|
||||
limitation = self.ap.cfg_mgr.data['rate_limitation']['default']
|
||||
|
||||
if session_name in self.ap.cfg_mgr.data['rate_limitation']:
|
||||
limitation = self.ap.cfg_mgr.data['rate_limitation'][session_name]
|
||||
|
||||
# 如果访问次数超过了限制
|
||||
if count >= limitation:
|
||||
if self.ap.cfg_mgr.data['rate_limit_strategy'] == 'drop':
|
||||
return False
|
||||
elif self.ap.cfg_mgr.data['rate_limit_strategy'] == 'wait':
|
||||
# 等待下一分钟
|
||||
await asyncio.sleep(60 - time.time() % 60)
|
||||
|
||||
now = int(time.time())
|
||||
now = now - now % 60
|
||||
|
||||
if now not in container.records:
|
||||
container.records = {}
|
||||
container.records[now] = 1
|
||||
else:
|
||||
# 访问次数加一
|
||||
container.records[now] = count + 1
|
||||
|
||||
# 返回True
|
||||
return True
|
||||
|
||||
async def release_access(self, launcher_type: str, launcher_id: int):
|
||||
pass
|
||||
31
pkg/platform/ratelim/ratelim.py
Normal file
31
pkg/platform/ratelim/ratelim.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import algo
|
||||
from .algos import fixedwin
|
||||
from ...core import app
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""限速器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
algo: algo.ReteLimitAlgo
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
self.algo = fixedwin.FixedWindowAlgo(self.ap)
|
||||
await self.algo.initialize()
|
||||
|
||||
async def require(self, launcher_type: str, launcher_id: int) -> bool:
|
||||
"""请求访问
|
||||
"""
|
||||
return await self.algo.require_access(launcher_type, launcher_id)
|
||||
|
||||
async def release(self, launcher_type: str, launcher_id: int):
|
||||
"""释放访问
|
||||
"""
|
||||
return await self.algo.release_access(launcher_type, launcher_id)
|
||||
0
pkg/platform/sources/__init__.py
Normal file
0
pkg/platform/sources/__init__.py
Normal file
331
pkg/platform/sources/nakuru.py
Normal file
331
pkg/platform/sources/nakuru.py
Normal file
@@ -0,0 +1,331 @@
|
||||
import asyncio
|
||||
import typing
|
||||
import traceback
|
||||
import logging
|
||||
|
||||
import mirai
|
||||
|
||||
import nakuru
|
||||
import nakuru.entities.components as nkc
|
||||
|
||||
from .. import adapter as adapter_model
|
||||
from ...platform import blob
|
||||
from ...utils import context
|
||||
|
||||
|
||||
class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
"""消息转换器"""
|
||||
@staticmethod
|
||||
def yiri2target(message_chain: mirai.MessageChain) -> list:
|
||||
msg_list = []
|
||||
if type(message_chain) is mirai.MessageChain:
|
||||
msg_list = message_chain.__root__
|
||||
elif type(message_chain) is list:
|
||||
msg_list = message_chain
|
||||
else:
|
||||
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
|
||||
|
||||
nakuru_msg_list = []
|
||||
|
||||
# 遍历并转换
|
||||
for component in msg_list:
|
||||
if type(component) is mirai.Plain:
|
||||
nakuru_msg_list.append(nkc.Plain(component.text, False))
|
||||
elif type(component) is mirai.Image:
|
||||
if component.url is not None:
|
||||
nakuru_msg_list.append(nkc.Image.fromURL(component.url))
|
||||
elif component.base64 is not None:
|
||||
nakuru_msg_list.append(nkc.Image.fromBase64(component.base64))
|
||||
elif component.path is not None:
|
||||
nakuru_msg_list.append(nkc.Image.fromFileSystem(component.path))
|
||||
elif type(component) is mirai.Face:
|
||||
nakuru_msg_list.append(nkc.Face(id=component.face_id))
|
||||
elif type(component) is mirai.At:
|
||||
nakuru_msg_list.append(nkc.At(qq=component.target))
|
||||
elif type(component) is mirai.AtAll:
|
||||
nakuru_msg_list.append(nkc.AtAll())
|
||||
elif type(component) is mirai.Voice:
|
||||
if component.url is not None:
|
||||
nakuru_msg_list.append(nkc.Record.fromURL(component.url))
|
||||
elif component.path is not None:
|
||||
nakuru_msg_list.append(nkc.Record.fromFileSystem(component.path))
|
||||
elif type(component) is blob.Forward:
|
||||
# 转发消息
|
||||
yiri_forward_node_list = component.node_list
|
||||
nakuru_forward_node_list = []
|
||||
|
||||
# 遍历并转换
|
||||
for yiri_forward_node in yiri_forward_node_list:
|
||||
try:
|
||||
content_list = NakuruProjectMessageConverter.yiri2target(yiri_forward_node.message_chain)
|
||||
nakuru_forward_node = nkc.Node(
|
||||
name=yiri_forward_node.sender_name,
|
||||
uin=yiri_forward_node.sender_id,
|
||||
time=int(yiri_forward_node.time.timestamp()) if yiri_forward_node.time is not None else None,
|
||||
content=content_list
|
||||
)
|
||||
nakuru_forward_node_list.append(nakuru_forward_node)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
nakuru_msg_list.append(nakuru_forward_node_list)
|
||||
else:
|
||||
nakuru_msg_list.append(nkc.Plain(str(component)))
|
||||
|
||||
return nakuru_msg_list
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(message_chain: typing.Any, message_id: int = -1) -> mirai.MessageChain:
|
||||
"""将Yiri的消息链转换为YiriMirai的消息链"""
|
||||
assert type(message_chain) is list
|
||||
|
||||
yiri_msg_list = []
|
||||
import datetime
|
||||
# 添加Source组件以标记message_id等信息
|
||||
yiri_msg_list.append(mirai.models.message.Source(id=message_id, time=datetime.datetime.now()))
|
||||
for component in message_chain:
|
||||
if type(component) is nkc.Plain:
|
||||
yiri_msg_list.append(mirai.Plain(text=component.text))
|
||||
elif type(component) is nkc.Image:
|
||||
yiri_msg_list.append(mirai.Image(url=component.url))
|
||||
elif type(component) is nkc.Face:
|
||||
yiri_msg_list.append(mirai.Face(face_id=component.id))
|
||||
elif type(component) is nkc.At:
|
||||
yiri_msg_list.append(mirai.At(target=component.qq))
|
||||
elif type(component) is nkc.AtAll:
|
||||
yiri_msg_list.append(mirai.AtAll())
|
||||
else:
|
||||
pass
|
||||
logging.debug("转换后的消息链: " + str(yiri_msg_list))
|
||||
chain = mirai.MessageChain(yiri_msg_list)
|
||||
return chain
|
||||
|
||||
|
||||
class NakuruProjectEventConverter(adapter_model.EventConverter):
|
||||
"""事件转换器"""
|
||||
@staticmethod
|
||||
def yiri2target(event: typing.Type[mirai.Event]):
|
||||
if event is mirai.GroupMessage:
|
||||
return nakuru.GroupMessage
|
||||
elif event is mirai.FriendMessage:
|
||||
return nakuru.FriendMessage
|
||||
else:
|
||||
raise Exception("未支持转换的事件类型: " + str(event))
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(event: typing.Any) -> mirai.Event:
|
||||
yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id)
|
||||
if type(event) is nakuru.FriendMessage: # 私聊消息事件
|
||||
return mirai.FriendMessage(
|
||||
sender=mirai.models.entities.Friend(
|
||||
id=event.sender.user_id,
|
||||
nickname=event.sender.nickname,
|
||||
remark=event.sender.nickname
|
||||
),
|
||||
message_chain=yiri_chain,
|
||||
time=event.time
|
||||
)
|
||||
elif type(event) is nakuru.GroupMessage: # 群聊消息事件
|
||||
permission = "MEMBER"
|
||||
|
||||
if event.sender.role == "admin":
|
||||
permission = "ADMINISTRATOR"
|
||||
elif event.sender.role == "owner":
|
||||
permission = "OWNER"
|
||||
|
||||
import mirai.models.entities as entities
|
||||
return mirai.GroupMessage(
|
||||
sender=mirai.models.entities.GroupMember(
|
||||
id=event.sender.user_id,
|
||||
member_name=event.sender.nickname,
|
||||
permission=permission,
|
||||
group=mirai.models.entities.Group(
|
||||
id=event.group_id,
|
||||
name=event.sender.nickname,
|
||||
permission=entities.Permission.Member
|
||||
),
|
||||
special_title=event.sender.title,
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=yiri_chain,
|
||||
time=event.time
|
||||
)
|
||||
else:
|
||||
raise Exception("未支持转换的事件类型: " + str(event))
|
||||
|
||||
|
||||
class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
||||
"""nakuru-project适配器"""
|
||||
bot: nakuru.CQHTTP
|
||||
bot_account_id: int
|
||||
|
||||
message_converter: NakuruProjectMessageConverter = NakuruProjectMessageConverter()
|
||||
event_converter: NakuruProjectEventConverter = NakuruProjectEventConverter()
|
||||
|
||||
listener_list: list[dict]
|
||||
|
||||
def __init__(self, cfg: dict):
|
||||
"""初始化nakuru-project的对象"""
|
||||
self.bot = nakuru.CQHTTP(**cfg)
|
||||
self.listener_list = []
|
||||
# nakuru库有bug,这个接口没法带access_token,会失败
|
||||
# 所以目前自行发请求
|
||||
|
||||
config = context.get_config_manager().data
|
||||
|
||||
import requests
|
||||
resp = requests.get(
|
||||
url="http://{}:{}/get_login_info".format(config['nakuru_config']['host'], config['nakuru_config']['http_port']),
|
||||
headers={
|
||||
'Authorization': "Bearer " + config['nakuru_config']['token'] if 'token' in config['nakuru_config']else ""
|
||||
},
|
||||
timeout=5,
|
||||
proxies=None
|
||||
)
|
||||
if resp.status_code == 403:
|
||||
logging.error("go-cqhttp拒绝访问,请检查config.py中nakuru_config的token是否与go-cqhttp设置的access-token匹配")
|
||||
raise Exception("go-cqhttp拒绝访问,请检查config.py中nakuru_config的token是否与go-cqhttp设置的access-token匹配")
|
||||
try:
|
||||
self.bot_account_id = int(resp.json()['data']['user_id'])
|
||||
except Exception as e:
|
||||
logging.error("获取go-cqhttp账号信息失败: {}, 请检查是否已启动go-cqhttp并配置正确".format(e))
|
||||
raise Exception("获取go-cqhttp账号信息失败: {}, 请检查是否已启动go-cqhttp并配置正确".format(e))
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: typing.Union[mirai.MessageChain, list],
|
||||
converted: bool = False
|
||||
):
|
||||
task = None
|
||||
|
||||
converted_msg = self.message_converter.yiri2target(message) if not converted else message
|
||||
|
||||
# 检查是否有转发消息
|
||||
has_forward = False
|
||||
for msg in converted_msg:
|
||||
if type(msg) is list: # 转发消息,仅回复此消息组件
|
||||
has_forward = True
|
||||
converted_msg = msg
|
||||
break
|
||||
if has_forward:
|
||||
if target_type == "group":
|
||||
task = self.bot.sendGroupForwardMessage(int(target_id), converted_msg)
|
||||
elif target_type == "person":
|
||||
task = self.bot.sendPrivateForwardMessage(int(target_id), converted_msg)
|
||||
else:
|
||||
raise Exception("Unknown target type: " + target_type)
|
||||
else:
|
||||
if target_type == "group":
|
||||
task = self.bot.sendGroupMessage(int(target_id), converted_msg)
|
||||
elif target_type == "person":
|
||||
task = self.bot.sendFriendMessage(int(target_id), converted_msg)
|
||||
else:
|
||||
raise Exception("Unknown target type: " + target_type)
|
||||
|
||||
asyncio.run(task)
|
||||
|
||||
def reply_message(
|
||||
self,
|
||||
message_source: mirai.MessageEvent,
|
||||
message: mirai.MessageChain,
|
||||
quote_origin: bool = False
|
||||
):
|
||||
message = self.message_converter.yiri2target(message)
|
||||
if quote_origin:
|
||||
# 在前方添加引用组件
|
||||
message.insert(0, nkc.Reply(
|
||||
id=message_source.message_chain.message_id,
|
||||
)
|
||||
)
|
||||
if type(message_source) is mirai.GroupMessage:
|
||||
self.send_message(
|
||||
"group",
|
||||
message_source.sender.group.id,
|
||||
message,
|
||||
converted=True
|
||||
)
|
||||
elif type(message_source) is mirai.FriendMessage:
|
||||
self.send_message(
|
||||
"person",
|
||||
message_source.sender.id,
|
||||
message,
|
||||
converted=True
|
||||
)
|
||||
else:
|
||||
raise Exception("Unknown message source type: " + str(type(message_source)))
|
||||
|
||||
def is_muted(self, group_id: int) -> bool:
|
||||
import time
|
||||
# 检查是否被禁言
|
||||
group_member_info = asyncio.run(self.bot.getGroupMemberInfo(group_id, self.bot_account_id))
|
||||
return group_member_info.shut_up_timestamp > int(time.time())
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
):
|
||||
try:
|
||||
logging.debug("注册监听器: " + str(event_type) + " -> " + str(callback))
|
||||
|
||||
# 包装函数
|
||||
async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.event_converter.yiri2target(event_type)):
|
||||
callback(self.event_converter.target2yiri(source))
|
||||
|
||||
# 将包装函数和原函数的对应关系存入列表
|
||||
self.listener_list.append(
|
||||
{
|
||||
"event_type": event_type,
|
||||
"callable": callback,
|
||||
"wrapper": listener_wrapper,
|
||||
}
|
||||
)
|
||||
|
||||
# 注册监听器
|
||||
self.bot.receiver(self.event_converter.yiri2target(event_type).__name__)(listener_wrapper)
|
||||
logging.debug("注册完成")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
):
|
||||
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
|
||||
|
||||
new_event_list = []
|
||||
|
||||
# 从本对象的监听器列表中查找并删除
|
||||
target_wrapper = None
|
||||
for listener in self.listener_list:
|
||||
if listener["event_type"] == event_type and listener["callable"] == callback:
|
||||
target_wrapper = listener["wrapper"]
|
||||
self.listener_list.remove(listener)
|
||||
break
|
||||
|
||||
if target_wrapper is None:
|
||||
raise Exception("未找到对应的监听器")
|
||||
|
||||
for func in self.bot.event[nakuru_event_name]:
|
||||
if func.callable != target_wrapper:
|
||||
new_event_list.append(func)
|
||||
|
||||
self.bot.event[nakuru_event_name] = new_event_list
|
||||
|
||||
def run_sync(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
self.bot.run()
|
||||
|
||||
async def run_async(self):
|
||||
return await self.bot._run()
|
||||
|
||||
def kill(self) -> bool:
|
||||
return False
|
||||
117
pkg/platform/sources/yirimirai.py
Normal file
117
pkg/platform/sources/yirimirai.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
import mirai.models.bus
|
||||
from mirai.bot import MiraiRunner
|
||||
|
||||
from .. import adapter as adapter_model
|
||||
|
||||
|
||||
class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
|
||||
"""YiriMirai适配器"""
|
||||
bot: mirai.Mirai
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""初始化YiriMirai的对象"""
|
||||
if 'adapter' not in config or \
|
||||
config['adapter'] == 'WebSocketAdapter':
|
||||
self.bot = mirai.Mirai(
|
||||
qq=config['qq'],
|
||||
adapter=mirai.WebSocketAdapter(
|
||||
host=config['host'],
|
||||
port=config['port'],
|
||||
verify_key=config['verifyKey']
|
||||
)
|
||||
)
|
||||
elif config['adapter'] == 'HTTPAdapter':
|
||||
self.bot = mirai.Mirai(
|
||||
qq=config['qq'],
|
||||
adapter=mirai.HTTPAdapter(
|
||||
host=config['host'],
|
||||
port=config['port'],
|
||||
verify_key=config['verifyKey']
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise Exception('Unknown adapter for YiriMirai: ' + config['adapter'])
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: mirai.MessageChain
|
||||
):
|
||||
"""发送消息
|
||||
|
||||
Args:
|
||||
target_type (str): 目标类型,`person`或`group`
|
||||
target_id (str): 目标ID
|
||||
message (mirai.MessageChain): YiriMirai库的消息链
|
||||
"""
|
||||
task = None
|
||||
if target_type == 'person':
|
||||
task = self.bot.send_friend_message(int(target_id), message)
|
||||
elif target_type == 'group':
|
||||
task = self.bot.send_group_message(int(target_id), message)
|
||||
else:
|
||||
raise Exception('Unknown target type: ' + target_type)
|
||||
|
||||
await task
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: mirai.MessageEvent,
|
||||
message: mirai.MessageChain,
|
||||
quote_origin: bool = False
|
||||
):
|
||||
"""回复消息
|
||||
|
||||
Args:
|
||||
message_source (mirai.MessageEvent): YiriMirai消息源事件
|
||||
message (mirai.MessageChain): YiriMirai库的消息链
|
||||
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
|
||||
"""
|
||||
await self.bot.send(message_source, message, quote_origin)
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get()
|
||||
if result.mute_time_remaining > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
):
|
||||
"""注册事件监听器
|
||||
|
||||
Args:
|
||||
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
|
||||
callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件
|
||||
"""
|
||||
self.bot.on(event_type)(callback)
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
):
|
||||
"""注销事件监听器
|
||||
|
||||
Args:
|
||||
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
|
||||
callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件
|
||||
"""
|
||||
assert isinstance(self.bot, mirai.Mirai)
|
||||
bus = self.bot.bus
|
||||
assert isinstance(bus, mirai.models.bus.ModelEventBus)
|
||||
|
||||
bus.unsubscribe(event_type, callback)
|
||||
|
||||
async def run_async(self):
|
||||
return await MiraiRunner(self.bot)._run()
|
||||
|
||||
def kill(self) -> bool:
|
||||
return False
|
||||
Reference in New Issue
Block a user