Merge pull request #670 from RockChinQ/refactor/asyncio/simplify-qqbot-mgr

Refactor: 简化和调整qqbot包架构
This commit is contained in:
Junyan Qin
2024-01-25 22:39:25 +08:00
committed by GitHub
48 changed files with 1331 additions and 856 deletions

View File

@@ -167,6 +167,8 @@ response_rules = {
# 此设置优先级高于response_rules
# 用以过滤mirai等其他层级的命令
# @see https://github.com/RockChinQ/QChatGPT/issues/165
#
# *需要同时开启下方 income_msg_check 才会生效
ignore_rules = {
"prefix": ["/"],
"regexp": []

View File

@@ -16,7 +16,7 @@ sys.path.append(".")
def check_file():
# 检查是否有banlist.py,如果没有就把banlist-template.py复制一份
if not os.path.exists('banlist.py'):
shutil.copy('res/templates/banlist-template.py', 'banlist.py')
shutil.copy('banlist-template.py', 'banlist.py')
# 检查是否有sensitive.json
if not os.path.exists("sensitive.json"):

View File

@@ -28,6 +28,9 @@ class Application:
def __init__(self):
pass
async def initialize(self):
await self.im_mgr.initialize()
async def run(self):
# TODO make it async
plugin_host.initialize_plugins()

View File

@@ -50,7 +50,10 @@ async def make_app() -> app.Application:
# 生成标识符
identifier.init()
cfg_mgr = await config.load_config()
cfg_mgr = await config.load_python_module_config(
"config.py",
"config-template.py"
)
context.set_config_manager(cfg_mgr)
cfg = cfg_mgr.data
@@ -63,12 +66,10 @@ async def make_app() -> app.Application:
if overrided:
qcg_logger.info("以下配置项已使用 override.json 覆盖:" + ",".join(overrided))
tips_mgr = await config.load_tips()
# 初始化文字转图片
from pkg.utils import text2img
# TODO make it async
text2img.initialize()
tips_mgr = await config.load_python_module_config(
"tips.py",
"tips-custom-template.py"
)
# 检查管理员QQ号
if cfg_mgr.data['admin_qq'] == 0:
@@ -121,4 +122,5 @@ async def make_app() -> app.Application:
async def main():
app_inst = await make_app()
await app_inst.initialize()
await app_inst.run()

View File

@@ -4,30 +4,8 @@ from ..config import manager as config_mgr
from ..config.impls import pymodule
async def load_config() -> config_mgr.ConfigManager:
"""加载配置文件"""
cfg_inst = pymodule.PythonModuleConfigFile(
"config.py",
"config-template.py"
)
cfg_mgr = config_mgr.ConfigManager(cfg_inst)
await cfg_mgr.load_config()
return cfg_mgr
async def load_tips() -> config_mgr.ConfigManager:
"""加载提示文件"""
tips_inst = pymodule.PythonModuleConfigFile(
"tips.py",
"tips-custom-template.py"
)
tips_mgr = config_mgr.ConfigManager(tips_inst)
await tips_mgr.load_config()
return tips_mgr
load_python_module_config = config_mgr.load_python_module_config
load_json_config = config_mgr.load_json_config
async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str]:
@@ -39,5 +17,5 @@ async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str
if key in config:
config[key] = override_json[key]
overrided.append(key)
return overrided

View File

@@ -7,7 +7,7 @@ import sys
required_files = {
"config.py": "config-template.py",
"banlist.py": "res/templates/banlist-template.py",
"banlist.py": "banlist-template.py",
"tips.py": "tips-custom-template.py",
"sensitive.json": "res/templates/sensitive-template.json",
"scenario/default.json": "scenario/default-template.json",

44
pkg/config/impls/json.py Normal file
View File

@@ -0,0 +1,44 @@
import os
import shutil
import json
from .. import model as file_model
class JSONConfigFile(file_model.ConfigFile):
"""JSON配置文件"""
config_file_name: str = None
"""配置文件名"""
template_file_name: str = None
"""模板文件名"""
def __init__(self, config_file_name: str, template_file_name: str) -> None:
self.config_file_name = config_file_name
self.template_file_name = template_file_name
def exists(self) -> bool:
return os.path.exists(self.config_file_name)
async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name)
async def load(self) -> dict:
with open(self.config_file_name, 'r', encoding='utf-8') as f:
cfg = json.load(f)
# 从模板文件中进行补全
with open(self.template_file_name, 'r', encoding='utf-8') as f:
template_cfg = json.load(f)
for key in template_cfg:
if key not in cfg:
cfg[key] = template_cfg[key]
return cfg
async def save(self, cfg: dict):
with open(self.config_file_name, 'w', encoding='utf-8') as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)

View File

@@ -1,5 +1,6 @@
from . import model as file_model
from ..utils import context
from .impls import pymodule, json as json_file
class ConfigManager:
@@ -20,3 +21,29 @@ class ConfigManager:
async def dump_config(self):
await self.file.save(self.data)
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager:
"""加载Python模块配置文件"""
cfg_inst = pymodule.PythonModuleConfigFile(
config_name,
template_name
)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config()
return cfg_mgr
async def load_json_config(config_name: str, template_name: str) -> ConfigManager:
"""加载JSON配置文件"""
cfg_inst = json_file.JSONConfigFile(
config_name,
template_name
)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config()
return cfg_mgr

View File

@@ -1,50 +0,0 @@
from ..utils import context
def is_banned(launcher_type: str, launcher_id: int, sender_id: int) -> bool:
if not context.get_qqbot_manager().enable_banlist:
return False
result = False
if launcher_type == 'group':
# 检查是否显式声明发起人QQ要被person忽略
if sender_id in context.get_qqbot_manager().ban_person:
result = True
else:
for group_rule in context.get_qqbot_manager().ban_group:
if type(group_rule) == int:
if group_rule == launcher_id: # 此群群号被禁用
result = True
elif type(group_rule) == str:
if group_rule.startswith('!'):
# 截取!后面的字符串作为表达式,判断是否匹配
reg_str = group_rule[1:]
import re
if re.match(reg_str, str(launcher_id)): # 被豁免,最高级别
result = False
break
else:
# 判断是否匹配regexp
import re
if re.match(group_rule, str(launcher_id)): # 此群群号被禁用
result = True
else:
# ban_person, 与群规则相同
for person_rule in context.get_qqbot_manager().ban_person:
if type(person_rule) == int:
if person_rule == launcher_id:
result = True
elif type(person_rule) == str:
if person_rule.startswith('!'):
reg_str = person_rule[1:]
import re
if re.match(reg_str, str(launcher_id)):
result = False
break
else:
import re
if re.match(person_rule, str(launcher_id)):
result = True
return result

View File

View File

@@ -0,0 +1,70 @@
# 处理对会话的禁用配置
# 过去的 banlist
from __future__ import annotations
import re
from ...boot import app
from ...config import manager as cfg_mgr
class SessionBanManager:
ap: app.Application = None
banlist_mgr: cfg_mgr.ConfigManager
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
self.banlist_mgr = await cfg_mgr.load_python_module_config(
"banlist.py",
"res/templates/banlist-template.py"
)
async def is_banned(
self, launcher_type: str, launcher_id: int, sender_id: int
) -> bool:
if not self.banlist_mgr.data['enable']:
return False
result = False
if launcher_type == 'group':
if not self.banlist_mgr.data['enable_group']: # 未启用群聊响应
result = True
# 检查是否显式声明发起人QQ要被person忽略
elif sender_id in self.banlist_mgr.data['person']:
result = True
else:
for group_rule in self.banlist_mgr.data['group']:
if type(group_rule) == int:
if group_rule == launcher_id:
result = True
elif type(group_rule) == str:
if group_rule.startswith('!'):
reg_str = group_rule[1:]
if re.match(reg_str, str(launcher_id)):
result = False
break
else:
if re.match(group_rule, str(launcher_id)):
result = True
elif launcher_type == 'person':
if not self.banlist_mgr.data['enable_private']:
result = True
else:
for person_rule in self.banlist_mgr.data['person']:
if type(person_rule) == int:
if person_rule == launcher_id:
result = True
elif type(person_rule) == str:
if person_rule.startswith('!'):
reg_str = person_rule[1:]
if re.match(reg_str, str(launcher_id)):
result = False
break
else:
if re.match(person_rule, str(launcher_id)):
result = True
return result

View File

@@ -1,100 +0,0 @@
# 长消息处理相关
import os
import time
import base64
import typing
from mirai.models.message import MessageComponent, MessageChain, Image
from mirai.models.message import ForwardMessageNode
from mirai.models.base import MiraiBaseModel
from ..utils import text2img
from ..utils import context
class ForwardMessageDiaplay(MiraiBaseModel):
title: str = "群聊的聊天记录"
brief: str = "[聊天记录]"
source: str = "聊天记录"
preview: typing.List[str] = []
summary: str = "查看x条转发消息"
class Forward(MessageComponent):
"""合并转发。"""
type: str = "Forward"
"""消息组件类型。"""
display: ForwardMessageDiaplay
"""显示信息"""
node_list: typing.List[ForwardMessageNode]
"""转发消息节点列表。"""
def __init__(self, *args, **kwargs):
if len(args) == 1:
self.node_list = args[0]
super().__init__(**kwargs)
super().__init__(*args, **kwargs)
def __str__(self):
return '[聊天记录]'
def text_to_image(text: str) -> MessageComponent:
"""将文本转换成图片"""
# 检查temp文件夹是否存在
if not os.path.exists('temp'):
os.mkdir('temp')
img_path = text2img.text_to_image(text_str=text, save_as='temp/{}.png'.format(int(time.time())))
compressed_path, size = text2img.compress_image(img_path, outfile="temp/{}_compressed.png".format(int(time.time())))
# 读取图片转换成base64
with open(compressed_path, 'rb') as f:
img = f.read()
b64 = base64.b64encode(img)
# 删除图片
os.remove(img_path)
# 判断compressed_path是否存在
if os.path.exists(compressed_path):
os.remove(compressed_path)
# 返回图片
return Image(base64=b64.decode('utf-8'))
def check_text(text: str) -> list:
"""检查文本是否为长消息,并转换成该使用的消息链组件"""
config = context.get_config_manager().data
if len(text) > config['blob_message_threshold']:
# logging.info("长消息: {}".format(text))
if config['blob_message_strategy'] == 'image':
# 转换成图片
return [text_to_image(text)]
elif config['blob_message_strategy'] == 'forward':
# 包装转发消息
display = ForwardMessageDiaplay(
title='群聊的聊天记录',
brief='[聊天记录]',
source='聊天记录',
preview=["bot: "+text],
summary="查看1条转发消息"
)
node = ForwardMessageNode(
sender_id=config['mirai_http_api_config']['qq'],
sender_name='bot',
message_chain=MessageChain([text])
)
forward = Forward(
display=display,
node_list=[node]
)
return [forward]
else:
return [text]

View File

View File

@@ -0,0 +1,93 @@
from __future__ import annotations
from ...boot import app
from . import entities
from . import filter
from .filters import cntignore, banwords, baiduexamine
class ContentFilterManager:
ao: app.Application
filter_chain: list[filter.ContentFilter]
def __init__(self, ap: app.Application) -> None:
self.ap = ap
self.filter_chain = []
async def initialize(self):
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
if self.ap.cfg_mgr.data['sensitive_word_filter']:
self.filter_chain.append(banwords.BanWordFilter(self.ap))
if self.ap.cfg_mgr.data['baidu_check']:
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
for filter in self.filter_chain:
await filter.initialize()
async def pre_process(self, message: str) -> entities.FilterManagerResult:
"""请求llm前处理消息
只要有一个不通过就不放行,只放行 PASS 的消息
"""
if not self.ap.cfg_mgr.data['income_msg_check']: # 不检查收到的消息,直接放行
return entities.FilterManagerResult(
level=entities.ManagerResultLevel.CONTINUE,
replacement=message,
user_notice='',
console_notice=''
)
else:
for filter in self.filter_chain:
if entities.EnableStage.PRE in filter.enable_stages:
result = await filter.process(message)
if result.level in [
entities.ResultLevel.BLOCK,
entities.ResultLevel.MASKED
]:
return entities.FilterManagerResult(
level=entities.ManagerResultLevel.INTERRUPT,
replacement=result.replacement,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level == entities.ResultLevel.PASS:
message = result.replacement
return entities.FilterManagerResult(
level=entities.ManagerResultLevel.CONTINUE,
replacement=message,
user_notice='',
console_notice=''
)
async def post_process(self, message: str) -> entities.FilterManagerResult:
"""请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
"""
for filter in self.filter_chain:
if entities.EnableStage.POST in filter.enable_stages:
result = await filter.process(message)
if result.level == entities.ResultLevel.BLOCK:
return entities.FilterManagerResult(
level=entities.ManagerResultLevel.INTERRUPT,
replacement=result.replacement,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level in [
entities.ResultLevel.PASS,
entities.ResultLevel.MASKED
]:
message = result.replacement
return entities.FilterManagerResult(
level=entities.ManagerResultLevel.CONTINUE,
replacement=message,
user_notice='',
console_notice=''
)

View File

@@ -0,0 +1,64 @@
import typing
import enum
import pydantic
class ResultLevel(enum.Enum):
"""结果等级"""
PASS = enum.auto()
"""通过"""
WARN = enum.auto()
"""警告"""
MASKED = enum.auto()
"""已掩去"""
BLOCK = enum.auto()
"""阻止"""
class EnableStage(enum.Enum):
"""启用阶段"""
PRE = enum.auto()
"""预处理"""
POST = enum.auto()
"""后处理"""
class FilterResult(pydantic.BaseModel):
level: ResultLevel
replacement: str
"""替换后的消息"""
user_notice: str
"""不通过时,用户提示消息"""
console_notice: str
"""不通过时,控制台提示消息"""
class ManagerResultLevel(enum.Enum):
"""处理器结果等级"""
CONTINUE = enum.auto()
"""继续"""
INTERRUPT = enum.auto()
"""中断"""
class FilterManagerResult(pydantic.BaseModel):
level: ManagerResultLevel
replacement: str
"""替换后的消息"""
user_notice: str
"""用户提示消息"""
console_notice: str
"""控制台提示消息"""

View File

@@ -0,0 +1,34 @@
# 内容过滤器的抽象类
from __future__ import annotations
import abc
from ...boot import app
from . import entities
class ContentFilter(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
@property
def enable_stages(self):
"""启用的阶段
"""
return [
entities.EnableStage.PRE,
entities.EnableStage.POST
]
async def initialize(self):
"""初始化过滤器
"""
pass
@abc.abstractmethod
async def process(self, message: str) -> entities.FilterResult:
"""处理消息
"""
raise NotImplementedError

View File

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
import aiohttp
from .. import entities
from .. import filter as filter_model
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核"""
async def _get_token(self) -> str:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_TOKEN_URL,
params={
"grant_type": "client_credentials",
"client_id": self.ap.cfg_mgr.data['baidu_api_key'],
"client_secret": self.ap.cfg_mgr.data['baidu_secret_key']
}
) as resp:
return (await resp.json())['access_token']
async def process(self, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()),
headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'},
data=f"text={message}".encode('utf-8')
) as resp:
result = await resp.json()
if "error_code" in result:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice='',
console_notice=f"百度云判定出错,错误信息:{result['error_msg']}"
)
else:
conclusion = result["conclusion"]
if conclusion in ("合规"):
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=f"百度云判定结果:{conclusion}"
)
else:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice=self.ap.cfg_mgr.data['inappropriate_message_tips'],
console_notice=f"百度云判定结果:{conclusion}"
)

View File

@@ -0,0 +1,44 @@
from __future__ import annotations
import re
from .. import filter as filter_model
from .. import entities
from ....config import manager as cfg_mgr
class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言"""
sensitive: cfg_mgr.ConfigManager
async def initialize(self):
self.sensitive = await cfg_mgr.load_json_config(
"sensitive.json",
"res/templates/sensitive-template.json"
)
async def process(self, message: str) -> entities.FilterResult:
found = False
for word in self.sensitive.data['words']:
match = re.findall(word, message)
if len(match) > 0:
found = True
for i in range(len(match)):
if self.sensitive.data['mask_word'] == "":
message = message.replace(
match[i], self.sensitive.data['mask'] * len(match[i])
)
else:
message = message.replace(
match[i], self.sensitive.data['mask_word']
)
return entities.FilterResult(
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
replacement=message,
user_notice='[bot] 消息中存在不合适的内容, 请更换措辞' if found else '',
console_notice=''
)

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
import re
from .. import entities
from .. import filter as filter_model
class ContentIgnore(filter_model.ContentFilter):
"""根据内容忽略消息"""
@property
def enable_stages(self):
return [
entities.EnableStage.PRE,
]
async def process(self, message: str) -> entities.FilterResult:
if 'prefix' in self.ap.cfg_mgr.data['ignore_rules']:
for rule in self.ap.cfg_mgr.data['ignore_rules']['prefix']:
if message.startswith(rule):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
)
if 'regexp' in self.ap.cfg_mgr.data['ignore_rules']:
for rule in self.ap.cfg_mgr.data['ignore_rules']['regexp']:
if re.search(rule, message):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息'
)
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=''
)

View File

@@ -1,87 +0,0 @@
# 敏感词过滤模块
import re
import requests
import json
import logging
from ..utils import context
class ReplyFilter:
sensitive_words = []
mask = "*"
mask_word = ""
# 默认值( 兼容性考虑 )
baidu_check = False
baidu_api_key = ""
baidu_secret_key = ""
inappropriate_message_tips = "[百度云]请珍惜机器人,当前返回内容不合规"
def __init__(self, sensitive_words: list, mask: str = "*", mask_word: str = ""):
self.sensitive_words = sensitive_words
self.mask = mask
self.mask_word = mask_word
config = context.get_config_manager().data
self.baidu_check = config['baidu_check']
self.baidu_api_key = config['baidu_api_key']
self.baidu_secret_key = config['baidu_secret_key']
self.inappropriate_message_tips = config['inappropriate_message_tips']
def is_illegal(self, message: str) -> bool:
processed = self.process(message)
if processed != message:
return True
return False
def process(self, message: str) -> str:
# 本地关键词屏蔽
for word in self.sensitive_words:
match = re.findall(word, message)
if len(match) > 0:
for i in range(len(match)):
if self.mask_word == "":
message = message.replace(match[i], self.mask * len(match[i]))
else:
message = message.replace(match[i], self.mask_word)
# 百度云审核
if self.baidu_check:
# 百度云审核URL
baidu_url = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token=" + \
str(requests.post("https://aip.baidubce.com/oauth/2.0/token",
params={"grant_type": "client_credentials",
"client_id": self.baidu_api_key,
"client_secret": self.baidu_secret_key}).json().get("access_token"))
# 百度云审核
payload = "text=" + message
logging.info("向百度云发送:" + payload)
headers = {'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'}
if isinstance(payload, str):
payload = payload.encode('utf-8')
response = requests.request("POST", baidu_url, headers=headers, data=payload)
response_dict = json.loads(response.text)
if "error_code" in response_dict:
error_msg = response_dict.get("error_msg")
logging.warning(f"百度云判定出错,错误信息:{error_msg}")
conclusion = f"百度云判定出错,错误信息:{error_msg}\n以下是原消息:{message}"
else:
conclusion = response_dict["conclusion"]
if conclusion in ("合规"):
logging.info(f"百度云判定结果:{conclusion}")
return message
else:
logging.warning(f"百度云判定结果:{conclusion}")
conclusion = self.inappropriate_message_tips
# 返回百度云审核结果
return conclusion
return message

View File

@@ -1,18 +0,0 @@
import re
from ..utils import context
def ignore(msg: str) -> bool:
"""检查消息是否应该被忽略"""
config = context.get_config_manager().data
if 'prefix' in config['ignore_rules']:
for rule in config['ignore_rules']['prefix']:
if msg.startswith(rule):
return True
if 'regexp' in config['ignore_rules']:
for rule in config['ignore_rules']['regexp']:
if re.search(rule, msg):
return True

View File

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
import os
import traceback
from PIL import Image, ImageDraw, ImageFont
from mirai.models.message import MessageComponent, Plain
from ...boot import app
from . import strategy
from .strategies import image, forward
class LongTextProcessor:
ap: app.Application
strategy_impl: strategy.LongTextStrategy
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
config = self.ap.cfg_mgr.data
if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image':
use_font = config['font_path']
try:
# 检查是否存在
if not os.path.exists(use_font):
# 若是windows系统使用微软雅黑
if os.name == "nt":
use_font = "C:/Windows/Fonts/msyh.ttc"
if not os.path.exists(use_font):
self.ap.logger.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。")
config['blob_message_strategy'] = "forward"
else:
self.ap.logger.info("使用Windows自带字体" + use_font)
self.ap.cfg_mgr.data['font_path'] = use_font
else:
self.ap.logger.warn("未找到字体文件且无法使用系统自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。")
self.ap.cfg_mgr.data['blob_message_strategy'] = "forward"
except:
traceback.print_exc()
self.ap.logger.error("加载字体文件失败({})更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。".format(use_font))
self.ap.cfg_mgr.data['blob_message_strategy'] = "forward"
if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image':
self.strategy_impl = image.Text2ImageStrategy(self.ap)
elif self.ap.cfg_mgr.data['blob_message_strategy'] == 'forward':
self.strategy_impl = forward.ForwardComponentStrategy(self.ap)
await self.strategy_impl.initialize()
async def check_and_process(self, message: str) -> list[MessageComponent]:
if len(message) > self.ap.cfg_mgr.data['blob_message_threshold']:
return await self.strategy_impl.process(message)
else:
return [Plain(message)]

View File

@@ -0,0 +1,62 @@
# 转发消息组件
from __future__ import annotations
import typing
from mirai.models import MessageChain
from mirai.models.message import MessageComponent, ForwardMessageNode
from mirai.models.base import MiraiBaseModel
from .. import strategy as strategy_model
class ForwardMessageDiaplay(MiraiBaseModel):
title: str = "群聊的聊天记录"
brief: str = "[聊天记录]"
source: str = "聊天记录"
preview: typing.List[str] = []
summary: str = "查看x条转发消息"
class Forward(MessageComponent):
"""合并转发。"""
type: str = "Forward"
"""消息组件类型。"""
display: ForwardMessageDiaplay
"""显示信息"""
node_list: typing.List[ForwardMessageNode]
"""转发消息节点列表。"""
def __init__(self, *args, **kwargs):
if len(args) == 1:
self.node_list = args[0]
super().__init__(**kwargs)
super().__init__(*args, **kwargs)
def __str__(self):
return '[聊天记录]'
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str) -> list[MessageComponent]:
display = ForwardMessageDiaplay(
title="群聊的聊天记录",
brief="[聊天记录]",
source="聊天记录",
preview=["QQ用户: "+message],
summary="查看1条转发消息"
)
node_list = [
ForwardMessageNode(
sender_id=self.ap.im_mgr.bot_account_id,
sender_name='QQ用户',
message_chain=MessageChain([message])
)
]
forward = Forward(
display=display,
node_list=node_list
)
return [forward]

View File

@@ -0,0 +1,197 @@
from __future__ import annotations
import typing
import os
import base64
import time
import re
from PIL import Image, ImageDraw, ImageFont
from mirai.models import MessageChain, Image as ImageComponent
from mirai.models.message import MessageComponent
from .. import strategy as strategy_model
class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_render_font: ImageFont.FreeTypeFont
async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.cfg_mgr.data['font_path'], 32, encoding="utf-8")
async def process(self, message: str) -> list[MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time()))
)
compressed_path, size = self.compress_image(
img_path,
outfile="temp/{}_compressed.png".format(int(time.time()))
)
with open(compressed_path, 'rb') as f:
img = f.read()
b64 = base64.b64encode(img)
# 删除图片
os.remove(img_path)
if os.path.exists(compressed_path):
os.remove(compressed_path)
return [
ImageComponent(
base64=b64.decode('utf-8'),
)
]
def indexNumber(self, path=''):
"""
查找字符串中数字所在串中的位置
:param path:目标字符串
:return:<class 'list'>: <class 'list'>: [['1', 16], ['2', 35], ['1', 51]]
"""
kv = []
nums = []
beforeDatas = re.findall('[\d]+', path)
for num in beforeDatas:
indexV = []
times = path.count(num)
if times > 1:
if num not in nums:
indexs = re.finditer(num, path)
for index in indexs:
iV = []
i = index.span()[0]
iV.append(num)
iV.append(i)
kv.append(iV)
nums.append(num)
else:
index = path.find(num)
indexV.append(num)
indexV.append(index)
kv.append(indexV)
# 根据数字位置排序
indexSort = []
resultIndex = []
for vi in kv:
indexSort.append(vi[1])
indexSort.sort()
for i in indexSort:
for v in kv:
if i == v[1]:
resultIndex.append(v)
return resultIndex
def get_size(self, file):
# 获取文件大小:KB
size = os.path.getsize(file)
return size / 1024
def get_outfile(self, infile, outfile):
if outfile:
return outfile
dir, suffix = os.path.splitext(infile)
outfile = '{}-out{}'.format(dir, suffix)
return outfile
def compress_image(self, infile, outfile='', kb=100, step=20, quality=90):
"""不改变图片尺寸压缩到指定大小
:param infile: 压缩源文件
:param outfile: 压缩文件保存地址
:param mb: 压缩目标,KB
:param step: 每次调整的压缩比率
:param quality: 初始压缩比率
:return: 压缩文件地址,压缩文件大小
"""
o_size = self.get_size(infile)
if o_size <= kb:
return infile, o_size
outfile = self.get_outfile(infile, outfile)
while o_size > kb:
im = Image.open(infile)
im.save(outfile, quality=quality)
if quality - step < 0:
break
quality -= step
o_size = self.get_size(outfile)
return outfile, self.get_size(outfile)
def text_to_image(self, text_str: str, save_as="temp.png", width=800):
text_str = text_str.replace("\t", " ")
# 分行
lines = text_str.split('\n')
# 计算并分割
final_lines = []
text_width = width-80
self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width))
for line in lines:
# 如果长了就分割
line_width = self.text_render_font.getlength(line)
self.ap.logger.debug("line_width: {}".format(line_width))
if line_width < text_width:
final_lines.append(line)
continue
else:
rest_text = line
while True:
# 分割最前面的一行
point = int(len(rest_text) * (text_width / line_width))
# 检查断点是否在数字中间
numbers = self.indexNumber(rest_text)
for number in numbers:
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
point = number[1]
break
final_lines.append(rest_text[:point])
rest_text = rest_text[point:]
line_width = self.text_render_font.getlength(rest_text)
if line_width < text_width:
final_lines.append(rest_text)
break
else:
continue
# 准备画布
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
draw = ImageDraw.Draw(img, mode='RGBA')
self.ap.logger.debug("正在绘制图片...")
# 绘制正文
line_number = 0
offset_x = 20
offset_y = 30
for final_line in final_lines:
draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=self.text_render_font)
# 遍历此行,检查是否有emoji
idx_in_line = 0
for ch in final_line:
# 检查字符占位宽
char_code = ord(ch)
if char_code >= 127:
idx_in_line += 1
else:
idx_in_line += 0.5
line_number += 1
self.ap.logger.debug("正在保存图片...")
img.save(save_as)
return save_as

View File

@@ -0,0 +1,22 @@
from __future__ import annotations
import abc
import typing
import mirai
from mirai.models.message import MessageComponent
from ...boot import app
class LongTextStrategy(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def process(self, message: str) -> list[MessageComponent]:
return []

View File

@@ -7,87 +7,26 @@ import asyncio
from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \
FriendMessage, Image, MessageChain, Plain
import mirai
import func_timeout
from ..openai import session as openai_session
from ..qqbot import filter as qqbot_filter
from ..qqbot import process as processor
from ..utils import context
from ..plugin import host as plugin_host
from ..plugin import models as plugin_models
import tips as tips_custom
from ..qqbot import adapter as msadapter
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .longtext import longtext
from .ratelim import ratelim
from ..boot import app
# 检查消息是否符合泛响应匹配机制
def check_response_rule(group_id:int, text: str):
config = context.get_config_manager().data
rules = config['response_rules']
# 检查是否有特定规则
if 'prefix' not in config['response_rules']:
if str(group_id) in config['response_rules']:
rules = config['response_rules'][str(group_id)]
else:
rules = config['response_rules']['default']
# 检查前缀匹配
if 'prefix' in rules:
for rule in rules['prefix']:
if text.startswith(rule):
return True, text.replace(rule, "", 1)
# 检查正则表达式匹配
if 'regexp' in rules:
for rule in rules['regexp']:
import re
match = re.match(rule, text)
if match:
return True, text
return False, ""
def response_at(group_id: int):
config = context.get_config_manager().data
use_response_rule = config['response_rules']
# 检查是否有特定规则
if 'prefix' not in config['response_rules']:
if str(group_id) in config['response_rules']:
use_response_rule = config['response_rules'][str(group_id)]
else:
use_response_rule = config['response_rules']['default']
if 'at' not in use_response_rule:
return True
return use_response_rule['at']
def random_responding(group_id):
config = context.get_config_manager().data
use_response_rule = config['response_rules']
# 检查是否有特定规则
if 'prefix' not in config['response_rules']:
if str(group_id) in config['response_rules']:
use_response_rule = config['response_rules'][str(group_id)]
else:
use_response_rule = config['response_rules']['default']
if 'random_rate' in use_response_rule:
import random
return random.random() < use_response_rule['random_rate']
return False
# 控制QQ消息输入输出的类
class QQBotManager:
retry = 3
@@ -96,40 +35,51 @@ class QQBotManager:
bot_account_id: int = 0
reply_filter = None
enable_banlist = False
enable_private = True
enable_group = True
ban_person = []
ban_group = []
# modern
ap: app.Application = None
bansess_mgr: bansess.SessionBanManager = None
cntfilter_mgr: cntfilter.ContentFilterManager = None
longtext_pcs: longtext.LongTextProcessor = None
resprule_chkr: resprule.GroupRespondRuleChecker = None
ratelimiter: ratelim.RateLimiter = None
def __init__(self, first_time_init=True, ap: app.Application = None):
config = context.get_config_manager().data
self.ap = ap
self.bansess_mgr = bansess.SessionBanManager(ap)
self.cntfilter_mgr = cntfilter.ContentFilterManager(ap)
self.longtext_pcs = longtext.LongTextProcessor(ap)
self.resprule_chkr = resprule.GroupRespondRuleChecker(ap)
self.ratelimiter = ratelim.RateLimiter(ap)
self.timeout = config['process_message_timeout']
self.retry = config['retry_times']
async def initialize(self):
await self.bansess_mgr.initialize()
await self.cntfilter_mgr.initialize()
await self.longtext_pcs.initialize()
await self.resprule_chkr.initialize()
await self.ratelimiter.initialize()
# 由于YiriMirai的bot对象是单例的且shutdown方法暂时无法使用
# 故只在第一次初始化时创建bot对象重载之后使用原bot对象
# 因此bot的配置不支持热重载
if first_time_init:
logging.debug("Use adapter:" + config['msg_source_adapter'])
if config['msg_source_adapter'] == 'yirimirai':
from pkg.qqbot.sources.yirimirai import YiriMiraiAdapter
config = context.get_config_manager().data
mirai_http_api_config = config['mirai_http_api_config']
self.bot_account_id = config['mirai_http_api_config']['qq']
self.adapter = YiriMiraiAdapter(mirai_http_api_config)
elif config['msg_source_adapter'] == 'nakuru':
from pkg.qqbot.sources.nakuru import NakuruProjectAdapter
self.adapter = NakuruProjectAdapter(config['nakuru_config'])
self.bot_account_id = self.adapter.bot_account_id
else:
self.adapter = context.get_qqbot_manager().adapter
self.bot_account_id = context.get_qqbot_manager().bot_account_id
logging.debug("Use adapter:" + config['msg_source_adapter'])
if config['msg_source_adapter'] == 'yirimirai':
from pkg.qqbot.sources.yirimirai import YiriMiraiAdapter
mirai_http_api_config = config['mirai_http_api_config']
self.bot_account_id = config['mirai_http_api_config']['qq']
self.adapter = YiriMiraiAdapter(mirai_http_api_config)
elif config['msg_source_adapter'] == 'nakuru':
from pkg.qqbot.sources.nakuru import NakuruProjectAdapter
self.adapter = NakuruProjectAdapter(config['nakuru_config'])
self.bot_account_id = self.adapter.bot_account_id
# 保存 account_id 到审计模块
from ..utils.center import apigroup
@@ -205,6 +155,7 @@ class QQBotManager:
await self.on_group_message(event)
asyncio.create_task(group_message_handler(event))
self.adapter.register_listener(
GroupMessage,
on_group_message
@@ -231,33 +182,6 @@ class QQBotManager:
self.unsubscribe_all = unsubscribe_all
# 加载禁用列表
if os.path.exists("banlist.py"):
import banlist
self.enable_banlist = banlist.enable
self.ban_person = banlist.person
self.ban_group = banlist.group
logging.info("加载禁用列表: person: {}, group: {}".format(self.ban_person, self.ban_group))
if hasattr(banlist, "enable_private"):
self.enable_private = banlist.enable_private
if hasattr(banlist, "enable_group"):
self.enable_group = banlist.enable_group
config = context.get_config_manager().data
if os.path.exists("sensitive.json") \
and config['sensitive_word_filter'] is not None \
and config['sensitive_word_filter']:
with open("sensitive.json", "r", encoding="utf-8") as f:
sensitive_json = json.load(f)
self.reply_filter = qqbot_filter.ReplyFilter(
sensitive_words=sensitive_json['words'],
mask=sensitive_json['mask'] if 'mask' in sensitive_json else '*',
mask_word=sensitive_json['mask_word'] if 'mask_word' in sensitive_json else ''
)
else:
self.reply_filter = qqbot_filter.ReplyFilter([])
async def send(self, event, msg, check_quote=True, check_at_sender=True):
config = context.get_config_manager().data
@@ -282,46 +206,60 @@ class QQBotManager:
quote_origin=True if config['quote_origin'] and check_quote else False
)
async def common_process(
self,
launcher_type: str,
launcher_id: int,
text_message: str,
message_chain: MessageChain,
sender_id: int
) -> mirai.MessageChain:
"""
私聊群聊通用消息处理方法
"""
# 检查bansess
if await self.bansess_mgr.is_banned(launcher_type, launcher_id, sender_id):
self.ap.logger.info("根据禁用列表忽略{}_{}的消息".format(launcher_type, launcher_id))
return []
if mirai.Image in message_chain:
return []
elif sender_id == self.bot_account_id:
return []
else:
# 超时则重试,重试超过次数则放弃
failed = 0
for i in range(self.retry):
try:
reply = await processor.process_message(launcher_type, launcher_id, text_message, message_chain,
sender_id)
return reply
# TODO openai 超时处理
except func_timeout.FunctionTimedOut:
logging.warning("{}_{}: 超时,重试中({})".format(launcher_type, launcher_id, i))
openai_session.get_session("{}_{}".format(launcher_type, launcher_id)).release_response_lock()
if "{}_{}".format(launcher_type, launcher_id) in processor.processing:
processor.processing.remove("{}_{}".format(launcher_type, launcher_id))
failed += 1
continue
if failed == self.retry:
openai_session.get_session("{}_{}".format(launcher_type, launcher_id)).release_response_lock()
await self.notify_admin("{} 请求超时".format("{}_{}".format(launcher_type, launcher_id)))
reply = [tips_custom.reply_message]
# 私聊消息处理
async def on_person_message(self, event: MessageEvent):
reply = ''
config = context.get_config_manager().data
if not self.enable_private:
logging.debug("已在banlist.py中禁用所有私聊")
elif event.sender.id == self.bot_account_id:
pass
else:
if Image in event.message_chain:
pass
else:
# 超时则重试,重试超过次数则放弃
failed = 0
for i in range(self.retry):
try:
# @func_timeout.func_set_timeout(config['process_message_timeout'])
async def time_ctrl_wrapper():
reply = await processor.process_message('person', event.sender.id, str(event.message_chain),
event.message_chain,
event.sender.id)
return reply
reply = await time_ctrl_wrapper()
break
except func_timeout.FunctionTimedOut:
logging.warning("person_{}: 超时,重试中({})".format(event.sender.id, i))
openai_session.get_session('person_{}'.format(event.sender.id)).release_response_lock()
if "person_{}".format(event.sender.id) in processor.processing:
processor.processing.remove('person_{}'.format(event.sender.id))
failed += 1
continue
if failed == self.retry:
openai_session.get_session('person_{}'.format(event.sender.id)).release_response_lock()
self.notify_admin("{} 请求超时".format("person_{}".format(event.sender.id)))
reply = [tips_custom.reply_message]
reply = await self.common_process(
launcher_type="person",
launcher_id=event.sender.id,
text_message=str(event.message_chain),
message_chain=event.message_chain,
sender_id=event.sender.id
)
if reply:
await self.send(event, reply, check_quote=False, check_at_sender=False)
@@ -330,99 +268,48 @@ class QQBotManager:
async def on_group_message(self, event: GroupMessage):
reply = ''
config = context.get_config_manager().data
text = str(event.message_chain).strip()
async def process(text=None) -> str:
replys = ""
if At(self.bot_account_id) in event.message_chain:
event.message_chain.remove(At(self.bot_account_id))
rule_check_res = await self.resprule_chkr.check(
text,
event.message_chain,
event.group.id,
event.sender.id
)
# 超时则重试,重试超过次数则放弃
failed = 0
for i in range(self.retry):
try:
# @func_timeout.func_set_timeout(config['process_message_timeout'])
async def time_ctrl_wrapper():
replys = await processor.process_message('group', event.group.id,
str(event.message_chain).strip() if text is None else text,
event.message_chain,
event.sender.id)
return replys
replys = await time_ctrl_wrapper()
break
except func_timeout.FunctionTimedOut:
logging.warning("group_{}: 超时,重试中({})".format(event.group.id, i))
openai_session.get_session('group_{}'.format(event.group.id)).release_response_lock()
if "group_{}".format(event.group.id) in processor.processing:
processor.processing.remove('group_{}'.format(event.group.id))
failed += 1
continue
if failed == self.retry:
openai_session.get_session('group_{}'.format(event.group.id)).release_response_lock()
self.notify_admin("{} 请求超时".format("group_{}".format(event.group.id)))
replys = [tips_custom.replys_message]
return replys
if not self.enable_group:
logging.debug("已在banlist.py中禁用所有群聊")
elif Image in event.message_chain:
pass
else:
if At(self.bot_account_id) in event.message_chain and response_at(event.group.id):
# 直接调用
reply = await process()
else:
check, result = check_response_rule(event.group.id, str(event.message_chain).strip())
if check:
reply = await process(result.strip())
# 检查是否随机响应
elif random_responding(event.group.id):
logging.info("随机响应group_{}消息".format(event.group.id))
reply = await process()
if rule_check_res.matching:
text = str(rule_check_res.replacement).strip()
reply = await self.common_process(
launcher_type="group",
launcher_id=event.group.id,
text_message=text,
message_chain=rule_check_res.replacement,
sender_id=event.sender.id
)
if reply:
await self.send(event, reply)
# 通知系统管理员
async def notify_admin(self, message: str):
config = context.get_config_manager().data
if config['admin_qq'] != 0 and config['admin_qq'] != []:
logging.info("通知管理员:{}".format(message))
if type(config['admin_qq']) == int:
self.adapter.send_message(
"person",
config['admin_qq'],
MessageChain([Plain("[bot]{}".format(message))])
)
else:
for adm in config['admin_qq']:
self.adapter.send_message(
"person",
adm,
MessageChain([Plain("[bot]{}".format(message))])
)
await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
async def notify_admin_message_chain(self, message):
async def notify_admin_message_chain(self, message: mirai.MessageChain):
config = context.get_config_manager().data
if config['admin_qq'] != 0 and config['admin_qq'] != []:
logging.info("通知管理员:{}".format(message))
admin_list = []
if type(config['admin_qq']) == int:
admin_list.append(config['admin_qq'])
for adm in admin_list:
self.adapter.send_message(
"person",
config['admin_qq'],
adm,
message
)
else:
for adm in config['admin_qq']:
self.adapter.send_message(
"person",
adm,
message
)
async def run(self):
await self.adapter.run_async()
await self.adapter.run_async()

View File

@@ -1,4 +1,5 @@
# 此模块提供了消息处理的具体逻辑的接口
from __future__ import annotations
import asyncio
import time
import traceback
@@ -6,23 +7,15 @@ import traceback
import mirai
import logging
# 这里不使用动态引入config
# 因为在这里动态引入会卡死程序
# 而此模块静态引用config与动态引入的表现一致
# 已弃用,由于超时时间现已动态使用
# import config as config_init_import
from ..qqbot import ratelimit
from ..qqbot import command, message
from ..openai import session as openai_session
from ..utils import context
from ..plugin import host as plugin_host
from ..plugin import models as plugin_models
from ..qqbot import ignore
from ..qqbot import banlist
from ..qqbot import blob
import tips as tips_custom
from ..boot import app
from .cntfilter import entities
processing = []
@@ -37,7 +30,7 @@ def is_admin(qq: int) -> bool:
async def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: mirai.MessageChain,
sender_id: int) -> mirai.MessageChain:
sender_id: int) -> list:
global processing
mgr = context.get_qqbot_manager()
@@ -45,19 +38,10 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st
reply = []
session_name = "{}_{}".format(launcher_type, launcher_id)
# 检查发送方是否被禁用
if banlist.is_banned(launcher_type, launcher_id, sender_id):
logging.info("根据禁用列表忽略{}_{}的消息".format(launcher_type, launcher_id))
return []
if ignore.ignore(text_message):
logging.info("根据忽略规则忽略消息: {}".format(text_message))
return []
config = context.get_config_manager().data
if not config['wait_last_done'] and session_name in processing:
return mirai.MessageChain([mirai.Plain(tips_custom.message_drop_tip)])
return [mirai.Plain(tips_custom.message_drop_tip)]
# 检查是否被禁言
if launcher_type == 'group':
@@ -66,9 +50,14 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st
logging.info("机器人被禁言,跳过消息处理(group_{})".format(launcher_id))
return reply
if config['income_msg_check']:
if mgr.reply_filter.is_illegal(text_message):
return mirai.MessageChain(mirai.Plain("[bot] 消息中存在不合适的内容, 请更换措辞"))
cntfilter_res = await mgr.cntfilter_mgr.pre_process(text_message)
if cntfilter_res.level == entities.ManagerResultLevel.INTERRUPT:
if cntfilter_res.console_notice:
mgr.ap.logger.info(cntfilter_res.console_notice)
if cntfilter_res.user_notice:
return [mirai.Plain(cntfilter_res.user_notice)]
else:
return []
openai_session.get_session(session_name).acquire_response_lock()
@@ -113,12 +102,10 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st
else: # 消息
msg_type = 'message'
# 限速丢弃检查
# print(ratelimit.__crt_minute_usage__[session_name])
if config['rate_limit_strategy'] == "drop":
if ratelimit.is_reach_limit(session_name):
logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message))
if not await mgr.ratelimiter.require(launcher_type, launcher_id):
logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message))
return mirai.MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else []
return mirai.MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else []
before = time.time()
# 触发插件事件
@@ -143,12 +130,6 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st
reply = message.process_normal_message(text_message,
mgr, config, launcher_type, launcher_id, sender_id)
# 限速等待时间
if config['rate_limit_strategy'] == "wait":
time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before))
ratelimit.add_usage(session_name)
if reply is not None and len(reply) > 0 and (type(reply[0]) == str or type(reply[0]) == mirai.Plain):
if type(reply[0]) == mirai.Plain:
reply[0] = reply[0].text
@@ -157,9 +138,18 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st
reply[0][:min(100, len(reply[0]))] + (
"..." if len(reply[0]) > 100 else "")))
if msg_type == 'message':
reply = [mgr.reply_filter.process(reply[0])]
reply = blob.check_text(reply[0])
cntfilter_res = await mgr.cntfilter_mgr.post_process(reply[0])
if cntfilter_res.level == entities.ManagerResultLevel.INTERRUPT:
if cntfilter_res.console_notice:
mgr.ap.logger.info(cntfilter_res.console_notice)
if cntfilter_res.user_notice:
return [mirai.Plain(cntfilter_res.user_notice)]
else:
return []
else:
reply = [cntfilter_res.replacement]
reply = await mgr.longtext_pcs.check_and_process(reply[0])
else:
logging.info("回复[{}]消息".format(session_name))

View File

24
pkg/qqbot/ratelim/algo.py Normal file
View File

@@ -0,0 +1,24 @@
from __future__ import annotations
import abc
from ...boot import app
class ReteLimitAlgo(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
raise NotImplementedError
@abc.abstractmethod
async def release_access(self, launcher_type: str, launcher_id: int):
raise NotImplementedError

View File

View File

@@ -0,0 +1,85 @@
# 固定窗口算法
from __future__ import annotations
import asyncio
import time
from .. import algo
class SessionContainer:
wait_lock: asyncio.Lock
records: dict[int, int]
"""访问记录key为每分钟的起始时间戳value为访问次数"""
def __init__(self):
self.wait_lock = asyncio.Lock()
self.records = {}
class FixedWindowAlgo(algo.ReteLimitAlgo):
containers_lock: asyncio.Lock
"""访问记录容器锁"""
containers: dict[str, SessionContainer]
"""访问记录容器key为launcher_type launcher_id"""
async def initialize(self):
self.containers_lock = asyncio.Lock()
self.containers = {}
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
# 加锁,找容器
container: SessionContainer = None
session_name = f'{launcher_type}_{launcher_id}'
async with self.containers_lock:
container = self.containers.get(session_name)
if container is None:
container = SessionContainer()
self.containers[session_name] = container
# 等待锁
async with container.wait_lock:
# 获取当前时间戳
now = int(time.time())
# 获取当前分钟的起始时间戳
now = now - now % 60
# 获取当前分钟的访问次数
count = container.records.get(now, 0)
limitation = self.ap.cfg_mgr.data['rate_limitation']['default']
if session_name in self.ap.cfg_mgr.data['rate_limitation']:
limitation = self.ap.cfg_mgr.data['rate_limitation'][session_name]
# 如果访问次数超过了限制
if count >= limitation:
if self.ap.cfg_mgr.data['rate_limit_strategy'] == 'drop':
return False
elif self.ap.cfg_mgr.data['rate_limit_strategy'] == 'wait':
# 等待下一分钟
await asyncio.sleep(60 - time.time() % 60)
now = int(time.time())
now = now - now % 60
if now not in container.records:
container.records = {}
container.records[now] = 1
else:
# 访问次数加一
container.records[now] = count + 1
# 返回True
return True
async def release_access(self, launcher_type: str, launcher_id: int):
pass

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from . import algo
from .algos import fixedwin
from ...boot import app
class RateLimiter:
"""限速器
"""
ap: app.Application
algo: algo.ReteLimitAlgo
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
self.algo = fixedwin.FixedWindowAlgo(self.ap)
await self.algo.initialize()
async def require(self, launcher_type: str, launcher_id: int) -> bool:
"""请求访问
"""
return await self.algo.require_access(launcher_type, launcher_id)
async def release(self, launcher_type: str, launcher_id: int):
"""释放访问
"""
return await self.algo.release_access(launcher_type, launcher_id)

View File

@@ -1,89 +0,0 @@
# 限速相关模块
import time
import logging
import threading
from ..utils import context
__crt_minute_usage__ = {}
"""当前分钟每个会话的对话次数"""
__timer_thr__: threading.Thread = None
def get_limitation(session_name: str) -> int:
"""获取会话的限制次数"""
config = context.get_config_manager().data
if session_name in config['rate_limitation']:
return config['rate_limitation'][session_name]
else:
return config['rate_limitation']["default"]
def add_usage(session_name: str):
"""增加会话的对话次数"""
global __crt_minute_usage__
if session_name in __crt_minute_usage__:
__crt_minute_usage__[session_name] += 1
else:
__crt_minute_usage__[session_name] = 1
def start_timer():
"""启动定时器"""
global __timer_thr__
__timer_thr__ = threading.Thread(target=run_timer, daemon=True)
__timer_thr__.start()
def run_timer():
"""启动定时器,每分钟清空一次对话次数"""
global __crt_minute_usage__
global __timer_thr__
# 等待直到整分钟
time.sleep(60 - time.time() % 60)
while True:
if __timer_thr__ != threading.current_thread():
break
logging.debug("清空当前分钟的对话次数")
__crt_minute_usage__ = {}
time.sleep(60)
def get_usage(session_name: str) -> int:
"""获取会话的对话次数"""
global __crt_minute_usage__
if session_name in __crt_minute_usage__:
return __crt_minute_usage__[session_name]
else:
return 0
def get_rest_wait_time(session_name: str, spent: float) -> float:
"""获取会话此回合的剩余等待时间"""
global __crt_minute_usage__
min_seconds_per_round = 60.0 / get_limitation(session_name)
if session_name in __crt_minute_usage__:
return max(0, min_seconds_per_round - spent)
else:
return 0
def is_reach_limit(session_name: str) -> bool:
"""判断会话是否超过限制"""
global __crt_minute_usage__
if session_name in __crt_minute_usage__:
return __crt_minute_usage__[session_name] >= get_limitation(session_name)
else:
return False
start_timer()

View File

View File

@@ -0,0 +1,9 @@
import pydantic
import mirai
class RuleJudgeResult(pydantic.BaseModel):
matching: bool = False
replacement: mirai.MessageChain = None

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
import mirai
from ...boot import app
from . import entities, rule
from .rules import atbot, prefix, regexp, random
class GroupRespondRuleChecker:
"""群组响应规则检查器
"""
ap: app.Application
rule_matchers: list[rule.GroupRespondRule]
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
"""初始化检查器
"""
self.rule_matchers = [
atbot.AtBotRule(self.ap),
prefix.PrefixRule(self.ap),
regexp.RegExpRule(self.ap),
random.RandomRespRule(self.ap),
]
for rule_matcher in self.rule_matchers:
await rule_matcher.initialize()
async def check(
self,
message_text: str,
message_chain: mirai.MessageChain,
launcher_id: int,
sender_id: int,
) -> entities.RuleJudgeResult:
"""检查消息是否匹配规则
"""
rules = self.ap.cfg_mgr.data['response_rules']
use_rule = rules['default']
if str(launcher_id) in use_rule:
use_rule = use_rule[str(launcher_id)]
for rule_matcher in self.rule_matchers:
res = await rule_matcher.match(message_text, message_chain, use_rule)
if res.matching:
return res
return entities.RuleJudgeResult(
matching=False,
replacement=message_chain
)

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
import abc
import mirai
from ...boot import app
from . import entities
class GroupRespondRule(metaclass=abc.ABCMeta):
"""群组响应规则的抽象类
"""
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
) -> entities.RuleJudgeResult:
"""判断消息是否匹配规则
"""
raise NotImplementedError

View File

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
import mirai
from .. import rule as rule_model
from .. import entities
class AtBotRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
) -> entities.RuleJudgeResult:
if message_chain.has(mirai.At(self.ap.im_mgr.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(self.ap.im_mgr.bot_account_id))
return entities.RuleJudgeResult(
matching=True,
replacement=message_chain,
)
return entities.RuleJudgeResult(
matching=False,
replacement = message_chain
)

View File

@@ -0,0 +1,29 @@
import mirai
from .. import rule as rule_model
from .. import entities
class PrefixRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
) -> entities.RuleJudgeResult:
prefixes = rule_dict['prefix']
for prefix in prefixes:
if message_text.startswith(prefix):
return entities.RuleJudgeResult(
matching=True,
replacement=mirai.MessageChain([
mirai.Plain(message_text[len(prefix):])
]),
)
return entities.RuleJudgeResult(
matching=False,
replacement=message_chain
)

View File

@@ -0,0 +1,22 @@
import random
import mirai
from .. import rule as rule_model
from .. import entities
class RandomRespRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
) -> entities.RuleJudgeResult:
random_rate = rule_dict['random_rate']
return entities.RuleJudgeResult(
matching=random.random() < random_rate,
replacement=message_chain
)

View File

@@ -0,0 +1,31 @@
import re
import mirai
from .. import rule as rule_model
from .. import entities
class RegExpRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
) -> entities.RuleJudgeResult:
regexps = rule_dict['regexp']
for regexp in regexps:
match = re.match(regexp, message_text)
if match:
return entities.RuleJudgeResult(
matching=True,
replacement=message_chain,
)
return entities.RuleJudgeResult(
matching=False,
replacement=message_chain
)

View File

@@ -1,208 +0,0 @@
import logging
import re
import os
import traceback
from PIL import Image, ImageDraw, ImageFont
from ..utils import context
text_render_font: ImageFont = None
def initialize():
global text_render_font
logging.debug("初始化文字转图片模块...")
config = context.get_config_manager().data
if config['blob_message_strategy'] == "image": # 仅在启用了image时才加载字体
use_font = config['font_path']
try:
# 检查是否存在
if not os.path.exists(use_font):
# 若是windows系统使用微软雅黑
if os.name == "nt":
use_font = "C:/Windows/Fonts/msyh.ttc"
if not os.path.exists(use_font):
logging.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。")
config['blob_message_strategy'] = "forward"
else:
logging.info("使用Windows自带字体" + use_font)
text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8")
else:
logging.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。")
config['blob_message_strategy'] = "forward"
else:
text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8")
except:
traceback.print_exc()
logging.error("加载字体文件失败({})更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。".format(use_font))
config['blob_message_strategy'] = "forward"
logging.debug("字体文件加载完成。")
def indexNumber(path=''):
"""
查找字符串中数字所在串中的位置
:param path:目标字符串
:return:<class 'list'>: <class 'list'>: [['1', 16], ['2', 35], ['1', 51]]
"""
kv = []
nums = []
beforeDatas = re.findall('[\d]+', path)
for num in beforeDatas:
indexV = []
times = path.count(num)
if times > 1:
if num not in nums:
indexs = re.finditer(num, path)
for index in indexs:
iV = []
i = index.span()[0]
iV.append(num)
iV.append(i)
kv.append(iV)
nums.append(num)
else:
index = path.find(num)
indexV.append(num)
indexV.append(index)
kv.append(indexV)
# 根据数字位置排序
indexSort = []
resultIndex = []
for vi in kv:
indexSort.append(vi[1])
indexSort.sort()
for i in indexSort:
for v in kv:
if i == v[1]:
resultIndex.append(v)
return resultIndex
def get_size(file):
# 获取文件大小:KB
size = os.path.getsize(file)
return size / 1024
def get_outfile(infile, outfile):
if outfile:
return outfile
dir, suffix = os.path.splitext(infile)
outfile = '{}-out{}'.format(dir, suffix)
return outfile
def compress_image(infile, outfile='', kb=100, step=20, quality=90):
"""不改变图片尺寸压缩到指定大小
:param infile: 压缩源文件
:param outfile: 压缩文件保存地址
:param mb: 压缩目标,KB
:param step: 每次调整的压缩比率
:param quality: 初始压缩比率
:return: 压缩文件地址,压缩文件大小
"""
o_size = get_size(infile)
if o_size <= kb:
return infile, o_size
outfile = get_outfile(infile, outfile)
while o_size > kb:
im = Image.open(infile)
im.save(outfile, quality=quality)
if quality - step < 0:
break
quality -= step
o_size = get_size(outfile)
return outfile, get_size(outfile)
def text_to_image(text_str: str, save_as="temp.png", width=800):
global text_render_font
logging.debug("正在将文本转换为图片...")
text_str = text_str.replace("\t", " ")
# 分行
lines = text_str.split('\n')
# 计算并分割
final_lines = []
text_width = width-80
logging.debug("lines: {}, text_width: {}".format(lines, text_width))
for line in lines:
logging.debug(type(text_render_font))
# 如果长了就分割
line_width = text_render_font.getlength(line)
logging.debug("line_width: {}".format(line_width))
if line_width < text_width:
final_lines.append(line)
continue
else:
rest_text = line
while True:
# 分割最前面的一行
point = int(len(rest_text) * (text_width / line_width))
# 检查断点是否在数字中间
numbers = indexNumber(rest_text)
for number in numbers:
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
point = number[1]
break
final_lines.append(rest_text[:point])
rest_text = rest_text[point:]
line_width = text_render_font.getlength(rest_text)
if line_width < text_width:
final_lines.append(rest_text)
break
else:
continue
# 准备画布
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
draw = ImageDraw.Draw(img, mode='RGBA')
logging.debug("正在绘制图片...")
# 绘制正文
line_number = 0
offset_x = 20
offset_y = 30
for final_line in final_lines:
draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=text_render_font)
# 遍历此行,检查是否有emoji
idx_in_line = 0
for ch in final_line:
# if self.is_emoji(ch):
# emoji_img_valid = ensure_emoji(hex(ord(ch))[2:])
# if emoji_img_valid: # emoji图像可用,绘制到指定位置
# emoji_image = Image.open("emojis/{}.png".format(hex(ord(ch))[2:]), mode='r').convert('RGBA')
# emoji_image = emoji_image.resize((32, 32))
# x, y = emoji_image.size
# final_emoji_img = Image.new('RGBA', emoji_image.size, (255, 255, 255))
# final_emoji_img.paste(emoji_image, (0, 0, x, y), emoji_image)
# img.paste(final_emoji_img, box=(int(offset_x + idx_in_line * 32), offset_y + 35 * line_number))
# 检查字符占位宽
char_code = ord(ch)
if char_code >= 127:
idx_in_line += 1
else:
idx_in_line += 0.5
line_number += 1
logging.debug("正在保存图片...")
img.save(save_as)
return save_as

View File

@@ -33,7 +33,7 @@ QChatGPT 主程序需要连接`QQ登录框架`以与QQ通信您可以选择 [
### 📄`banlist.py`
复制`res/templates/banlist-template.py`所有内容,创建`banlist.py`,这是黑名单配置文件,根据需要修改。
复制`banlist-template.py`所有内容,创建`banlist.py`,这是黑名单配置文件,根据需要修改。
### 📄`cmdpriv.json`