From a2360897851bdb6b8e8166d468dc8ab36d715ad2 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 24 Jan 2024 16:11:56 +0800 Subject: [PATCH 01/10] =?UTF-8?q?refactor:=20=E7=8B=AC=E7=AB=8Bresprule?= =?UTF-8?q?=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/qqbot/manager.py | 73 +++---------------------------------------- pkg/qqbot/resprule.py | 67 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 69 deletions(-) create mode 100644 pkg/qqbot/resprule.py diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 26769e43..a0ef8550 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -18,76 +18,11 @@ from ..plugin import host as plugin_host from ..plugin import models as plugin_models import tips as tips_custom from ..qqbot import adapter as msadapter +from . import resprule from ..boot import app -# 检查消息是否符合泛响应匹配机制 -def check_response_rule(group_id:int, text: str): - config = context.get_config_manager().data - - rules = config['response_rules'] - - # 检查是否有特定规则 - if 'prefix' not in config['response_rules']: - if str(group_id) in config['response_rules']: - rules = config['response_rules'][str(group_id)] - else: - rules = config['response_rules']['default'] - - # 检查前缀匹配 - if 'prefix' in rules: - for rule in rules['prefix']: - if text.startswith(rule): - return True, text.replace(rule, "", 1) - - # 检查正则表达式匹配 - if 'regexp' in rules: - for rule in rules['regexp']: - import re - match = re.match(rule, text) - if match: - return True, text - - return False, "" - - -def response_at(group_id: int): - config = context.get_config_manager().data - - use_response_rule = config['response_rules'] - - # 检查是否有特定规则 - if 'prefix' not in config['response_rules']: - if str(group_id) in config['response_rules']: - use_response_rule = config['response_rules'][str(group_id)] - else: - use_response_rule = config['response_rules']['default'] - - if 'at' not in use_response_rule: - return True - - return use_response_rule['at'] - - -def random_responding(group_id): - config = context.get_config_manager().data - - use_response_rule = config['response_rules'] - - # 检查是否有特定规则 - if 'prefix' not in config['response_rules']: - if str(group_id) in config['response_rules']: - use_response_rule = config['response_rules'][str(group_id)] - else: - use_response_rule = config['response_rules']['default'] - - if 'random_rate' in use_response_rule: - import random - return random.random() < use_response_rule['random_rate'] - return False - - # 控制QQ消息输入输出的类 class QQBotManager: retry = 3 @@ -371,16 +306,16 @@ class QQBotManager: elif Image in event.message_chain: pass else: - if At(self.bot_account_id) in event.message_chain and response_at(event.group.id): + if At(self.bot_account_id) in event.message_chain and resprule.response_at(event.group.id): # 直接调用 reply = await process() else: - check, result = check_response_rule(event.group.id, str(event.message_chain).strip()) + check, result = resprule.check_response_rule(event.group.id, str(event.message_chain).strip()) if check: reply = await process(result.strip()) # 检查是否随机响应 - elif random_responding(event.group.id): + elif resprule.random_responding(event.group.id): logging.info("随机响应group_{}消息".format(event.group.id)) reply = await process() diff --git a/pkg/qqbot/resprule.py b/pkg/qqbot/resprule.py new file mode 100644 index 00000000..5c237024 --- /dev/null +++ b/pkg/qqbot/resprule.py @@ -0,0 +1,67 @@ +from ..utils import context + + +# 检查消息是否符合泛响应匹配机制 +def check_response_rule(group_id:int, text: str): + config = context.get_config_manager().data + + rules = config['response_rules'] + + # 检查是否有特定规则 + if 'prefix' not in config['response_rules']: + if str(group_id) in config['response_rules']: + rules = config['response_rules'][str(group_id)] + else: + rules = config['response_rules']['default'] + + # 检查前缀匹配 + if 'prefix' in rules: + for rule in rules['prefix']: + if text.startswith(rule): + return True, text.replace(rule, "", 1) + + # 检查正则表达式匹配 + if 'regexp' in rules: + for rule in rules['regexp']: + import re + match = re.match(rule, text) + if match: + return True, text + + return False, "" + + +def response_at(group_id: int): + config = context.get_config_manager().data + + use_response_rule = config['response_rules'] + + # 检查是否有特定规则 + if 'prefix' not in config['response_rules']: + if str(group_id) in config['response_rules']: + use_response_rule = config['response_rules'][str(group_id)] + else: + use_response_rule = config['response_rules']['default'] + + if 'at' not in use_response_rule: + return True + + return use_response_rule['at'] + + +def random_responding(group_id): + config = context.get_config_manager().data + + use_response_rule = config['response_rules'] + + # 检查是否有特定规则 + if 'prefix' not in config['response_rules']: + if str(group_id) in config['response_rules']: + use_response_rule = config['response_rules'][str(group_id)] + else: + use_response_rule = config['response_rules']['default'] + + if 'random_rate' in use_response_rule: + import random + return random.random() < use_response_rule['random_rate'] + return False From 3d06a18bcba494bd7d79c58b2a3d464aa80b8d8c Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 24 Jan 2024 17:00:56 +0800 Subject: [PATCH 02/10] =?UTF-8?q?refactor:=20=E7=AE=80=E5=8C=96=E7=A7=81?= =?UTF-8?q?=E8=81=8A=E7=BE=A4=E8=81=8A=E5=85=B1=E5=90=8C=E5=A4=84=E7=90=86?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/qqbot/manager.py | 175 +++++++++++++++++++------------------------ 1 file changed, 77 insertions(+), 98 deletions(-) diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index a0ef8550..3ac877c1 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -7,6 +7,7 @@ import asyncio from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \ FriendMessage, Image, MessageChain, Plain +import mirai import func_timeout from ..openai import session as openai_session @@ -217,6 +218,44 @@ class QQBotManager: quote_origin=True if config['quote_origin'] and check_quote else False ) + async def common_process( + self, + launcher_type: str, + launcher_id: int, + text_message: str, + message_chain: MessageChain, + sender_id: int + ) -> mirai.MessageChain: + """ + 私聊群聊通用消息处理方法 + """ + if mirai.Image in message_chain: + return [] + elif sender_id == self.bot_account_id: + return [] + else: + # 超时则重试,重试超过次数则放弃 + failed = 0 + for i in range(self.retry): + try: + reply = await processor.process_message(launcher_type, launcher_id, text_message, message_chain, + sender_id) + return reply + + # TODO openai 超时处理 + except func_timeout.FunctionTimedOut: + logging.warning("{}_{}: 超时,重试中({})".format(launcher_type, launcher_id, i)) + openai_session.get_session("{}_{}".format(launcher_type, launcher_id)).release_response_lock() + if "{}_{}".format(launcher_type, launcher_id) in processor.processing: + processor.processing.remove("{}_{}".format(launcher_type, launcher_id)) + failed += 1 + continue + + if failed == self.retry: + openai_session.get_session("{}_{}".format(launcher_type, launcher_id)).release_response_lock() + await self.notify_admin("{} 请求超时".format("{}_{}".format(launcher_type, launcher_id))) + reply = [tips_custom.reply_message] + # 私聊消息处理 async def on_person_message(self, event: MessageEvent): reply = '' @@ -225,38 +264,15 @@ class QQBotManager: if not self.enable_private: logging.debug("已在banlist.py中禁用所有私聊") - elif event.sender.id == self.bot_account_id: - pass - else: - if Image in event.message_chain: - pass - else: - # 超时则重试,重试超过次数则放弃 - failed = 0 - for i in range(self.retry): - try: - - # @func_timeout.func_set_timeout(config['process_message_timeout']) - async def time_ctrl_wrapper(): - reply = await processor.process_message('person', event.sender.id, str(event.message_chain), - event.message_chain, - event.sender.id) - return reply - - reply = await time_ctrl_wrapper() - break - except func_timeout.FunctionTimedOut: - logging.warning("person_{}: 超时,重试中({})".format(event.sender.id, i)) - openai_session.get_session('person_{}'.format(event.sender.id)).release_response_lock() - if "person_{}".format(event.sender.id) in processor.processing: - processor.processing.remove('person_{}'.format(event.sender.id)) - failed += 1 - continue - if failed == self.retry: - openai_session.get_session('person_{}'.format(event.sender.id)).release_response_lock() - self.notify_admin("{} 请求超时".format("person_{}".format(event.sender.id))) - reply = [tips_custom.reply_message] + else: + reply = await self.common_process( + launcher_type="person", + launcher_id=event.sender.id, + text_message=str(event.message_chain), + message_chain=event.message_chain, + sender_id=event.sender.id + ) if reply: await self.send(event, reply, check_quote=False, check_at_sender=False) @@ -264,100 +280,63 @@ class QQBotManager: # 群消息处理 async def on_group_message(self, event: GroupMessage): reply = '' - - config = context.get_config_manager().data - - async def process(text=None) -> str: - replys = "" - if At(self.bot_account_id) in event.message_chain: - event.message_chain.remove(At(self.bot_account_id)) - - # 超时则重试,重试超过次数则放弃 - failed = 0 - for i in range(self.retry): - try: - # @func_timeout.func_set_timeout(config['process_message_timeout']) - async def time_ctrl_wrapper(): - replys = await processor.process_message('group', event.group.id, - str(event.message_chain).strip() if text is None else text, - event.message_chain, - event.sender.id) - return replys - - replys = await time_ctrl_wrapper() - break - except func_timeout.FunctionTimedOut: - logging.warning("group_{}: 超时,重试中({})".format(event.group.id, i)) - openai_session.get_session('group_{}'.format(event.group.id)).release_response_lock() - if "group_{}".format(event.group.id) in processor.processing: - processor.processing.remove('group_{}'.format(event.group.id)) - failed += 1 - continue - - if failed == self.retry: - openai_session.get_session('group_{}'.format(event.group.id)).release_response_lock() - self.notify_admin("{} 请求超时".format("group_{}".format(event.group.id))) - replys = [tips_custom.replys_message] - - return replys if not self.enable_group: logging.debug("已在banlist.py中禁用所有群聊") - elif Image in event.message_chain: - pass + else: + do_req = False + text = str(event.message_chain).strip() if At(self.bot_account_id) in event.message_chain and resprule.response_at(event.group.id): # 直接调用 - reply = await process() + # reply = await process() + event.message_chain.remove(At(self.bot_account_id)) + text = str(event.message_chain).strip() + do_req = True else: check, result = resprule.check_response_rule(event.group.id, str(event.message_chain).strip()) if check: - reply = await process(result.strip()) + do_req = True + text = result.strip() # 检查是否随机响应 elif resprule.random_responding(event.group.id): logging.info("随机响应group_{}消息".format(event.group.id)) - reply = await process() + # reply = await process() + do_req = True + + if do_req: + reply = await self.common_process( + launcher_type="group", + launcher_id=event.group.id, + text_message=text, + message_chain=event.message_chain, + sender_id=event.sender.id + ) if reply: await self.send(event, reply) # 通知系统管理员 async def notify_admin(self, message: str): - config = context.get_config_manager().data - if config['admin_qq'] != 0 and config['admin_qq'] != []: - logging.info("通知管理员:{}".format(message)) - if type(config['admin_qq']) == int: - self.adapter.send_message( - "person", - config['admin_qq'], - MessageChain([Plain("[bot]{}".format(message))]) - ) - else: - for adm in config['admin_qq']: - self.adapter.send_message( - "person", - adm, - MessageChain([Plain("[bot]{}".format(message))]) - ) + await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))])) - async def notify_admin_message_chain(self, message): + async def notify_admin_message_chain(self, message: mirai.MessageChain): config = context.get_config_manager().data if config['admin_qq'] != 0 and config['admin_qq'] != []: logging.info("通知管理员:{}".format(message)) + + admin_list = [] + if type(config['admin_qq']) == int: + admin_list.append(config['admin_qq']) + + for adm in admin_list: self.adapter.send_message( "person", - config['admin_qq'], + adm, message ) - else: - for adm in config['admin_qq']: - self.adapter.send_message( - "person", - adm, - message - ) async def run(self): - await self.adapter.run_async() \ No newline at end of file + await self.adapter.run_async() From a975718a647d3b33913528f81944db909e918810 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 24 Jan 2024 22:29:19 +0800 Subject: [PATCH 03/10] =?UTF-8?q?refactor:=20=E6=9A=82=E6=97=B6=E5=88=A0?= =?UTF-8?q?=E9=99=A4=E5=AF=B9=E7=83=AD=E9=87=8D=E8=BD=BD=E7=9A=84=E6=94=AF?= =?UTF-8?q?=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/qqbot/manager.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 3ac877c1..e8cf6b6e 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -48,24 +48,17 @@ class QQBotManager: self.timeout = config['process_message_timeout'] self.retry = config['retry_times'] - # 由于YiriMirai的bot对象是单例的,且shutdown方法暂时无法使用 - # 故只在第一次初始化时创建bot对象,重载之后使用原bot对象 - # 因此,bot的配置不支持热重载 - if first_time_init: - logging.debug("Use adapter:" + config['msg_source_adapter']) - if config['msg_source_adapter'] == 'yirimirai': - from pkg.qqbot.sources.yirimirai import YiriMiraiAdapter + logging.debug("Use adapter:" + config['msg_source_adapter']) + if config['msg_source_adapter'] == 'yirimirai': + from pkg.qqbot.sources.yirimirai import YiriMiraiAdapter - mirai_http_api_config = config['mirai_http_api_config'] - self.bot_account_id = config['mirai_http_api_config']['qq'] - self.adapter = YiriMiraiAdapter(mirai_http_api_config) - elif config['msg_source_adapter'] == 'nakuru': - from pkg.qqbot.sources.nakuru import NakuruProjectAdapter - self.adapter = NakuruProjectAdapter(config['nakuru_config']) - self.bot_account_id = self.adapter.bot_account_id - else: - self.adapter = context.get_qqbot_manager().adapter - self.bot_account_id = context.get_qqbot_manager().bot_account_id + mirai_http_api_config = config['mirai_http_api_config'] + self.bot_account_id = config['mirai_http_api_config']['qq'] + self.adapter = YiriMiraiAdapter(mirai_http_api_config) + elif config['msg_source_adapter'] == 'nakuru': + from pkg.qqbot.sources.nakuru import NakuruProjectAdapter + self.adapter = NakuruProjectAdapter(config['nakuru_config']) + self.bot_account_id = self.adapter.bot_account_id # 保存 account_id 到审计模块 from ..utils.center import apigroup @@ -260,8 +253,6 @@ class QQBotManager: async def on_person_message(self, event: MessageEvent): reply = '' - config = context.get_config_manager().data - if not self.enable_private: logging.debug("已在banlist.py中禁用所有私聊") From b4bd86549ed101e5b286dea17b62080ff9944316 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 24 Jan 2024 23:33:19 +0800 Subject: [PATCH 04/10] =?UTF-8?q?chore:=20banlist=E6=A8=A1=E7=89=88?= =?UTF-8?q?=E7=A7=BB=E8=87=B3=E6=A0=B9=E7=9B=AE=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- res/templates/banlist-template.py => banlist-template.py | 0 main.py | 2 +- pkg/boot/files.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename res/templates/banlist-template.py => banlist-template.py (100%) diff --git a/res/templates/banlist-template.py b/banlist-template.py similarity index 100% rename from res/templates/banlist-template.py rename to banlist-template.py diff --git a/main.py b/main.py index adcc3d89..53e40853 100644 --- a/main.py +++ b/main.py @@ -16,7 +16,7 @@ sys.path.append(".") def check_file(): # 检查是否有banlist.py,如果没有就把banlist-template.py复制一份 if not os.path.exists('banlist.py'): - shutil.copy('res/templates/banlist-template.py', 'banlist.py') + shutil.copy('banlist-template.py', 'banlist.py') # 检查是否有sensitive.json if not os.path.exists("sensitive.json"): diff --git a/pkg/boot/files.py b/pkg/boot/files.py index eaa4ffb0..25a8ba5b 100644 --- a/pkg/boot/files.py +++ b/pkg/boot/files.py @@ -7,7 +7,7 @@ import sys required_files = { "config.py": "config-template.py", - "banlist.py": "res/templates/banlist-template.py", + "banlist.py": "banlist-template.py", "tips.py": "tips-custom-template.py", "sensitive.json": "res/templates/sensitive-template.json", "scenario/default.json": "scenario/default-template.json", From f3bcff12614c8ee573d28dadbd87bb165608ff3a Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 24 Jan 2024 23:33:48 +0800 Subject: [PATCH 05/10] =?UTF-8?q?chore:=20banlist=E6=A8=A1=E7=89=88?= =?UTF-8?q?=E7=A7=BB=E8=87=B3=E6=A0=B9=E7=9B=AE=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- res/docs/docker_deployment.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/res/docs/docker_deployment.md b/res/docs/docker_deployment.md index 13420dfa..7172daf8 100644 --- a/res/docs/docker_deployment.md +++ b/res/docs/docker_deployment.md @@ -33,7 +33,7 @@ QChatGPT 主程序需要连接`QQ登录框架`以与QQ通信,您可以选择 [ ### 📄`banlist.py` -复制`res/templates/banlist-template.py`所有内容,创建`banlist.py`,这是黑名单配置文件,根据需要修改。 +复制`banlist-template.py`所有内容,创建`banlist.py`,这是黑名单配置文件,根据需要修改。 ### 📄`cmdpriv.json` From f4ae9df3bf5ddd7465799292184c526fe64edf01 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 24 Jan 2024 23:38:13 +0800 Subject: [PATCH 06/10] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E4=BC=9A?= =?UTF-8?q?=E8=AF=9D=E5=B0=81=E7=A6=81=E5=8A=9F=E8=83=BD=E5=A4=84=E7=90=86?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/boot/app.py | 3 ++ pkg/boot/boot.py | 11 +++++- pkg/boot/config.py | 23 +++--------- pkg/qqbot/banlist.py | 50 ------------------------ pkg/qqbot/bansess/__init__.py | 0 pkg/qqbot/bansess/bansess.py | 71 +++++++++++++++++++++++++++++++++++ pkg/qqbot/manager.py | 32 +++++++++------- pkg/qqbot/process.py | 14 +------ 8 files changed, 109 insertions(+), 95 deletions(-) delete mode 100644 pkg/qqbot/banlist.py create mode 100644 pkg/qqbot/bansess/__init__.py create mode 100644 pkg/qqbot/bansess/bansess.py diff --git a/pkg/boot/app.py b/pkg/boot/app.py index c548b3e9..df9e92b9 100644 --- a/pkg/boot/app.py +++ b/pkg/boot/app.py @@ -28,6 +28,9 @@ class Application: def __init__(self): pass + async def initialize(self): + await self.im_mgr.initialize() + async def run(self): # TODO make it async plugin_host.initialize_plugins() diff --git a/pkg/boot/boot.py b/pkg/boot/boot.py index 2d640904..a2b2d7c7 100644 --- a/pkg/boot/boot.py +++ b/pkg/boot/boot.py @@ -50,7 +50,10 @@ async def make_app() -> app.Application: # 生成标识符 identifier.init() - cfg_mgr = await config.load_config() + cfg_mgr = await config.load_python_module_config( + "config.py", + "config-template.py" + ) context.set_config_manager(cfg_mgr) cfg = cfg_mgr.data @@ -63,7 +66,10 @@ async def make_app() -> app.Application: if overrided: qcg_logger.info("以下配置项已使用 override.json 覆盖:" + ",".join(overrided)) - tips_mgr = await config.load_tips() + tips_mgr = await config.load_python_module_config( + "tips.py", + "tips-custom-template.py" + ) # 初始化文字转图片 from pkg.utils import text2img @@ -121,4 +127,5 @@ async def make_app() -> app.Application: async def main(): app_inst = await make_app() + await app_inst.initialize() await app_inst.run() diff --git a/pkg/boot/config.py b/pkg/boot/config.py index f18ed2c3..1d891da0 100644 --- a/pkg/boot/config.py +++ b/pkg/boot/config.py @@ -4,11 +4,11 @@ from ..config import manager as config_mgr from ..config.impls import pymodule -async def load_config() -> config_mgr.ConfigManager: - """加载配置文件""" +async def load_python_module_config(config_name: str, template_name: str) -> config_mgr.ConfigManager: + """加载Python模块配置文件""" cfg_inst = pymodule.PythonModuleConfigFile( - "config.py", - "config-template.py" + config_name, + template_name ) cfg_mgr = config_mgr.ConfigManager(cfg_inst) @@ -17,19 +17,6 @@ async def load_config() -> config_mgr.ConfigManager: return cfg_mgr -async def load_tips() -> config_mgr.ConfigManager: - """加载提示文件""" - tips_inst = pymodule.PythonModuleConfigFile( - "tips.py", - "tips-custom-template.py" - ) - - tips_mgr = config_mgr.ConfigManager(tips_inst) - await tips_mgr.load_config() - - return tips_mgr - - async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str]: override_json = json.load(open("override.json", "r", encoding="utf-8")) overrided = [] @@ -39,5 +26,5 @@ async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str if key in config: config[key] = override_json[key] overrided.append(key) - + return overrided diff --git a/pkg/qqbot/banlist.py b/pkg/qqbot/banlist.py deleted file mode 100644 index 949c541b..00000000 --- a/pkg/qqbot/banlist.py +++ /dev/null @@ -1,50 +0,0 @@ -from ..utils import context - - -def is_banned(launcher_type: str, launcher_id: int, sender_id: int) -> bool: - if not context.get_qqbot_manager().enable_banlist: - return False - - result = False - - if launcher_type == 'group': - # 检查是否显式声明发起人QQ要被person忽略 - if sender_id in context.get_qqbot_manager().ban_person: - result = True - else: - for group_rule in context.get_qqbot_manager().ban_group: - if type(group_rule) == int: - if group_rule == launcher_id: # 此群群号被禁用 - result = True - elif type(group_rule) == str: - if group_rule.startswith('!'): - # 截取!后面的字符串作为表达式,判断是否匹配 - reg_str = group_rule[1:] - import re - if re.match(reg_str, str(launcher_id)): # 被豁免,最高级别 - result = False - break - else: - # 判断是否匹配regexp - import re - if re.match(group_rule, str(launcher_id)): # 此群群号被禁用 - result = True - - else: - # ban_person, 与群规则相同 - for person_rule in context.get_qqbot_manager().ban_person: - if type(person_rule) == int: - if person_rule == launcher_id: - result = True - elif type(person_rule) == str: - if person_rule.startswith('!'): - reg_str = person_rule[1:] - import re - if re.match(reg_str, str(launcher_id)): - result = False - break - else: - import re - if re.match(person_rule, str(launcher_id)): - result = True - return result diff --git a/pkg/qqbot/bansess/__init__.py b/pkg/qqbot/bansess/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/bansess/bansess.py b/pkg/qqbot/bansess/bansess.py new file mode 100644 index 00000000..e037fde8 --- /dev/null +++ b/pkg/qqbot/bansess/bansess.py @@ -0,0 +1,71 @@ +# 处理对会话的禁用配置 +# 过去的 banlist +from __future__ import annotations +import re + +from ...boot import app +from ...boot import config as config_util +from ...config import manager as cfg_mgr + + +class SessionBanManager: + + ap: app.Application = None + + banlist_mgr: cfg_mgr.ConfigManager + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + self.banlist_mgr = await config_util.load_python_module_config( + "banlist.py", + "res/templates/banlist-template.py" + ) + + async def is_banned( + self, launcher_type: str, launcher_id: int, sender_id: int + ) -> bool: + if not self.banlist_mgr.data['enable']: + return False + + result = False + + if launcher_type == 'group': + if not self.banlist_mgr.data['enable_group']: # 未启用群聊响应 + result = True + # 检查是否显式声明发起人QQ要被person忽略 + elif sender_id in self.banlist_mgr.data['person']: + result = True + else: + for group_rule in self.banlist_mgr.data['group']: + if type(group_rule) == int: + if group_rule == launcher_id: + result = True + elif type(group_rule) == str: + if group_rule.startswith('!'): + reg_str = group_rule[1:] + if re.match(reg_str, str(launcher_id)): + result = False + break + else: + if re.match(group_rule, str(launcher_id)): + result = True + elif launcher_type == 'person': + if not self.banlist_mgr.data['enable_private']: + result = True + else: + for person_rule in self.banlist_mgr.data['person']: + if type(person_rule) == int: + if person_rule == launcher_id: + result = True + elif type(person_rule) == str: + if person_rule.startswith('!'): + reg_str = person_rule[1:] + if re.match(reg_str, str(launcher_id)): + result = False + break + else: + if re.match(person_rule, str(launcher_id)): + result = True + return result diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index e8cf6b6e..bfe86b9c 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -20,6 +20,7 @@ from ..plugin import models as plugin_models import tips as tips_custom from ..qqbot import adapter as msadapter from . import resprule +from .bansess import bansess from ..boot import app @@ -42,11 +43,24 @@ class QQBotManager: ban_person = [] ban_group = [] + # modern + ap: app.Application = None + + bansess_mgr: bansess.SessionBanManager = None + def __init__(self, first_time_init=True, ap: app.Application = None): config = context.get_config_manager().data + self.ap = ap + self.bansess_mgr = bansess.SessionBanManager(ap) + self.timeout = config['process_message_timeout'] self.retry = config['retry_times'] + + async def initialize(self): + await self.bansess_mgr.initialize() + + config = context.get_config_manager().data logging.debug("Use adapter:" + config['msg_source_adapter']) if config['msg_source_adapter'] == 'yirimirai': @@ -160,19 +174,6 @@ class QQBotManager: self.unsubscribe_all = unsubscribe_all - # 加载禁用列表 - if os.path.exists("banlist.py"): - import banlist - self.enable_banlist = banlist.enable - self.ban_person = banlist.person - self.ban_group = banlist.group - logging.info("加载禁用列表: person: {}, group: {}".format(self.ban_person, self.ban_group)) - - if hasattr(banlist, "enable_private"): - self.enable_private = banlist.enable_private - if hasattr(banlist, "enable_group"): - self.enable_group = banlist.enable_group - config = context.get_config_manager().data if os.path.exists("sensitive.json") \ and config['sensitive_word_filter'] is not None \ @@ -222,6 +223,11 @@ class QQBotManager: """ 私聊群聊通用消息处理方法 """ + # 检查bansess + if await self.bansess_mgr.is_banned(launcher_type, launcher_id, sender_id): + self.ap.logger.info("根据禁用列表忽略{}_{}的消息".format(launcher_type, launcher_id)) + return [] + if mirai.Image in message_chain: return [] elif sender_id == self.bot_account_id: diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index 72788581..aa02315f 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -1,4 +1,5 @@ # 此模块提供了消息处理的具体逻辑的接口 +from __future__ import annotations import asyncio import time import traceback @@ -6,12 +7,6 @@ import traceback import mirai import logging -# 这里不使用动态引入config -# 因为在这里动态引入会卡死程序 -# 而此模块静态引用config与动态引入的表现一致 -# 已弃用,由于超时时间现已动态使用 -# import config as config_init_import - from ..qqbot import ratelimit from ..qqbot import command, message from ..openai import session as openai_session @@ -20,9 +15,9 @@ from ..utils import context from ..plugin import host as plugin_host from ..plugin import models as plugin_models from ..qqbot import ignore -from ..qqbot import banlist from ..qqbot import blob import tips as tips_custom +from ..boot import app processing = [] @@ -45,11 +40,6 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st reply = [] session_name = "{}_{}".format(launcher_type, launcher_id) - # 检查发送方是否被禁用 - if banlist.is_banned(launcher_type, launcher_id, sender_id): - logging.info("根据禁用列表忽略{}_{}的消息".format(launcher_type, launcher_id)) - return [] - if ignore.ignore(text_message): logging.info("根据忽略规则忽略消息: {}".format(text_message)) return [] From a9a798b19d6afbe7d272a2dcafd38769ac4130c6 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 25 Jan 2024 15:28:23 +0800 Subject: [PATCH 07/10] =?UTF-8?q?refactor:=20filter=E5=92=8Cignore?= =?UTF-8?q?=E7=8B=AC=E7=AB=8B=E6=88=90=E6=96=B0=E7=9A=84cntfilter=E5=8C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config-template.py | 2 + pkg/boot/config.py | 13 +-- pkg/config/impls/json.py | 44 ++++++++++ pkg/config/manager.py | 27 ++++++ pkg/qqbot/bansess/bansess.py | 3 +- pkg/qqbot/cntfilter/__init__.py | 0 pkg/qqbot/cntfilter/cntfilter.py | 93 +++++++++++++++++++++ pkg/qqbot/cntfilter/entities.py | 64 ++++++++++++++ pkg/qqbot/cntfilter/filter.py | 34 ++++++++ pkg/qqbot/cntfilter/filters/__init__.py | 0 pkg/qqbot/cntfilter/filters/baiduexamine.py | 61 ++++++++++++++ pkg/qqbot/cntfilter/filters/banwords.py | 44 ++++++++++ pkg/qqbot/cntfilter/filters/cntignore.py | 43 ++++++++++ pkg/qqbot/filter.py | 87 ------------------- pkg/qqbot/ignore.py | 18 ---- pkg/qqbot/manager.py | 21 +---- pkg/qqbot/process.py | 32 ++++--- 17 files changed, 440 insertions(+), 146 deletions(-) create mode 100644 pkg/config/impls/json.py create mode 100644 pkg/qqbot/cntfilter/__init__.py create mode 100644 pkg/qqbot/cntfilter/cntfilter.py create mode 100644 pkg/qqbot/cntfilter/entities.py create mode 100644 pkg/qqbot/cntfilter/filter.py create mode 100644 pkg/qqbot/cntfilter/filters/__init__.py create mode 100644 pkg/qqbot/cntfilter/filters/baiduexamine.py create mode 100644 pkg/qqbot/cntfilter/filters/banwords.py create mode 100644 pkg/qqbot/cntfilter/filters/cntignore.py delete mode 100644 pkg/qqbot/filter.py delete mode 100644 pkg/qqbot/ignore.py diff --git a/config-template.py b/config-template.py index fb60ff1e..2c1d0e99 100644 --- a/config-template.py +++ b/config-template.py @@ -167,6 +167,8 @@ response_rules = { # 此设置优先级高于response_rules # 用以过滤mirai等其他层级的命令 # @see https://github.com/RockChinQ/QChatGPT/issues/165 +# +# *需要同时开启下方 income_msg_check 才会生效 ignore_rules = { "prefix": ["/"], "regexp": [] diff --git a/pkg/boot/config.py b/pkg/boot/config.py index 1d891da0..3f796214 100644 --- a/pkg/boot/config.py +++ b/pkg/boot/config.py @@ -4,17 +4,8 @@ from ..config import manager as config_mgr from ..config.impls import pymodule -async def load_python_module_config(config_name: str, template_name: str) -> config_mgr.ConfigManager: - """加载Python模块配置文件""" - cfg_inst = pymodule.PythonModuleConfigFile( - config_name, - template_name - ) - - cfg_mgr = config_mgr.ConfigManager(cfg_inst) - await cfg_mgr.load_config() - - return cfg_mgr +load_python_module_config = config_mgr.load_python_module_config +load_json_config = config_mgr.load_json_config async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str]: diff --git a/pkg/config/impls/json.py b/pkg/config/impls/json.py new file mode 100644 index 00000000..cfc284cb --- /dev/null +++ b/pkg/config/impls/json.py @@ -0,0 +1,44 @@ +import os +import shutil +import json + +from .. import model as file_model + + +class JSONConfigFile(file_model.ConfigFile): + """JSON配置文件""" + + config_file_name: str = None + """配置文件名""" + + template_file_name: str = None + """模板文件名""" + + def __init__(self, config_file_name: str, template_file_name: str) -> None: + self.config_file_name = config_file_name + self.template_file_name = template_file_name + + def exists(self) -> bool: + return os.path.exists(self.config_file_name) + + async def create(self): + shutil.copyfile(self.template_file_name, self.config_file_name) + + async def load(self) -> dict: + + with open(self.config_file_name, 'r', encoding='utf-8') as f: + cfg = json.load(f) + + # 从模板文件中进行补全 + with open(self.template_file_name, 'r', encoding='utf-8') as f: + template_cfg = json.load(f) + + for key in template_cfg: + if key not in cfg: + cfg[key] = template_cfg[key] + + return cfg + + async def save(self, cfg: dict): + with open(self.config_file_name, 'w', encoding='utf-8') as f: + json.dump(cfg, f, indent=4, ensure_ascii=False) \ No newline at end of file diff --git a/pkg/config/manager.py b/pkg/config/manager.py index 5893ff0b..e343b0c2 100644 --- a/pkg/config/manager.py +++ b/pkg/config/manager.py @@ -1,5 +1,6 @@ from . import model as file_model from ..utils import context +from .impls import pymodule, json as json_file class ConfigManager: @@ -20,3 +21,29 @@ class ConfigManager: async def dump_config(self): await self.file.save(self.data) + + +async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager: + """加载Python模块配置文件""" + cfg_inst = pymodule.PythonModuleConfigFile( + config_name, + template_name + ) + + cfg_mgr = ConfigManager(cfg_inst) + await cfg_mgr.load_config() + + return cfg_mgr + + +async def load_json_config(config_name: str, template_name: str) -> ConfigManager: + """加载JSON配置文件""" + cfg_inst = json_file.JSONConfigFile( + config_name, + template_name + ) + + cfg_mgr = ConfigManager(cfg_inst) + await cfg_mgr.load_config() + + return cfg_mgr \ No newline at end of file diff --git a/pkg/qqbot/bansess/bansess.py b/pkg/qqbot/bansess/bansess.py index e037fde8..d8ef4958 100644 --- a/pkg/qqbot/bansess/bansess.py +++ b/pkg/qqbot/bansess/bansess.py @@ -4,7 +4,6 @@ from __future__ import annotations import re from ...boot import app -from ...boot import config as config_util from ...config import manager as cfg_mgr @@ -18,7 +17,7 @@ class SessionBanManager: self.ap = ap async def initialize(self): - self.banlist_mgr = await config_util.load_python_module_config( + self.banlist_mgr = await cfg_mgr.load_python_module_config( "banlist.py", "res/templates/banlist-template.py" ) diff --git a/pkg/qqbot/cntfilter/__init__.py b/pkg/qqbot/cntfilter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/cntfilter/cntfilter.py b/pkg/qqbot/cntfilter/cntfilter.py new file mode 100644 index 00000000..2d690b57 --- /dev/null +++ b/pkg/qqbot/cntfilter/cntfilter.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from ...boot import app +from . import entities +from . import filter +from .filters import cntignore, banwords, baiduexamine + + +class ContentFilterManager: + + ao: app.Application + + filter_chain: list[filter.ContentFilter] + + def __init__(self, ap: app.Application) -> None: + self.ap = ap + self.filter_chain = [] + + async def initialize(self): + self.filter_chain.append(cntignore.ContentIgnore(self.ap)) + + if self.ap.cfg_mgr.data['sensitive_word_filter']: + self.filter_chain.append(banwords.BanWordFilter(self.ap)) + + if self.ap.cfg_mgr.data['baidu_check']: + self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap)) + + for filter in self.filter_chain: + await filter.initialize() + + async def pre_process(self, message: str) -> entities.FilterManagerResult: + """请求llm前处理消息 + 只要有一个不通过就不放行,只放行 PASS 的消息 + """ + if not self.ap.cfg_mgr.data['income_msg_check']: # 不检查收到的消息,直接放行 + return entities.FilterManagerResult( + level=entities.ManagerResultLevel.CONTINUE, + replacement=message, + user_notice='', + console_notice='' + ) + else: + for filter in self.filter_chain: + if entities.EnableStage.PRE in filter.enable_stages: + result = await filter.process(message) + + if result.level in [ + entities.ResultLevel.BLOCK, + entities.ResultLevel.MASKED + ]: + return entities.FilterManagerResult( + level=entities.ManagerResultLevel.INTERRUPT, + replacement=result.replacement, + user_notice=result.user_notice, + console_notice=result.console_notice + ) + elif result.level == entities.ResultLevel.PASS: + message = result.replacement + + return entities.FilterManagerResult( + level=entities.ManagerResultLevel.CONTINUE, + replacement=message, + user_notice='', + console_notice='' + ) + + async def post_process(self, message: str) -> entities.FilterManagerResult: + """请求llm后处理响应 + 只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter + """ + for filter in self.filter_chain: + if entities.EnableStage.POST in filter.enable_stages: + result = await filter.process(message) + + if result.level == entities.ResultLevel.BLOCK: + return entities.FilterManagerResult( + level=entities.ManagerResultLevel.INTERRUPT, + replacement=result.replacement, + user_notice=result.user_notice, + console_notice=result.console_notice + ) + elif result.level in [ + entities.ResultLevel.PASS, + entities.ResultLevel.MASKED + ]: + message = result.replacement + + return entities.FilterManagerResult( + level=entities.ManagerResultLevel.CONTINUE, + replacement=message, + user_notice='', + console_notice='' + ) diff --git a/pkg/qqbot/cntfilter/entities.py b/pkg/qqbot/cntfilter/entities.py new file mode 100644 index 00000000..7ab05675 --- /dev/null +++ b/pkg/qqbot/cntfilter/entities.py @@ -0,0 +1,64 @@ + +import typing +import enum + +import pydantic + + +class ResultLevel(enum.Enum): + """结果等级""" + PASS = enum.auto() + """通过""" + + WARN = enum.auto() + """警告""" + + MASKED = enum.auto() + """已掩去""" + + BLOCK = enum.auto() + """阻止""" + + +class EnableStage(enum.Enum): + """启用阶段""" + PRE = enum.auto() + """预处理""" + + POST = enum.auto() + """后处理""" + + +class FilterResult(pydantic.BaseModel): + level: ResultLevel + + replacement: str + """替换后的消息""" + + user_notice: str + """不通过时,用户提示消息""" + + console_notice: str + """不通过时,控制台提示消息""" + + +class ManagerResultLevel(enum.Enum): + """处理器结果等级""" + CONTINUE = enum.auto() + """继续""" + + INTERRUPT = enum.auto() + """中断""" + +class FilterManagerResult(pydantic.BaseModel): + + level: ManagerResultLevel + + replacement: str + """替换后的消息""" + + user_notice: str + """用户提示消息""" + + console_notice: str + """控制台提示消息""" diff --git a/pkg/qqbot/cntfilter/filter.py b/pkg/qqbot/cntfilter/filter.py new file mode 100644 index 00000000..4d4cd79f --- /dev/null +++ b/pkg/qqbot/cntfilter/filter.py @@ -0,0 +1,34 @@ +# 内容过滤器的抽象类 +from __future__ import annotations +import abc + +from ...boot import app +from . import entities + + +class ContentFilter(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + @property + def enable_stages(self): + """启用的阶段 + """ + return [ + entities.EnableStage.PRE, + entities.EnableStage.POST + ] + + async def initialize(self): + """初始化过滤器 + """ + pass + + @abc.abstractmethod + async def process(self, message: str) -> entities.FilterResult: + """处理消息 + """ + raise NotImplementedError diff --git a/pkg/qqbot/cntfilter/filters/__init__.py b/pkg/qqbot/cntfilter/filters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/cntfilter/filters/baiduexamine.py b/pkg/qqbot/cntfilter/filters/baiduexamine.py new file mode 100644 index 00000000..a658897b --- /dev/null +++ b/pkg/qqbot/cntfilter/filters/baiduexamine.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import aiohttp + +from .. import entities +from .. import filter as filter_model + + +BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}" +BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" + + +class BaiduCloudExamine(filter_model.ContentFilter): + """百度云内容审核""" + + async def _get_token(self) -> str: + async with aiohttp.ClientSession() as session: + async with session.post( + BAIDU_EXAMINE_TOKEN_URL, + params={ + "grant_type": "client_credentials", + "client_id": self.ap.cfg_mgr.data['baidu_api_key'], + "client_secret": self.ap.cfg_mgr.data['baidu_secret_key'] + } + ) as resp: + return (await resp.json())['access_token'] + + async def process(self, message: str) -> entities.FilterResult: + + async with aiohttp.ClientSession() as session: + async with session.post( + BAIDU_EXAMINE_URL.format(await self._get_token()), + headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'}, + data=f"text={message}".encode('utf-8') + ) as resp: + result = await resp.json() + + if "error_code" in result: + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement=message, + user_notice='', + console_notice=f"百度云判定出错,错误信息:{result['error_msg']}" + ) + else: + conclusion = result["conclusion"] + + if conclusion in ("合规"): + return entities.FilterResult( + level=entities.ResultLevel.PASS, + replacement=message, + user_notice='', + console_notice=f"百度云判定结果:{conclusion}" + ) + else: + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement=message, + user_notice=self.ap.cfg_mgr.data['inappropriate_message_tips'], + console_notice=f"百度云判定结果:{conclusion}" + ) \ No newline at end of file diff --git a/pkg/qqbot/cntfilter/filters/banwords.py b/pkg/qqbot/cntfilter/filters/banwords.py new file mode 100644 index 00000000..9451c7b8 --- /dev/null +++ b/pkg/qqbot/cntfilter/filters/banwords.py @@ -0,0 +1,44 @@ +from __future__ import annotations +import re + +from .. import filter as filter_model +from .. import entities +from ....config import manager as cfg_mgr + + +class BanWordFilter(filter_model.ContentFilter): + """根据内容禁言""" + + sensitive: cfg_mgr.ConfigManager + + async def initialize(self): + self.sensitive = await cfg_mgr.load_json_config( + "sensitive.json", + "res/templates/sensitive-template.json" + ) + + async def process(self, message: str) -> entities.FilterResult: + found = False + + for word in self.sensitive.data['words']: + match = re.findall(word, message) + + if len(match) > 0: + found = True + + for i in range(len(match)): + if self.sensitive.data['mask_word'] == "": + message = message.replace( + match[i], self.sensitive.data['mask'] * len(match[i]) + ) + else: + message = message.replace( + match[i], self.sensitive.data['mask_word'] + ) + + return entities.FilterResult( + level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS, + replacement=message, + user_notice='[bot] 消息中存在不合适的内容, 请更换措辞' if found else '', + console_notice='' + ) \ No newline at end of file diff --git a/pkg/qqbot/cntfilter/filters/cntignore.py b/pkg/qqbot/cntfilter/filters/cntignore.py new file mode 100644 index 00000000..81408868 --- /dev/null +++ b/pkg/qqbot/cntfilter/filters/cntignore.py @@ -0,0 +1,43 @@ +from __future__ import annotations +import re + +from .. import entities +from .. import filter as filter_model + + +class ContentIgnore(filter_model.ContentFilter): + """根据内容忽略消息""" + + @property + def enable_stages(self): + return [ + entities.EnableStage.PRE, + ] + + async def process(self, message: str) -> entities.FilterResult: + if 'prefix' in self.ap.cfg_mgr.data['ignore_rules']: + for rule in self.ap.cfg_mgr.data['ignore_rules']['prefix']: + if message.startswith(rule): + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement='', + user_notice='', + console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息' + ) + + if 'regexp' in self.ap.cfg_mgr.data['ignore_rules']: + for rule in self.ap.cfg_mgr.data['ignore_rules']['regexp']: + if re.search(rule, message): + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement='', + user_notice='', + console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息' + ) + + return entities.FilterResult( + level=entities.ResultLevel.PASS, + replacement=message, + user_notice='', + console_notice='' + ) \ No newline at end of file diff --git a/pkg/qqbot/filter.py b/pkg/qqbot/filter.py deleted file mode 100644 index c3a58093..00000000 --- a/pkg/qqbot/filter.py +++ /dev/null @@ -1,87 +0,0 @@ -# 敏感词过滤模块 -import re -import requests -import json -import logging - -from ..utils import context - - -class ReplyFilter: - sensitive_words = [] - mask = "*" - mask_word = "" - - # 默认值( 兼容性考虑 ) - baidu_check = False - baidu_api_key = "" - baidu_secret_key = "" - inappropriate_message_tips = "[百度云]请珍惜机器人,当前返回内容不合规" - - def __init__(self, sensitive_words: list, mask: str = "*", mask_word: str = ""): - self.sensitive_words = sensitive_words - self.mask = mask - self.mask_word = mask_word - - config = context.get_config_manager().data - - self.baidu_check = config['baidu_check'] - self.baidu_api_key = config['baidu_api_key'] - self.baidu_secret_key = config['baidu_secret_key'] - self.inappropriate_message_tips = config['inappropriate_message_tips'] - - def is_illegal(self, message: str) -> bool: - processed = self.process(message) - if processed != message: - return True - return False - - def process(self, message: str) -> str: - - # 本地关键词屏蔽 - for word in self.sensitive_words: - match = re.findall(word, message) - if len(match) > 0: - for i in range(len(match)): - if self.mask_word == "": - message = message.replace(match[i], self.mask * len(match[i])) - else: - message = message.replace(match[i], self.mask_word) - - # 百度云审核 - if self.baidu_check: - - # 百度云审核URL - baidu_url = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token=" + \ - str(requests.post("https://aip.baidubce.com/oauth/2.0/token", - params={"grant_type": "client_credentials", - "client_id": self.baidu_api_key, - "client_secret": self.baidu_secret_key}).json().get("access_token")) - - # 百度云审核 - payload = "text=" + message - logging.info("向百度云发送:" + payload) - headers = {'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'} - - if isinstance(payload, str): - payload = payload.encode('utf-8') - - response = requests.request("POST", baidu_url, headers=headers, data=payload) - response_dict = json.loads(response.text) - - if "error_code" in response_dict: - error_msg = response_dict.get("error_msg") - logging.warning(f"百度云判定出错,错误信息:{error_msg}") - conclusion = f"百度云判定出错,错误信息:{error_msg}\n以下是原消息:{message}" - else: - conclusion = response_dict["conclusion"] - if conclusion in ("合规"): - logging.info(f"百度云判定结果:{conclusion}") - return message - else: - logging.warning(f"百度云判定结果:{conclusion}") - conclusion = self.inappropriate_message_tips - # 返回百度云审核结果 - return conclusion - - return message diff --git a/pkg/qqbot/ignore.py b/pkg/qqbot/ignore.py deleted file mode 100644 index e1adc777..00000000 --- a/pkg/qqbot/ignore.py +++ /dev/null @@ -1,18 +0,0 @@ -import re - -from ..utils import context - - -def ignore(msg: str) -> bool: - """检查消息是否应该被忽略""" - config = context.get_config_manager().data - - if 'prefix' in config['ignore_rules']: - for rule in config['ignore_rules']['prefix']: - if msg.startswith(rule): - return True - - if 'regexp' in config['ignore_rules']: - for rule in config['ignore_rules']['regexp']: - if re.search(rule, msg): - return True diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index bfe86b9c..a801229d 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -12,7 +12,6 @@ import func_timeout from ..openai import session as openai_session -from ..qqbot import filter as qqbot_filter from ..qqbot import process as processor from ..utils import context from ..plugin import host as plugin_host @@ -21,6 +20,7 @@ import tips as tips_custom from ..qqbot import adapter as msadapter from . import resprule from .bansess import bansess +from .cntfilter import cntfilter from ..boot import app @@ -33,8 +33,6 @@ class QQBotManager: bot_account_id: int = 0 - reply_filter = None - enable_banlist = False enable_private = True @@ -47,18 +45,21 @@ class QQBotManager: ap: app.Application = None bansess_mgr: bansess.SessionBanManager = None + cntfilter_mgr: cntfilter.ContentFilterManager = None def __init__(self, first_time_init=True, ap: app.Application = None): config = context.get_config_manager().data self.ap = ap self.bansess_mgr = bansess.SessionBanManager(ap) + self.cntfilter_mgr = cntfilter.ContentFilterManager(ap) self.timeout = config['process_message_timeout'] self.retry = config['retry_times'] async def initialize(self): await self.bansess_mgr.initialize() + await self.cntfilter_mgr.initialize() config = context.get_config_manager().data @@ -174,20 +175,6 @@ class QQBotManager: self.unsubscribe_all = unsubscribe_all - config = context.get_config_manager().data - if os.path.exists("sensitive.json") \ - and config['sensitive_word_filter'] is not None \ - and config['sensitive_word_filter']: - with open("sensitive.json", "r", encoding="utf-8") as f: - sensitive_json = json.load(f) - self.reply_filter = qqbot_filter.ReplyFilter( - sensitive_words=sensitive_json['words'], - mask=sensitive_json['mask'] if 'mask' in sensitive_json else '*', - mask_word=sensitive_json['mask_word'] if 'mask_word' in sensitive_json else '' - ) - else: - self.reply_filter = qqbot_filter.ReplyFilter([]) - async def send(self, event, msg, check_quote=True, check_at_sender=True): config = context.get_config_manager().data diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index aa02315f..fedaddbd 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -14,10 +14,10 @@ from ..utils import context from ..plugin import host as plugin_host from ..plugin import models as plugin_models -from ..qqbot import ignore from ..qqbot import blob import tips as tips_custom from ..boot import app +from .cntfilter import entities processing = [] @@ -32,7 +32,7 @@ def is_admin(qq: int) -> bool: async def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: mirai.MessageChain, - sender_id: int) -> mirai.MessageChain: + sender_id: int) -> list: global processing mgr = context.get_qqbot_manager() @@ -40,14 +40,10 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st reply = [] session_name = "{}_{}".format(launcher_type, launcher_id) - if ignore.ignore(text_message): - logging.info("根据忽略规则忽略消息: {}".format(text_message)) - return [] - config = context.get_config_manager().data if not config['wait_last_done'] and session_name in processing: - return mirai.MessageChain([mirai.Plain(tips_custom.message_drop_tip)]) + return [mirai.Plain(tips_custom.message_drop_tip)] # 检查是否被禁言 if launcher_type == 'group': @@ -56,9 +52,14 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st logging.info("机器人被禁言,跳过消息处理(group_{})".format(launcher_id)) return reply - if config['income_msg_check']: - if mgr.reply_filter.is_illegal(text_message): - return mirai.MessageChain(mirai.Plain("[bot] 消息中存在不合适的内容, 请更换措辞")) + cntfilter_res = await mgr.cntfilter_mgr.pre_process(text_message) + if cntfilter_res.level == entities.ManagerResultLevel.INTERRUPT: + if cntfilter_res.console_notice: + mgr.ap.logger.info(cntfilter_res.console_notice) + if cntfilter_res.user_notice: + return [mirai.Plain(cntfilter_res.user_notice)] + else: + return [] openai_session.get_session(session_name).acquire_response_lock() @@ -147,7 +148,16 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st reply[0][:min(100, len(reply[0]))] + ( "..." if len(reply[0]) > 100 else ""))) if msg_type == 'message': - reply = [mgr.reply_filter.process(reply[0])] + cntfilter_res = await mgr.cntfilter_mgr.post_process(reply[0]) + if cntfilter_res.level == entities.ManagerResultLevel.INTERRUPT: + if cntfilter_res.console_notice: + mgr.ap.logger.info(cntfilter_res.console_notice) + if cntfilter_res.user_notice: + return [mirai.Plain(cntfilter_res.user_notice)] + else: + return [] + else: + reply = [cntfilter_res.replacement] reply = blob.check_text(reply[0]) else: From ea9ae854283b8851a2c30013ecc86e9b92c42456 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 25 Jan 2024 17:05:09 +0800 Subject: [PATCH 08/10] =?UTF-8?q?refactor:=20=E7=8B=AC=E7=AB=8B=E9=95=BF?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E5=A4=84=E7=90=86=E4=B8=BAlongtext=E5=8C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/boot/boot.py | 5 - pkg/qqbot/blob.py | 100 ----------- pkg/qqbot/longtext/__init__.py | 0 pkg/qqbot/longtext/longtext.py | 56 ++++++ pkg/qqbot/longtext/strategies/__init__.py | 0 pkg/qqbot/longtext/strategies/forward.py | 62 +++++++ pkg/qqbot/longtext/strategies/image.py | 197 ++++++++++++++++++++ pkg/qqbot/longtext/strategy.py | 22 +++ pkg/qqbot/manager.py | 5 + pkg/qqbot/process.py | 5 +- pkg/utils/text2img.py | 208 ---------------------- 11 files changed, 344 insertions(+), 316 deletions(-) delete mode 100644 pkg/qqbot/blob.py create mode 100644 pkg/qqbot/longtext/__init__.py create mode 100644 pkg/qqbot/longtext/longtext.py create mode 100644 pkg/qqbot/longtext/strategies/__init__.py create mode 100644 pkg/qqbot/longtext/strategies/forward.py create mode 100644 pkg/qqbot/longtext/strategies/image.py create mode 100644 pkg/qqbot/longtext/strategy.py delete mode 100644 pkg/utils/text2img.py diff --git a/pkg/boot/boot.py b/pkg/boot/boot.py index a2b2d7c7..b7cc0f38 100644 --- a/pkg/boot/boot.py +++ b/pkg/boot/boot.py @@ -71,11 +71,6 @@ async def make_app() -> app.Application: "tips-custom-template.py" ) - # 初始化文字转图片 - from pkg.utils import text2img - # TODO make it async - text2img.initialize() - # 检查管理员QQ号 if cfg_mgr.data['admin_qq'] == 0: qcg_logger.warning("未设置管理员QQ号,将无法使用管理员命令,请在 config.py 中修改 admin_qq") diff --git a/pkg/qqbot/blob.py b/pkg/qqbot/blob.py deleted file mode 100644 index d8373cd8..00000000 --- a/pkg/qqbot/blob.py +++ /dev/null @@ -1,100 +0,0 @@ -# 长消息处理相关 -import os -import time -import base64 -import typing - -from mirai.models.message import MessageComponent, MessageChain, Image -from mirai.models.message import ForwardMessageNode -from mirai.models.base import MiraiBaseModel - -from ..utils import text2img -from ..utils import context - - -class ForwardMessageDiaplay(MiraiBaseModel): - title: str = "群聊的聊天记录" - brief: str = "[聊天记录]" - source: str = "聊天记录" - preview: typing.List[str] = [] - summary: str = "查看x条转发消息" - - -class Forward(MessageComponent): - """合并转发。""" - type: str = "Forward" - """消息组件类型。""" - display: ForwardMessageDiaplay - """显示信息""" - node_list: typing.List[ForwardMessageNode] - """转发消息节点列表。""" - def __init__(self, *args, **kwargs): - if len(args) == 1: - self.node_list = args[0] - super().__init__(**kwargs) - super().__init__(*args, **kwargs) - - def __str__(self): - return '[聊天记录]' - - -def text_to_image(text: str) -> MessageComponent: - """将文本转换成图片""" - # 检查temp文件夹是否存在 - if not os.path.exists('temp'): - os.mkdir('temp') - img_path = text2img.text_to_image(text_str=text, save_as='temp/{}.png'.format(int(time.time()))) - - compressed_path, size = text2img.compress_image(img_path, outfile="temp/{}_compressed.png".format(int(time.time()))) - # 读取图片,转换成base64 - with open(compressed_path, 'rb') as f: - img = f.read() - - b64 = base64.b64encode(img) - - # 删除图片 - os.remove(img_path) - - # 判断compressed_path是否存在 - if os.path.exists(compressed_path): - os.remove(compressed_path) - # 返回图片 - return Image(base64=b64.decode('utf-8')) - - -def check_text(text: str) -> list: - """检查文本是否为长消息,并转换成该使用的消息链组件""" - - config = context.get_config_manager().data - - if len(text) > config['blob_message_threshold']: - - # logging.info("长消息: {}".format(text)) - if config['blob_message_strategy'] == 'image': - # 转换成图片 - return [text_to_image(text)] - elif config['blob_message_strategy'] == 'forward': - - # 包装转发消息 - display = ForwardMessageDiaplay( - title='群聊的聊天记录', - brief='[聊天记录]', - source='聊天记录', - preview=["bot: "+text], - summary="查看1条转发消息" - ) - - node = ForwardMessageNode( - sender_id=config['mirai_http_api_config']['qq'], - sender_name='bot', - message_chain=MessageChain([text]) - ) - - forward = Forward( - display=display, - node_list=[node] - ) - - return [forward] - else: - return [text] \ No newline at end of file diff --git a/pkg/qqbot/longtext/__init__.py b/pkg/qqbot/longtext/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/longtext/longtext.py b/pkg/qqbot/longtext/longtext.py new file mode 100644 index 00000000..21267880 --- /dev/null +++ b/pkg/qqbot/longtext/longtext.py @@ -0,0 +1,56 @@ +from __future__ import annotations +import os +import traceback + +from PIL import Image, ImageDraw, ImageFont +from mirai.models.message import MessageComponent, Plain + +from ...boot import app +from . import strategy +from .strategies import image, forward + + +class LongTextProcessor: + + ap: app.Application + + strategy_impl: strategy.LongTextStrategy + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + config = self.ap.cfg_mgr.data + if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image': + use_font = config['font_path'] + try: + # 检查是否存在 + if not os.path.exists(use_font): + # 若是windows系统,使用微软雅黑 + if os.name == "nt": + use_font = "C:/Windows/Fonts/msyh.ttc" + if not os.path.exists(use_font): + self.ap.logger.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") + config['blob_message_strategy'] = "forward" + else: + self.ap.logger.info("使用Windows自带字体:" + use_font) + self.ap.cfg_mgr.data['font_path'] = use_font + else: + self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") + self.ap.cfg_mgr.data['blob_message_strategy'] = "forward" + except: + traceback.print_exc() + self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) + self.ap.cfg_mgr.data['blob_message_strategy'] = "forward" + + if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image': + self.strategy_impl = image.Text2ImageStrategy(self.ap) + elif self.ap.cfg_mgr.data['blob_message_strategy'] == 'forward': + self.strategy_impl = forward.ForwardComponentStrategy(self.ap) + await self.strategy_impl.initialize() + + async def check_and_process(self, message: str) -> list[MessageComponent]: + if len(message) > self.ap.cfg_mgr.data['blob_message_threshold']: + return await self.strategy_impl.process(message) + else: + return [Plain(message)] \ No newline at end of file diff --git a/pkg/qqbot/longtext/strategies/__init__.py b/pkg/qqbot/longtext/strategies/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/longtext/strategies/forward.py b/pkg/qqbot/longtext/strategies/forward.py new file mode 100644 index 00000000..d1b5c36c --- /dev/null +++ b/pkg/qqbot/longtext/strategies/forward.py @@ -0,0 +1,62 @@ +# 转发消息组件 +from __future__ import annotations +import typing + +from mirai.models import MessageChain +from mirai.models.message import MessageComponent, ForwardMessageNode +from mirai.models.base import MiraiBaseModel + +from .. import strategy as strategy_model + + +class ForwardMessageDiaplay(MiraiBaseModel): + title: str = "群聊的聊天记录" + brief: str = "[聊天记录]" + source: str = "聊天记录" + preview: typing.List[str] = [] + summary: str = "查看x条转发消息" + + +class Forward(MessageComponent): + """合并转发。""" + type: str = "Forward" + """消息组件类型。""" + display: ForwardMessageDiaplay + """显示信息""" + node_list: typing.List[ForwardMessageNode] + """转发消息节点列表。""" + def __init__(self, *args, **kwargs): + if len(args) == 1: + self.node_list = args[0] + super().__init__(**kwargs) + super().__init__(*args, **kwargs) + + def __str__(self): + return '[聊天记录]' + + +class ForwardComponentStrategy(strategy_model.LongTextStrategy): + + async def process(self, message: str) -> list[MessageComponent]: + display = ForwardMessageDiaplay( + title="群聊的聊天记录", + brief="[聊天记录]", + source="聊天记录", + preview=["QQ用户: "+message], + summary="查看1条转发消息" + ) + + node_list = [ + ForwardMessageNode( + sender_id=self.ap.im_mgr.bot_account_id, + sender_name='QQ用户', + message_chain=MessageChain([message]) + ) + ] + + forward = Forward( + display=display, + node_list=node_list + ) + + return [forward] diff --git a/pkg/qqbot/longtext/strategies/image.py b/pkg/qqbot/longtext/strategies/image.py new file mode 100644 index 00000000..4f789098 --- /dev/null +++ b/pkg/qqbot/longtext/strategies/image.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import typing +import os +import base64 +import time +import re + +from PIL import Image, ImageDraw, ImageFont + +from mirai.models import MessageChain, Image as ImageComponent +from mirai.models.message import MessageComponent + +from .. import strategy as strategy_model + + +class Text2ImageStrategy(strategy_model.LongTextStrategy): + + text_render_font: ImageFont.FreeTypeFont + + async def initialize(self): + self.text_render_font = ImageFont.truetype(self.ap.cfg_mgr.data['font_path'], 32, encoding="utf-8") + + async def process(self, message: str) -> list[MessageComponent]: + img_path = self.text_to_image( + text_str=message, + save_as='temp/{}.png'.format(int(time.time())) + ) + + compressed_path, size = self.compress_image( + img_path, + outfile="temp/{}_compressed.png".format(int(time.time())) + ) + + with open(compressed_path, 'rb') as f: + img = f.read() + + b64 = base64.b64encode(img) + + # 删除图片 + os.remove(img_path) + + if os.path.exists(compressed_path): + os.remove(compressed_path) + + return [ + ImageComponent( + base64=b64.decode('utf-8'), + ) + ] + + def indexNumber(self, path=''): + """ + 查找字符串中数字所在串中的位置 + :param path:目标字符串 + :return:: : [['1', 16], ['2', 35], ['1', 51]] + """ + kv = [] + nums = [] + beforeDatas = re.findall('[\d]+', path) + for num in beforeDatas: + indexV = [] + times = path.count(num) + if times > 1: + if num not in nums: + indexs = re.finditer(num, path) + for index in indexs: + iV = [] + i = index.span()[0] + iV.append(num) + iV.append(i) + kv.append(iV) + nums.append(num) + else: + index = path.find(num) + indexV.append(num) + indexV.append(index) + kv.append(indexV) + # 根据数字位置排序 + indexSort = [] + resultIndex = [] + for vi in kv: + indexSort.append(vi[1]) + indexSort.sort() + for i in indexSort: + for v in kv: + if i == v[1]: + resultIndex.append(v) + return resultIndex + + + def get_size(self, file): + # 获取文件大小:KB + size = os.path.getsize(file) + return size / 1024 + + + def get_outfile(self, infile, outfile): + if outfile: + return outfile + dir, suffix = os.path.splitext(infile) + outfile = '{}-out{}'.format(dir, suffix) + return outfile + + + def compress_image(self, infile, outfile='', kb=100, step=20, quality=90): + """不改变图片尺寸压缩到指定大小 + :param infile: 压缩源文件 + :param outfile: 压缩文件保存地址 + :param mb: 压缩目标,KB + :param step: 每次调整的压缩比率 + :param quality: 初始压缩比率 + :return: 压缩文件地址,压缩文件大小 + """ + o_size = self.get_size(infile) + if o_size <= kb: + return infile, o_size + outfile = self.get_outfile(infile, outfile) + while o_size > kb: + im = Image.open(infile) + im.save(outfile, quality=quality) + if quality - step < 0: + break + quality -= step + o_size = self.get_size(outfile) + return outfile, self.get_size(outfile) + + + def text_to_image(self, text_str: str, save_as="temp.png", width=800): + + text_str = text_str.replace("\t", " ") + + # 分行 + lines = text_str.split('\n') + + # 计算并分割 + final_lines = [] + + text_width = width-80 + + self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width)) + for line in lines: + # 如果长了就分割 + line_width = self.text_render_font.getlength(line) + self.ap.logger.debug("line_width: {}".format(line_width)) + if line_width < text_width: + final_lines.append(line) + continue + else: + rest_text = line + while True: + # 分割最前面的一行 + point = int(len(rest_text) * (text_width / line_width)) + + # 检查断点是否在数字中间 + numbers = self.indexNumber(rest_text) + + for number in numbers: + if number[1] < point < number[1] + len(number[0]) and number[1] != 0: + point = number[1] + break + + final_lines.append(rest_text[:point]) + rest_text = rest_text[point:] + line_width = self.text_render_font.getlength(rest_text) + if line_width < text_width: + final_lines.append(rest_text) + break + else: + continue + # 准备画布 + img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)) + draw = ImageDraw.Draw(img, mode='RGBA') + + self.ap.logger.debug("正在绘制图片...") + # 绘制正文 + line_number = 0 + offset_x = 20 + offset_y = 30 + for final_line in final_lines: + draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=self.text_render_font) + # 遍历此行,检查是否有emoji + idx_in_line = 0 + for ch in final_line: + # 检查字符占位宽 + char_code = ord(ch) + if char_code >= 127: + idx_in_line += 1 + else: + idx_in_line += 0.5 + + line_number += 1 + + self.ap.logger.debug("正在保存图片...") + img.save(save_as) + + return save_as diff --git a/pkg/qqbot/longtext/strategy.py b/pkg/qqbot/longtext/strategy.py new file mode 100644 index 00000000..ef4cc1a5 --- /dev/null +++ b/pkg/qqbot/longtext/strategy.py @@ -0,0 +1,22 @@ +from __future__ import annotations +import abc +import typing + +import mirai +from mirai.models.message import MessageComponent + +from ...boot import app + + +class LongTextStrategy(metaclass=abc.ABCMeta): + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def process(self, message: str) -> list[MessageComponent]: + return [] diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index a801229d..7c295f9d 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -21,6 +21,7 @@ from ..qqbot import adapter as msadapter from . import resprule from .bansess import bansess from .cntfilter import cntfilter +from .longtext import longtext from ..boot import app @@ -46,6 +47,7 @@ class QQBotManager: bansess_mgr: bansess.SessionBanManager = None cntfilter_mgr: cntfilter.ContentFilterManager = None + longtext_pcs: longtext.LongTextProcessor = None def __init__(self, first_time_init=True, ap: app.Application = None): config = context.get_config_manager().data @@ -53,6 +55,7 @@ class QQBotManager: self.ap = ap self.bansess_mgr = bansess.SessionBanManager(ap) self.cntfilter_mgr = cntfilter.ContentFilterManager(ap) + self.longtext_pcs = longtext.LongTextProcessor(ap) self.timeout = config['process_message_timeout'] self.retry = config['retry_times'] @@ -60,6 +63,7 @@ class QQBotManager: async def initialize(self): await self.bansess_mgr.initialize() await self.cntfilter_mgr.initialize() + await self.longtext_pcs.initialize() config = context.get_config_manager().data @@ -149,6 +153,7 @@ class QQBotManager: await self.on_group_message(event) asyncio.create_task(group_message_handler(event)) + self.adapter.register_listener( GroupMessage, on_group_message diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index fedaddbd..f6379c71 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -14,7 +14,6 @@ from ..utils import context from ..plugin import host as plugin_host from ..plugin import models as plugin_models -from ..qqbot import blob import tips as tips_custom from ..boot import app from .cntfilter import entities @@ -158,8 +157,8 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st return [] else: reply = [cntfilter_res.replacement] - - reply = blob.check_text(reply[0]) + + reply = await mgr.longtext_pcs.check_and_process(reply[0]) else: logging.info("回复[{}]消息".format(session_name)) diff --git a/pkg/utils/text2img.py b/pkg/utils/text2img.py deleted file mode 100644 index 5be723ed..00000000 --- a/pkg/utils/text2img.py +++ /dev/null @@ -1,208 +0,0 @@ -import logging -import re -import os -import traceback - -from PIL import Image, ImageDraw, ImageFont - -from ..utils import context - - -text_render_font: ImageFont = None - -def initialize(): - global text_render_font - logging.debug("初始化文字转图片模块...") - config = context.get_config_manager().data - - if config['blob_message_strategy'] == "image": # 仅在启用了image时才加载字体 - use_font = config['font_path'] - try: - - # 检查是否存在 - if not os.path.exists(use_font): - # 若是windows系统,使用微软雅黑 - if os.name == "nt": - use_font = "C:/Windows/Fonts/msyh.ttc" - if not os.path.exists(use_font): - logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") - config['blob_message_strategy'] = "forward" - else: - logging.info("使用Windows自带字体:" + use_font) - text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") - else: - logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") - config['blob_message_strategy'] = "forward" - else: - text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") - except: - traceback.print_exc() - logging.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) - config['blob_message_strategy'] = "forward" - - logging.debug("字体文件加载完成。") - - -def indexNumber(path=''): - """ - 查找字符串中数字所在串中的位置 - :param path:目标字符串 - :return:: : [['1', 16], ['2', 35], ['1', 51]] - """ - kv = [] - nums = [] - beforeDatas = re.findall('[\d]+', path) - for num in beforeDatas: - indexV = [] - times = path.count(num) - if times > 1: - if num not in nums: - indexs = re.finditer(num, path) - for index in indexs: - iV = [] - i = index.span()[0] - iV.append(num) - iV.append(i) - kv.append(iV) - nums.append(num) - else: - index = path.find(num) - indexV.append(num) - indexV.append(index) - kv.append(indexV) - # 根据数字位置排序 - indexSort = [] - resultIndex = [] - for vi in kv: - indexSort.append(vi[1]) - indexSort.sort() - for i in indexSort: - for v in kv: - if i == v[1]: - resultIndex.append(v) - return resultIndex - - -def get_size(file): - # 获取文件大小:KB - size = os.path.getsize(file) - return size / 1024 - - -def get_outfile(infile, outfile): - if outfile: - return outfile - dir, suffix = os.path.splitext(infile) - outfile = '{}-out{}'.format(dir, suffix) - return outfile - - -def compress_image(infile, outfile='', kb=100, step=20, quality=90): - """不改变图片尺寸压缩到指定大小 - :param infile: 压缩源文件 - :param outfile: 压缩文件保存地址 - :param mb: 压缩目标,KB - :param step: 每次调整的压缩比率 - :param quality: 初始压缩比率 - :return: 压缩文件地址,压缩文件大小 - """ - o_size = get_size(infile) - if o_size <= kb: - return infile, o_size - outfile = get_outfile(infile, outfile) - while o_size > kb: - im = Image.open(infile) - im.save(outfile, quality=quality) - if quality - step < 0: - break - quality -= step - o_size = get_size(outfile) - return outfile, get_size(outfile) - - -def text_to_image(text_str: str, save_as="temp.png", width=800): - global text_render_font - - logging.debug("正在将文本转换为图片...") - - text_str = text_str.replace("\t", " ") - - # 分行 - lines = text_str.split('\n') - - # 计算并分割 - final_lines = [] - - text_width = width-80 - - logging.debug("lines: {}, text_width: {}".format(lines, text_width)) - for line in lines: - logging.debug(type(text_render_font)) - # 如果长了就分割 - line_width = text_render_font.getlength(line) - logging.debug("line_width: {}".format(line_width)) - if line_width < text_width: - final_lines.append(line) - continue - else: - rest_text = line - while True: - # 分割最前面的一行 - point = int(len(rest_text) * (text_width / line_width)) - - # 检查断点是否在数字中间 - numbers = indexNumber(rest_text) - - for number in numbers: - if number[1] < point < number[1] + len(number[0]) and number[1] != 0: - point = number[1] - break - - final_lines.append(rest_text[:point]) - rest_text = rest_text[point:] - line_width = text_render_font.getlength(rest_text) - if line_width < text_width: - final_lines.append(rest_text) - break - else: - continue - # 准备画布 - img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)) - draw = ImageDraw.Draw(img, mode='RGBA') - - logging.debug("正在绘制图片...") - # 绘制正文 - line_number = 0 - offset_x = 20 - offset_y = 30 - for final_line in final_lines: - draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=text_render_font) - # 遍历此行,检查是否有emoji - idx_in_line = 0 - for ch in final_line: - # if self.is_emoji(ch): - # emoji_img_valid = ensure_emoji(hex(ord(ch))[2:]) - # if emoji_img_valid: # emoji图像可用,绘制到指定位置 - # emoji_image = Image.open("emojis/{}.png".format(hex(ord(ch))[2:]), mode='r').convert('RGBA') - # emoji_image = emoji_image.resize((32, 32)) - - # x, y = emoji_image.size - - # final_emoji_img = Image.new('RGBA', emoji_image.size, (255, 255, 255)) - # final_emoji_img.paste(emoji_image, (0, 0, x, y), emoji_image) - - # img.paste(final_emoji_img, box=(int(offset_x + idx_in_line * 32), offset_y + 35 * line_number)) - - # 检查字符占位宽 - char_code = ord(ch) - if char_code >= 127: - idx_in_line += 1 - else: - idx_in_line += 0.5 - - line_number += 1 - - logging.debug("正在保存图片...") - img.save(save_as) - - return save_as From f4ead5ec5cae8e22b556fabf3f10dc656f24079c Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 25 Jan 2024 18:07:28 +0800 Subject: [PATCH 09/10] =?UTF-8?q?refactor:=20=E7=8B=AC=E7=AB=8Bresprule?= =?UTF-8?q?=E4=B8=BA=E5=8D=95=E7=8B=AC=E7=9A=84=E5=8C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/qqbot/manager.py | 74 ++++++++++------------------ pkg/qqbot/resprule.py | 67 ------------------------- pkg/qqbot/resprule/__init__.py | 0 pkg/qqbot/resprule/entities.py | 10 ++++ pkg/qqbot/resprule/resprule.py | 58 ++++++++++++++++++++++ pkg/qqbot/resprule/rule.py | 31 ++++++++++++ pkg/qqbot/resprule/rules/__init__.py | 0 pkg/qqbot/resprule/rules/atbot.py | 28 +++++++++++ pkg/qqbot/resprule/rules/prefix.py | 29 +++++++++++ pkg/qqbot/resprule/rules/random.py | 22 +++++++++ pkg/qqbot/resprule/rules/regexp.py | 31 ++++++++++++ 11 files changed, 236 insertions(+), 114 deletions(-) delete mode 100644 pkg/qqbot/resprule.py create mode 100644 pkg/qqbot/resprule/__init__.py create mode 100644 pkg/qqbot/resprule/entities.py create mode 100644 pkg/qqbot/resprule/resprule.py create mode 100644 pkg/qqbot/resprule/rule.py create mode 100644 pkg/qqbot/resprule/rules/__init__.py create mode 100644 pkg/qqbot/resprule/rules/atbot.py create mode 100644 pkg/qqbot/resprule/rules/prefix.py create mode 100644 pkg/qqbot/resprule/rules/random.py create mode 100644 pkg/qqbot/resprule/rules/regexp.py diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 7c295f9d..a973ab6d 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -18,7 +18,7 @@ from ..plugin import host as plugin_host from ..plugin import models as plugin_models import tips as tips_custom from ..qqbot import adapter as msadapter -from . import resprule +from .resprule import resprule from .bansess import bansess from .cntfilter import cntfilter from .longtext import longtext @@ -34,11 +34,6 @@ class QQBotManager: bot_account_id: int = 0 - enable_banlist = False - - enable_private = True - enable_group = True - ban_person = [] ban_group = [] @@ -48,6 +43,7 @@ class QQBotManager: bansess_mgr: bansess.SessionBanManager = None cntfilter_mgr: cntfilter.ContentFilterManager = None longtext_pcs: longtext.LongTextProcessor = None + resprule_chkr: resprule.GroupRespondRuleChecker = None def __init__(self, first_time_init=True, ap: app.Application = None): config = context.get_config_manager().data @@ -56,6 +52,7 @@ class QQBotManager: self.bansess_mgr = bansess.SessionBanManager(ap) self.cntfilter_mgr = cntfilter.ContentFilterManager(ap) self.longtext_pcs = longtext.LongTextProcessor(ap) + self.resprule_chkr = resprule.GroupRespondRuleChecker(ap) self.timeout = config['process_message_timeout'] self.retry = config['retry_times'] @@ -64,6 +61,7 @@ class QQBotManager: await self.bansess_mgr.initialize() await self.cntfilter_mgr.initialize() await self.longtext_pcs.initialize() + await self.resprule_chkr.initialize() config = context.get_config_manager().data @@ -251,17 +249,13 @@ class QQBotManager: async def on_person_message(self, event: MessageEvent): reply = '' - if not self.enable_private: - logging.debug("已在banlist.py中禁用所有私聊") - - else: - reply = await self.common_process( - launcher_type="person", - launcher_id=event.sender.id, - text_message=str(event.message_chain), - message_chain=event.message_chain, - sender_id=event.sender.id - ) + reply = await self.common_process( + launcher_type="person", + launcher_id=event.sender.id, + text_message=str(event.message_chain), + message_chain=event.message_chain, + sender_id=event.sender.id + ) if reply: await self.send(event, reply, check_quote=False, check_at_sender=False) @@ -269,39 +263,25 @@ class QQBotManager: # 群消息处理 async def on_group_message(self, event: GroupMessage): reply = '' - - if not self.enable_group: - logging.debug("已在banlist.py中禁用所有群聊") - else: - do_req = False - text = str(event.message_chain).strip() - if At(self.bot_account_id) in event.message_chain and resprule.response_at(event.group.id): - # 直接调用 - # reply = await process() - event.message_chain.remove(At(self.bot_account_id)) - text = str(event.message_chain).strip() - do_req = True - else: - check, result = resprule.check_response_rule(event.group.id, str(event.message_chain).strip()) + text = str(event.message_chain).strip() - if check: - do_req = True - text = result.strip() - # 检查是否随机响应 - elif resprule.random_responding(event.group.id): - logging.info("随机响应group_{}消息".format(event.group.id)) - # reply = await process() - do_req = True + rule_check_res = await self.resprule_chkr.check( + text, + event.message_chain, + event.group.id, + event.sender.id + ) - if do_req: - reply = await self.common_process( - launcher_type="group", - launcher_id=event.group.id, - text_message=text, - message_chain=event.message_chain, - sender_id=event.sender.id - ) + if rule_check_res.matching: + text = str(rule_check_res.replacement).strip() + reply = await self.common_process( + launcher_type="group", + launcher_id=event.group.id, + text_message=text, + message_chain=rule_check_res.replacement, + sender_id=event.sender.id + ) if reply: await self.send(event, reply) diff --git a/pkg/qqbot/resprule.py b/pkg/qqbot/resprule.py deleted file mode 100644 index 5c237024..00000000 --- a/pkg/qqbot/resprule.py +++ /dev/null @@ -1,67 +0,0 @@ -from ..utils import context - - -# 检查消息是否符合泛响应匹配机制 -def check_response_rule(group_id:int, text: str): - config = context.get_config_manager().data - - rules = config['response_rules'] - - # 检查是否有特定规则 - if 'prefix' not in config['response_rules']: - if str(group_id) in config['response_rules']: - rules = config['response_rules'][str(group_id)] - else: - rules = config['response_rules']['default'] - - # 检查前缀匹配 - if 'prefix' in rules: - for rule in rules['prefix']: - if text.startswith(rule): - return True, text.replace(rule, "", 1) - - # 检查正则表达式匹配 - if 'regexp' in rules: - for rule in rules['regexp']: - import re - match = re.match(rule, text) - if match: - return True, text - - return False, "" - - -def response_at(group_id: int): - config = context.get_config_manager().data - - use_response_rule = config['response_rules'] - - # 检查是否有特定规则 - if 'prefix' not in config['response_rules']: - if str(group_id) in config['response_rules']: - use_response_rule = config['response_rules'][str(group_id)] - else: - use_response_rule = config['response_rules']['default'] - - if 'at' not in use_response_rule: - return True - - return use_response_rule['at'] - - -def random_responding(group_id): - config = context.get_config_manager().data - - use_response_rule = config['response_rules'] - - # 检查是否有特定规则 - if 'prefix' not in config['response_rules']: - if str(group_id) in config['response_rules']: - use_response_rule = config['response_rules'][str(group_id)] - else: - use_response_rule = config['response_rules']['default'] - - if 'random_rate' in use_response_rule: - import random - return random.random() < use_response_rule['random_rate'] - return False diff --git a/pkg/qqbot/resprule/__init__.py b/pkg/qqbot/resprule/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/resprule/entities.py b/pkg/qqbot/resprule/entities.py new file mode 100644 index 00000000..1cdd76f2 --- /dev/null +++ b/pkg/qqbot/resprule/entities.py @@ -0,0 +1,10 @@ +import pydantic +import mirai + + +class RuleJudgeResult(pydantic.BaseModel): + + matching: bool = False + + replacement: mirai.MessageChain = None + diff --git a/pkg/qqbot/resprule/resprule.py b/pkg/qqbot/resprule/resprule.py new file mode 100644 index 00000000..f0c51921 --- /dev/null +++ b/pkg/qqbot/resprule/resprule.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import mirai + +from ...boot import app +from . import entities, rule +from .rules import atbot, prefix, regexp, random + + +class GroupRespondRuleChecker: + """群组响应规则检查器 + """ + + ap: app.Application + + rule_matchers: list[rule.GroupRespondRule] + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + """初始化检查器 + """ + self.rule_matchers = [ + atbot.AtBotRule(self.ap), + prefix.PrefixRule(self.ap), + regexp.RegExpRule(self.ap), + random.RandomRespRule(self.ap), + ] + + for rule_matcher in self.rule_matchers: + await rule_matcher.initialize() + + async def check( + self, + message_text: str, + message_chain: mirai.MessageChain, + launcher_id: int, + sender_id: int, + ) -> entities.RuleJudgeResult: + """检查消息是否匹配规则 + """ + rules = self.ap.cfg_mgr.data['response_rules'] + + use_rule = rules['default'] + + if str(launcher_id) in use_rule: + use_rule = use_rule[str(launcher_id)] + + for rule_matcher in self.rule_matchers: + res = await rule_matcher.match(message_text, message_chain, use_rule) + if res.matching: + return res + + return entities.RuleJudgeResult( + matching=False, + replacement=message_chain + ) diff --git a/pkg/qqbot/resprule/rule.py b/pkg/qqbot/resprule/rule.py new file mode 100644 index 00000000..67af0204 --- /dev/null +++ b/pkg/qqbot/resprule/rule.py @@ -0,0 +1,31 @@ +from __future__ import annotations +import abc + +import mirai + +from ...boot import app +from . import entities + + +class GroupRespondRule(metaclass=abc.ABCMeta): + """群组响应规则的抽象类 + """ + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + """判断消息是否匹配规则 + """ + raise NotImplementedError diff --git a/pkg/qqbot/resprule/rules/__init__.py b/pkg/qqbot/resprule/rules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/resprule/rules/atbot.py b/pkg/qqbot/resprule/rules/atbot.py new file mode 100644 index 00000000..eefc4891 --- /dev/null +++ b/pkg/qqbot/resprule/rules/atbot.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import mirai + +from .. import rule as rule_model +from .. import entities + + +class AtBotRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + + if message_chain.has(mirai.At(self.ap.im_mgr.bot_account_id)) and rule_dict['at']: + message_chain.remove(mirai.At(self.ap.im_mgr.bot_account_id)) + return entities.RuleJudgeResult( + matching=True, + replacement=message_chain, + ) + + return entities.RuleJudgeResult( + matching=False, + replacement = message_chain + ) diff --git a/pkg/qqbot/resprule/rules/prefix.py b/pkg/qqbot/resprule/rules/prefix.py new file mode 100644 index 00000000..31ead5ab --- /dev/null +++ b/pkg/qqbot/resprule/rules/prefix.py @@ -0,0 +1,29 @@ +import mirai + +from .. import rule as rule_model +from .. import entities + + +class PrefixRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + prefixes = rule_dict['prefix'] + + for prefix in prefixes: + if message_text.startswith(prefix): + return entities.RuleJudgeResult( + matching=True, + replacement=mirai.MessageChain([ + mirai.Plain(message_text[len(prefix):]) + ]), + ) + + return entities.RuleJudgeResult( + matching=False, + replacement=message_chain + ) diff --git a/pkg/qqbot/resprule/rules/random.py b/pkg/qqbot/resprule/rules/random.py new file mode 100644 index 00000000..1e8354b5 --- /dev/null +++ b/pkg/qqbot/resprule/rules/random.py @@ -0,0 +1,22 @@ +import random + +import mirai + +from .. import rule as rule_model +from .. import entities + + +class RandomRespRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + random_rate = rule_dict['random_rate'] + + return entities.RuleJudgeResult( + matching=random.random() < random_rate, + replacement=message_chain + ) \ No newline at end of file diff --git a/pkg/qqbot/resprule/rules/regexp.py b/pkg/qqbot/resprule/rules/regexp.py new file mode 100644 index 00000000..0d621fe4 --- /dev/null +++ b/pkg/qqbot/resprule/rules/regexp.py @@ -0,0 +1,31 @@ +import re + +import mirai + +from .. import rule as rule_model +from .. import entities + + +class RegExpRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + regexps = rule_dict['regexp'] + + for regexp in regexps: + match = re.match(regexp, message_text) + + if match: + return entities.RuleJudgeResult( + matching=True, + replacement=message_chain, + ) + + return entities.RuleJudgeResult( + matching=False, + replacement=message_chain + ) From b43882aad0b879b15d460a65666bfb4c1fe7ab11 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 25 Jan 2024 22:35:15 +0800 Subject: [PATCH 10/10] =?UTF-8?q?refactor:=20=E7=8B=AC=E7=AB=8Bratelimiter?= =?UTF-8?q?=E5=8C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/qqbot/manager.py | 4 ++ pkg/qqbot/process.py | 15 +---- pkg/qqbot/ratelim/__init__.py | 0 pkg/qqbot/ratelim/algo.py | 24 ++++++++ pkg/qqbot/ratelim/algos/__init__.py | 0 pkg/qqbot/ratelim/algos/fixedwin.py | 85 +++++++++++++++++++++++++++ pkg/qqbot/ratelim/ratelim.py | 31 ++++++++++ pkg/qqbot/ratelimit.py | 89 ----------------------------- pkg/qqbot/resprule/entities.py | 1 - 9 files changed, 147 insertions(+), 102 deletions(-) create mode 100644 pkg/qqbot/ratelim/__init__.py create mode 100644 pkg/qqbot/ratelim/algo.py create mode 100644 pkg/qqbot/ratelim/algos/__init__.py create mode 100644 pkg/qqbot/ratelim/algos/fixedwin.py create mode 100644 pkg/qqbot/ratelim/ratelim.py delete mode 100644 pkg/qqbot/ratelimit.py diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index a973ab6d..5239604f 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -22,6 +22,7 @@ from .resprule import resprule from .bansess import bansess from .cntfilter import cntfilter from .longtext import longtext +from .ratelim import ratelim from ..boot import app @@ -44,6 +45,7 @@ class QQBotManager: cntfilter_mgr: cntfilter.ContentFilterManager = None longtext_pcs: longtext.LongTextProcessor = None resprule_chkr: resprule.GroupRespondRuleChecker = None + ratelimiter: ratelim.RateLimiter = None def __init__(self, first_time_init=True, ap: app.Application = None): config = context.get_config_manager().data @@ -53,6 +55,7 @@ class QQBotManager: self.cntfilter_mgr = cntfilter.ContentFilterManager(ap) self.longtext_pcs = longtext.LongTextProcessor(ap) self.resprule_chkr = resprule.GroupRespondRuleChecker(ap) + self.ratelimiter = ratelim.RateLimiter(ap) self.timeout = config['process_message_timeout'] self.retry = config['retry_times'] @@ -62,6 +65,7 @@ class QQBotManager: await self.cntfilter_mgr.initialize() await self.longtext_pcs.initialize() await self.resprule_chkr.initialize() + await self.ratelimiter.initialize() config = context.get_config_manager().data diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index f6379c71..e1673583 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -7,7 +7,6 @@ import traceback import mirai import logging -from ..qqbot import ratelimit from ..qqbot import command, message from ..openai import session as openai_session from ..utils import context @@ -103,12 +102,10 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st else: # 消息 msg_type = 'message' # 限速丢弃检查 - # print(ratelimit.__crt_minute_usage__[session_name]) - if config['rate_limit_strategy'] == "drop": - if ratelimit.is_reach_limit(session_name): - logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message)) + if not await mgr.ratelimiter.require(launcher_type, launcher_id): + logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message)) - return mirai.MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else [] + return mirai.MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else [] before = time.time() # 触发插件事件 @@ -133,12 +130,6 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st reply = message.process_normal_message(text_message, mgr, config, launcher_type, launcher_id, sender_id) - # 限速等待时间 - if config['rate_limit_strategy'] == "wait": - time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before)) - - ratelimit.add_usage(session_name) - if reply is not None and len(reply) > 0 and (type(reply[0]) == str or type(reply[0]) == mirai.Plain): if type(reply[0]) == mirai.Plain: reply[0] = reply[0].text diff --git a/pkg/qqbot/ratelim/__init__.py b/pkg/qqbot/ratelim/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/ratelim/algo.py b/pkg/qqbot/ratelim/algo.py new file mode 100644 index 00000000..10bbdd3a --- /dev/null +++ b/pkg/qqbot/ratelim/algo.py @@ -0,0 +1,24 @@ +from __future__ import annotations +import abc + +from ...boot import app + + +class ReteLimitAlgo(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def require_access(self, launcher_type: str, launcher_id: int) -> bool: + raise NotImplementedError + + @abc.abstractmethod + async def release_access(self, launcher_type: str, launcher_id: int): + raise NotImplementedError + \ No newline at end of file diff --git a/pkg/qqbot/ratelim/algos/__init__.py b/pkg/qqbot/ratelim/algos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/ratelim/algos/fixedwin.py b/pkg/qqbot/ratelim/algos/fixedwin.py new file mode 100644 index 00000000..4996fbaa --- /dev/null +++ b/pkg/qqbot/ratelim/algos/fixedwin.py @@ -0,0 +1,85 @@ +# 固定窗口算法 +from __future__ import annotations + +import asyncio +import time + +from .. import algo + + +class SessionContainer: + + wait_lock: asyncio.Lock + + records: dict[int, int] + """访问记录,key为每分钟的起始时间戳,value为访问次数""" + + def __init__(self): + self.wait_lock = asyncio.Lock() + self.records = {} + + +class FixedWindowAlgo(algo.ReteLimitAlgo): + + containers_lock: asyncio.Lock + """访问记录容器锁""" + + containers: dict[str, SessionContainer] + """访问记录容器,key为launcher_type launcher_id""" + + async def initialize(self): + self.containers_lock = asyncio.Lock() + self.containers = {} + + async def require_access(self, launcher_type: str, launcher_id: int) -> bool: + # 加锁,找容器 + container: SessionContainer = None + + session_name = f'{launcher_type}_{launcher_id}' + + async with self.containers_lock: + container = self.containers.get(session_name) + + if container is None: + container = SessionContainer() + self.containers[session_name] = container + + # 等待锁 + async with container.wait_lock: + # 获取当前时间戳 + now = int(time.time()) + + # 获取当前分钟的起始时间戳 + now = now - now % 60 + + # 获取当前分钟的访问次数 + count = container.records.get(now, 0) + + limitation = self.ap.cfg_mgr.data['rate_limitation']['default'] + + if session_name in self.ap.cfg_mgr.data['rate_limitation']: + limitation = self.ap.cfg_mgr.data['rate_limitation'][session_name] + + # 如果访问次数超过了限制 + if count >= limitation: + if self.ap.cfg_mgr.data['rate_limit_strategy'] == 'drop': + return False + elif self.ap.cfg_mgr.data['rate_limit_strategy'] == 'wait': + # 等待下一分钟 + await asyncio.sleep(60 - time.time() % 60) + + now = int(time.time()) + now = now - now % 60 + + if now not in container.records: + container.records = {} + container.records[now] = 1 + else: + # 访问次数加一 + container.records[now] = count + 1 + + # 返回True + return True + + async def release_access(self, launcher_type: str, launcher_id: int): + pass diff --git a/pkg/qqbot/ratelim/ratelim.py b/pkg/qqbot/ratelim/ratelim.py new file mode 100644 index 00000000..ab23d714 --- /dev/null +++ b/pkg/qqbot/ratelim/ratelim.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from . import algo +from .algos import fixedwin +from ...boot import app + + +class RateLimiter: + """限速器 + """ + + ap: app.Application + + algo: algo.ReteLimitAlgo + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + self.algo = fixedwin.FixedWindowAlgo(self.ap) + await self.algo.initialize() + + async def require(self, launcher_type: str, launcher_id: int) -> bool: + """请求访问 + """ + return await self.algo.require_access(launcher_type, launcher_id) + + async def release(self, launcher_type: str, launcher_id: int): + """释放访问 + """ + return await self.algo.release_access(launcher_type, launcher_id) \ No newline at end of file diff --git a/pkg/qqbot/ratelimit.py b/pkg/qqbot/ratelimit.py deleted file mode 100644 index 96d289ff..00000000 --- a/pkg/qqbot/ratelimit.py +++ /dev/null @@ -1,89 +0,0 @@ -# 限速相关模块 -import time -import logging -import threading - -from ..utils import context - - -__crt_minute_usage__ = {} -"""当前分钟每个会话的对话次数""" - - -__timer_thr__: threading.Thread = None - - -def get_limitation(session_name: str) -> int: - """获取会话的限制次数""" - config = context.get_config_manager().data - - if session_name in config['rate_limitation']: - return config['rate_limitation'][session_name] - else: - return config['rate_limitation']["default"] - - -def add_usage(session_name: str): - """增加会话的对话次数""" - global __crt_minute_usage__ - if session_name in __crt_minute_usage__: - __crt_minute_usage__[session_name] += 1 - else: - __crt_minute_usage__[session_name] = 1 - - -def start_timer(): - """启动定时器""" - global __timer_thr__ - __timer_thr__ = threading.Thread(target=run_timer, daemon=True) - __timer_thr__.start() - - -def run_timer(): - """启动定时器,每分钟清空一次对话次数""" - global __crt_minute_usage__ - global __timer_thr__ - - # 等待直到整分钟 - time.sleep(60 - time.time() % 60) - - while True: - if __timer_thr__ != threading.current_thread(): - break - - logging.debug("清空当前分钟的对话次数") - __crt_minute_usage__ = {} - time.sleep(60) - - -def get_usage(session_name: str) -> int: - """获取会话的对话次数""" - global __crt_minute_usage__ - if session_name in __crt_minute_usage__: - return __crt_minute_usage__[session_name] - else: - return 0 - - -def get_rest_wait_time(session_name: str, spent: float) -> float: - """获取会话此回合的剩余等待时间""" - global __crt_minute_usage__ - - min_seconds_per_round = 60.0 / get_limitation(session_name) - - if session_name in __crt_minute_usage__: - return max(0, min_seconds_per_round - spent) - else: - return 0 - - -def is_reach_limit(session_name: str) -> bool: - """判断会话是否超过限制""" - global __crt_minute_usage__ - - if session_name in __crt_minute_usage__: - return __crt_minute_usage__[session_name] >= get_limitation(session_name) - else: - return False - -start_timer() diff --git a/pkg/qqbot/resprule/entities.py b/pkg/qqbot/resprule/entities.py index 1cdd76f2..ffee3081 100644 --- a/pkg/qqbot/resprule/entities.py +++ b/pkg/qqbot/resprule/entities.py @@ -7,4 +7,3 @@ class RuleJudgeResult(pydantic.BaseModel): matching: bool = False replacement: mirai.MessageChain = None -