diff --git a/config-template.py b/config-template.py index fb60ff1e..2c1d0e99 100644 --- a/config-template.py +++ b/config-template.py @@ -167,6 +167,8 @@ response_rules = { # 此设置优先级高于response_rules # 用以过滤mirai等其他层级的命令 # @see https://github.com/RockChinQ/QChatGPT/issues/165 +# +# *需要同时开启下方 income_msg_check 才会生效 ignore_rules = { "prefix": ["/"], "regexp": [] diff --git a/pkg/boot/config.py b/pkg/boot/config.py index 1d891da0..3f796214 100644 --- a/pkg/boot/config.py +++ b/pkg/boot/config.py @@ -4,17 +4,8 @@ from ..config import manager as config_mgr from ..config.impls import pymodule -async def load_python_module_config(config_name: str, template_name: str) -> config_mgr.ConfigManager: - """加载Python模块配置文件""" - cfg_inst = pymodule.PythonModuleConfigFile( - config_name, - template_name - ) - - cfg_mgr = config_mgr.ConfigManager(cfg_inst) - await cfg_mgr.load_config() - - return cfg_mgr +load_python_module_config = config_mgr.load_python_module_config +load_json_config = config_mgr.load_json_config async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str]: diff --git a/pkg/config/impls/json.py b/pkg/config/impls/json.py new file mode 100644 index 00000000..cfc284cb --- /dev/null +++ b/pkg/config/impls/json.py @@ -0,0 +1,44 @@ +import os +import shutil +import json + +from .. import model as file_model + + +class JSONConfigFile(file_model.ConfigFile): + """JSON配置文件""" + + config_file_name: str = None + """配置文件名""" + + template_file_name: str = None + """模板文件名""" + + def __init__(self, config_file_name: str, template_file_name: str) -> None: + self.config_file_name = config_file_name + self.template_file_name = template_file_name + + def exists(self) -> bool: + return os.path.exists(self.config_file_name) + + async def create(self): + shutil.copyfile(self.template_file_name, self.config_file_name) + + async def load(self) -> dict: + + with open(self.config_file_name, 'r', encoding='utf-8') as f: + cfg = json.load(f) + + # 从模板文件中进行补全 + with open(self.template_file_name, 'r', encoding='utf-8') as f: + template_cfg = json.load(f) + + for key in template_cfg: + if key not in cfg: + cfg[key] = template_cfg[key] + + return cfg + + async def save(self, cfg: dict): + with open(self.config_file_name, 'w', encoding='utf-8') as f: + json.dump(cfg, f, indent=4, ensure_ascii=False) \ No newline at end of file diff --git a/pkg/config/manager.py b/pkg/config/manager.py index 5893ff0b..e343b0c2 100644 --- a/pkg/config/manager.py +++ b/pkg/config/manager.py @@ -1,5 +1,6 @@ from . import model as file_model from ..utils import context +from .impls import pymodule, json as json_file class ConfigManager: @@ -20,3 +21,29 @@ class ConfigManager: async def dump_config(self): await self.file.save(self.data) + + +async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager: + """加载Python模块配置文件""" + cfg_inst = pymodule.PythonModuleConfigFile( + config_name, + template_name + ) + + cfg_mgr = ConfigManager(cfg_inst) + await cfg_mgr.load_config() + + return cfg_mgr + + +async def load_json_config(config_name: str, template_name: str) -> ConfigManager: + """加载JSON配置文件""" + cfg_inst = json_file.JSONConfigFile( + config_name, + template_name + ) + + cfg_mgr = ConfigManager(cfg_inst) + await cfg_mgr.load_config() + + return cfg_mgr \ No newline at end of file diff --git a/pkg/qqbot/bansess/bansess.py b/pkg/qqbot/bansess/bansess.py index e037fde8..d8ef4958 100644 --- a/pkg/qqbot/bansess/bansess.py +++ b/pkg/qqbot/bansess/bansess.py @@ -4,7 +4,6 @@ from __future__ import annotations import re from ...boot import app -from ...boot import config as config_util from ...config import manager as cfg_mgr @@ -18,7 +17,7 @@ class SessionBanManager: self.ap = ap async def initialize(self): - self.banlist_mgr = await config_util.load_python_module_config( + self.banlist_mgr = await cfg_mgr.load_python_module_config( "banlist.py", "res/templates/banlist-template.py" ) diff --git a/pkg/qqbot/cntfilter/__init__.py b/pkg/qqbot/cntfilter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/cntfilter/cntfilter.py b/pkg/qqbot/cntfilter/cntfilter.py new file mode 100644 index 00000000..2d690b57 --- /dev/null +++ b/pkg/qqbot/cntfilter/cntfilter.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from ...boot import app +from . import entities +from . import filter +from .filters import cntignore, banwords, baiduexamine + + +class ContentFilterManager: + + ao: app.Application + + filter_chain: list[filter.ContentFilter] + + def __init__(self, ap: app.Application) -> None: + self.ap = ap + self.filter_chain = [] + + async def initialize(self): + self.filter_chain.append(cntignore.ContentIgnore(self.ap)) + + if self.ap.cfg_mgr.data['sensitive_word_filter']: + self.filter_chain.append(banwords.BanWordFilter(self.ap)) + + if self.ap.cfg_mgr.data['baidu_check']: + self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap)) + + for filter in self.filter_chain: + await filter.initialize() + + async def pre_process(self, message: str) -> entities.FilterManagerResult: + """请求llm前处理消息 + 只要有一个不通过就不放行,只放行 PASS 的消息 + """ + if not self.ap.cfg_mgr.data['income_msg_check']: # 不检查收到的消息,直接放行 + return entities.FilterManagerResult( + level=entities.ManagerResultLevel.CONTINUE, + replacement=message, + user_notice='', + console_notice='' + ) + else: + for filter in self.filter_chain: + if entities.EnableStage.PRE in filter.enable_stages: + result = await filter.process(message) + + if result.level in [ + entities.ResultLevel.BLOCK, + entities.ResultLevel.MASKED + ]: + return entities.FilterManagerResult( + level=entities.ManagerResultLevel.INTERRUPT, + replacement=result.replacement, + user_notice=result.user_notice, + console_notice=result.console_notice + ) + elif result.level == entities.ResultLevel.PASS: + message = result.replacement + + return entities.FilterManagerResult( + level=entities.ManagerResultLevel.CONTINUE, + replacement=message, + user_notice='', + console_notice='' + ) + + async def post_process(self, message: str) -> entities.FilterManagerResult: + """请求llm后处理响应 + 只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter + """ + for filter in self.filter_chain: + if entities.EnableStage.POST in filter.enable_stages: + result = await filter.process(message) + + if result.level == entities.ResultLevel.BLOCK: + return entities.FilterManagerResult( + level=entities.ManagerResultLevel.INTERRUPT, + replacement=result.replacement, + user_notice=result.user_notice, + console_notice=result.console_notice + ) + elif result.level in [ + entities.ResultLevel.PASS, + entities.ResultLevel.MASKED + ]: + message = result.replacement + + return entities.FilterManagerResult( + level=entities.ManagerResultLevel.CONTINUE, + replacement=message, + user_notice='', + console_notice='' + ) diff --git a/pkg/qqbot/cntfilter/entities.py b/pkg/qqbot/cntfilter/entities.py new file mode 100644 index 00000000..7ab05675 --- /dev/null +++ b/pkg/qqbot/cntfilter/entities.py @@ -0,0 +1,64 @@ + +import typing +import enum + +import pydantic + + +class ResultLevel(enum.Enum): + """结果等级""" + PASS = enum.auto() + """通过""" + + WARN = enum.auto() + """警告""" + + MASKED = enum.auto() + """已掩去""" + + BLOCK = enum.auto() + """阻止""" + + +class EnableStage(enum.Enum): + """启用阶段""" + PRE = enum.auto() + """预处理""" + + POST = enum.auto() + """后处理""" + + +class FilterResult(pydantic.BaseModel): + level: ResultLevel + + replacement: str + """替换后的消息""" + + user_notice: str + """不通过时,用户提示消息""" + + console_notice: str + """不通过时,控制台提示消息""" + + +class ManagerResultLevel(enum.Enum): + """处理器结果等级""" + CONTINUE = enum.auto() + """继续""" + + INTERRUPT = enum.auto() + """中断""" + +class FilterManagerResult(pydantic.BaseModel): + + level: ManagerResultLevel + + replacement: str + """替换后的消息""" + + user_notice: str + """用户提示消息""" + + console_notice: str + """控制台提示消息""" diff --git a/pkg/qqbot/cntfilter/filter.py b/pkg/qqbot/cntfilter/filter.py new file mode 100644 index 00000000..4d4cd79f --- /dev/null +++ b/pkg/qqbot/cntfilter/filter.py @@ -0,0 +1,34 @@ +# 内容过滤器的抽象类 +from __future__ import annotations +import abc + +from ...boot import app +from . import entities + + +class ContentFilter(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + @property + def enable_stages(self): + """启用的阶段 + """ + return [ + entities.EnableStage.PRE, + entities.EnableStage.POST + ] + + async def initialize(self): + """初始化过滤器 + """ + pass + + @abc.abstractmethod + async def process(self, message: str) -> entities.FilterResult: + """处理消息 + """ + raise NotImplementedError diff --git a/pkg/qqbot/cntfilter/filters/__init__.py b/pkg/qqbot/cntfilter/filters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/cntfilter/filters/baiduexamine.py b/pkg/qqbot/cntfilter/filters/baiduexamine.py new file mode 100644 index 00000000..a658897b --- /dev/null +++ b/pkg/qqbot/cntfilter/filters/baiduexamine.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import aiohttp + +from .. import entities +from .. import filter as filter_model + + +BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}" +BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" + + +class BaiduCloudExamine(filter_model.ContentFilter): + """百度云内容审核""" + + async def _get_token(self) -> str: + async with aiohttp.ClientSession() as session: + async with session.post( + BAIDU_EXAMINE_TOKEN_URL, + params={ + "grant_type": "client_credentials", + "client_id": self.ap.cfg_mgr.data['baidu_api_key'], + "client_secret": self.ap.cfg_mgr.data['baidu_secret_key'] + } + ) as resp: + return (await resp.json())['access_token'] + + async def process(self, message: str) -> entities.FilterResult: + + async with aiohttp.ClientSession() as session: + async with session.post( + BAIDU_EXAMINE_URL.format(await self._get_token()), + headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'}, + data=f"text={message}".encode('utf-8') + ) as resp: + result = await resp.json() + + if "error_code" in result: + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement=message, + user_notice='', + console_notice=f"百度云判定出错,错误信息:{result['error_msg']}" + ) + else: + conclusion = result["conclusion"] + + if conclusion in ("合规"): + return entities.FilterResult( + level=entities.ResultLevel.PASS, + replacement=message, + user_notice='', + console_notice=f"百度云判定结果:{conclusion}" + ) + else: + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement=message, + user_notice=self.ap.cfg_mgr.data['inappropriate_message_tips'], + console_notice=f"百度云判定结果:{conclusion}" + ) \ No newline at end of file diff --git a/pkg/qqbot/cntfilter/filters/banwords.py b/pkg/qqbot/cntfilter/filters/banwords.py new file mode 100644 index 00000000..9451c7b8 --- /dev/null +++ b/pkg/qqbot/cntfilter/filters/banwords.py @@ -0,0 +1,44 @@ +from __future__ import annotations +import re + +from .. import filter as filter_model +from .. import entities +from ....config import manager as cfg_mgr + + +class BanWordFilter(filter_model.ContentFilter): + """根据内容禁言""" + + sensitive: cfg_mgr.ConfigManager + + async def initialize(self): + self.sensitive = await cfg_mgr.load_json_config( + "sensitive.json", + "res/templates/sensitive-template.json" + ) + + async def process(self, message: str) -> entities.FilterResult: + found = False + + for word in self.sensitive.data['words']: + match = re.findall(word, message) + + if len(match) > 0: + found = True + + for i in range(len(match)): + if self.sensitive.data['mask_word'] == "": + message = message.replace( + match[i], self.sensitive.data['mask'] * len(match[i]) + ) + else: + message = message.replace( + match[i], self.sensitive.data['mask_word'] + ) + + return entities.FilterResult( + level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS, + replacement=message, + user_notice='[bot] 消息中存在不合适的内容, 请更换措辞' if found else '', + console_notice='' + ) \ No newline at end of file diff --git a/pkg/qqbot/cntfilter/filters/cntignore.py b/pkg/qqbot/cntfilter/filters/cntignore.py new file mode 100644 index 00000000..81408868 --- /dev/null +++ b/pkg/qqbot/cntfilter/filters/cntignore.py @@ -0,0 +1,43 @@ +from __future__ import annotations +import re + +from .. import entities +from .. import filter as filter_model + + +class ContentIgnore(filter_model.ContentFilter): + """根据内容忽略消息""" + + @property + def enable_stages(self): + return [ + entities.EnableStage.PRE, + ] + + async def process(self, message: str) -> entities.FilterResult: + if 'prefix' in self.ap.cfg_mgr.data['ignore_rules']: + for rule in self.ap.cfg_mgr.data['ignore_rules']['prefix']: + if message.startswith(rule): + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement='', + user_notice='', + console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息' + ) + + if 'regexp' in self.ap.cfg_mgr.data['ignore_rules']: + for rule in self.ap.cfg_mgr.data['ignore_rules']['regexp']: + if re.search(rule, message): + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement='', + user_notice='', + console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息' + ) + + return entities.FilterResult( + level=entities.ResultLevel.PASS, + replacement=message, + user_notice='', + console_notice='' + ) \ No newline at end of file diff --git a/pkg/qqbot/filter.py b/pkg/qqbot/filter.py deleted file mode 100644 index c3a58093..00000000 --- a/pkg/qqbot/filter.py +++ /dev/null @@ -1,87 +0,0 @@ -# 敏感词过滤模块 -import re -import requests -import json -import logging - -from ..utils import context - - -class ReplyFilter: - sensitive_words = [] - mask = "*" - mask_word = "" - - # 默认值( 兼容性考虑 ) - baidu_check = False - baidu_api_key = "" - baidu_secret_key = "" - inappropriate_message_tips = "[百度云]请珍惜机器人,当前返回内容不合规" - - def __init__(self, sensitive_words: list, mask: str = "*", mask_word: str = ""): - self.sensitive_words = sensitive_words - self.mask = mask - self.mask_word = mask_word - - config = context.get_config_manager().data - - self.baidu_check = config['baidu_check'] - self.baidu_api_key = config['baidu_api_key'] - self.baidu_secret_key = config['baidu_secret_key'] - self.inappropriate_message_tips = config['inappropriate_message_tips'] - - def is_illegal(self, message: str) -> bool: - processed = self.process(message) - if processed != message: - return True - return False - - def process(self, message: str) -> str: - - # 本地关键词屏蔽 - for word in self.sensitive_words: - match = re.findall(word, message) - if len(match) > 0: - for i in range(len(match)): - if self.mask_word == "": - message = message.replace(match[i], self.mask * len(match[i])) - else: - message = message.replace(match[i], self.mask_word) - - # 百度云审核 - if self.baidu_check: - - # 百度云审核URL - baidu_url = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token=" + \ - str(requests.post("https://aip.baidubce.com/oauth/2.0/token", - params={"grant_type": "client_credentials", - "client_id": self.baidu_api_key, - "client_secret": self.baidu_secret_key}).json().get("access_token")) - - # 百度云审核 - payload = "text=" + message - logging.info("向百度云发送:" + payload) - headers = {'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'} - - if isinstance(payload, str): - payload = payload.encode('utf-8') - - response = requests.request("POST", baidu_url, headers=headers, data=payload) - response_dict = json.loads(response.text) - - if "error_code" in response_dict: - error_msg = response_dict.get("error_msg") - logging.warning(f"百度云判定出错,错误信息:{error_msg}") - conclusion = f"百度云判定出错,错误信息:{error_msg}\n以下是原消息:{message}" - else: - conclusion = response_dict["conclusion"] - if conclusion in ("合规"): - logging.info(f"百度云判定结果:{conclusion}") - return message - else: - logging.warning(f"百度云判定结果:{conclusion}") - conclusion = self.inappropriate_message_tips - # 返回百度云审核结果 - return conclusion - - return message diff --git a/pkg/qqbot/ignore.py b/pkg/qqbot/ignore.py deleted file mode 100644 index e1adc777..00000000 --- a/pkg/qqbot/ignore.py +++ /dev/null @@ -1,18 +0,0 @@ -import re - -from ..utils import context - - -def ignore(msg: str) -> bool: - """检查消息是否应该被忽略""" - config = context.get_config_manager().data - - if 'prefix' in config['ignore_rules']: - for rule in config['ignore_rules']['prefix']: - if msg.startswith(rule): - return True - - if 'regexp' in config['ignore_rules']: - for rule in config['ignore_rules']['regexp']: - if re.search(rule, msg): - return True diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index bfe86b9c..a801229d 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -12,7 +12,6 @@ import func_timeout from ..openai import session as openai_session -from ..qqbot import filter as qqbot_filter from ..qqbot import process as processor from ..utils import context from ..plugin import host as plugin_host @@ -21,6 +20,7 @@ import tips as tips_custom from ..qqbot import adapter as msadapter from . import resprule from .bansess import bansess +from .cntfilter import cntfilter from ..boot import app @@ -33,8 +33,6 @@ class QQBotManager: bot_account_id: int = 0 - reply_filter = None - enable_banlist = False enable_private = True @@ -47,18 +45,21 @@ class QQBotManager: ap: app.Application = None bansess_mgr: bansess.SessionBanManager = None + cntfilter_mgr: cntfilter.ContentFilterManager = None def __init__(self, first_time_init=True, ap: app.Application = None): config = context.get_config_manager().data self.ap = ap self.bansess_mgr = bansess.SessionBanManager(ap) + self.cntfilter_mgr = cntfilter.ContentFilterManager(ap) self.timeout = config['process_message_timeout'] self.retry = config['retry_times'] async def initialize(self): await self.bansess_mgr.initialize() + await self.cntfilter_mgr.initialize() config = context.get_config_manager().data @@ -174,20 +175,6 @@ class QQBotManager: self.unsubscribe_all = unsubscribe_all - config = context.get_config_manager().data - if os.path.exists("sensitive.json") \ - and config['sensitive_word_filter'] is not None \ - and config['sensitive_word_filter']: - with open("sensitive.json", "r", encoding="utf-8") as f: - sensitive_json = json.load(f) - self.reply_filter = qqbot_filter.ReplyFilter( - sensitive_words=sensitive_json['words'], - mask=sensitive_json['mask'] if 'mask' in sensitive_json else '*', - mask_word=sensitive_json['mask_word'] if 'mask_word' in sensitive_json else '' - ) - else: - self.reply_filter = qqbot_filter.ReplyFilter([]) - async def send(self, event, msg, check_quote=True, check_at_sender=True): config = context.get_config_manager().data diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index aa02315f..fedaddbd 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -14,10 +14,10 @@ from ..utils import context from ..plugin import host as plugin_host from ..plugin import models as plugin_models -from ..qqbot import ignore from ..qqbot import blob import tips as tips_custom from ..boot import app +from .cntfilter import entities processing = [] @@ -32,7 +32,7 @@ def is_admin(qq: int) -> bool: async def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: mirai.MessageChain, - sender_id: int) -> mirai.MessageChain: + sender_id: int) -> list: global processing mgr = context.get_qqbot_manager() @@ -40,14 +40,10 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st reply = [] session_name = "{}_{}".format(launcher_type, launcher_id) - if ignore.ignore(text_message): - logging.info("根据忽略规则忽略消息: {}".format(text_message)) - return [] - config = context.get_config_manager().data if not config['wait_last_done'] and session_name in processing: - return mirai.MessageChain([mirai.Plain(tips_custom.message_drop_tip)]) + return [mirai.Plain(tips_custom.message_drop_tip)] # 检查是否被禁言 if launcher_type == 'group': @@ -56,9 +52,14 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st logging.info("机器人被禁言,跳过消息处理(group_{})".format(launcher_id)) return reply - if config['income_msg_check']: - if mgr.reply_filter.is_illegal(text_message): - return mirai.MessageChain(mirai.Plain("[bot] 消息中存在不合适的内容, 请更换措辞")) + cntfilter_res = await mgr.cntfilter_mgr.pre_process(text_message) + if cntfilter_res.level == entities.ManagerResultLevel.INTERRUPT: + if cntfilter_res.console_notice: + mgr.ap.logger.info(cntfilter_res.console_notice) + if cntfilter_res.user_notice: + return [mirai.Plain(cntfilter_res.user_notice)] + else: + return [] openai_session.get_session(session_name).acquire_response_lock() @@ -147,7 +148,16 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st reply[0][:min(100, len(reply[0]))] + ( "..." if len(reply[0]) > 100 else ""))) if msg_type == 'message': - reply = [mgr.reply_filter.process(reply[0])] + cntfilter_res = await mgr.cntfilter_mgr.post_process(reply[0]) + if cntfilter_res.level == entities.ManagerResultLevel.INTERRUPT: + if cntfilter_res.console_notice: + mgr.ap.logger.info(cntfilter_res.console_notice) + if cntfilter_res.user_notice: + return [mirai.Plain(cntfilter_res.user_notice)] + else: + return [] + else: + reply = [cntfilter_res.replacement] reply = blob.check_text(reply[0]) else: