refactor: filter和ignore独立成新的cntfilter包

This commit is contained in:
RockChinQ
2024-01-25 15:28:23 +08:00
parent f4ae9df3bf
commit a9a798b19d
17 changed files with 440 additions and 146 deletions

View File

@@ -167,6 +167,8 @@ response_rules = {
# 此设置优先级高于response_rules
# 用以过滤mirai等其他层级的命令
# @see https://github.com/RockChinQ/QChatGPT/issues/165
#
# *需要同时开启下方 income_msg_check 才会生效
ignore_rules = {
"prefix": ["/"],
"regexp": []

View File

@@ -4,17 +4,8 @@ from ..config import manager as config_mgr
from ..config.impls import pymodule
async def load_python_module_config(config_name: str, template_name: str) -> config_mgr.ConfigManager:
"""加载Python模块配置文件"""
cfg_inst = pymodule.PythonModuleConfigFile(
config_name,
template_name
)
cfg_mgr = config_mgr.ConfigManager(cfg_inst)
await cfg_mgr.load_config()
return cfg_mgr
load_python_module_config = config_mgr.load_python_module_config
load_json_config = config_mgr.load_json_config
async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str]:

44
pkg/config/impls/json.py Normal file
View File

@@ -0,0 +1,44 @@
import os
import shutil
import json
from .. import model as file_model
class JSONConfigFile(file_model.ConfigFile):
"""JSON配置文件"""
config_file_name: str = None
"""配置文件名"""
template_file_name: str = None
"""模板文件名"""
def __init__(self, config_file_name: str, template_file_name: str) -> None:
self.config_file_name = config_file_name
self.template_file_name = template_file_name
def exists(self) -> bool:
return os.path.exists(self.config_file_name)
async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name)
async def load(self) -> dict:
with open(self.config_file_name, 'r', encoding='utf-8') as f:
cfg = json.load(f)
# 从模板文件中进行补全
with open(self.template_file_name, 'r', encoding='utf-8') as f:
template_cfg = json.load(f)
for key in template_cfg:
if key not in cfg:
cfg[key] = template_cfg[key]
return cfg
async def save(self, cfg: dict):
with open(self.config_file_name, 'w', encoding='utf-8') as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)

View File

@@ -1,5 +1,6 @@
from . import model as file_model
from ..utils import context
from .impls import pymodule, json as json_file
class ConfigManager:
@@ -20,3 +21,29 @@ class ConfigManager:
async def dump_config(self):
await self.file.save(self.data)
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager:
"""加载Python模块配置文件"""
cfg_inst = pymodule.PythonModuleConfigFile(
config_name,
template_name
)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config()
return cfg_mgr
async def load_json_config(config_name: str, template_name: str) -> ConfigManager:
"""加载JSON配置文件"""
cfg_inst = json_file.JSONConfigFile(
config_name,
template_name
)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config()
return cfg_mgr

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import re
from ...boot import app
from ...boot import config as config_util
from ...config import manager as cfg_mgr
@@ -18,7 +17,7 @@ class SessionBanManager:
self.ap = ap
async def initialize(self):
self.banlist_mgr = await config_util.load_python_module_config(
self.banlist_mgr = await cfg_mgr.load_python_module_config(
"banlist.py",
"res/templates/banlist-template.py"
)

View File

View File

@@ -0,0 +1,93 @@
from __future__ import annotations
from ...boot import app
from . import entities
from . import filter
from .filters import cntignore, banwords, baiduexamine
class ContentFilterManager:
ao: app.Application
filter_chain: list[filter.ContentFilter]
def __init__(self, ap: app.Application) -> None:
self.ap = ap
self.filter_chain = []
async def initialize(self):
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
if self.ap.cfg_mgr.data['sensitive_word_filter']:
self.filter_chain.append(banwords.BanWordFilter(self.ap))
if self.ap.cfg_mgr.data['baidu_check']:
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
for filter in self.filter_chain:
await filter.initialize()
async def pre_process(self, message: str) -> entities.FilterManagerResult:
"""请求llm前处理消息
只要有一个不通过就不放行,只放行 PASS 的消息
"""
if not self.ap.cfg_mgr.data['income_msg_check']: # 不检查收到的消息,直接放行
return entities.FilterManagerResult(
level=entities.ManagerResultLevel.CONTINUE,
replacement=message,
user_notice='',
console_notice=''
)
else:
for filter in self.filter_chain:
if entities.EnableStage.PRE in filter.enable_stages:
result = await filter.process(message)
if result.level in [
entities.ResultLevel.BLOCK,
entities.ResultLevel.MASKED
]:
return entities.FilterManagerResult(
level=entities.ManagerResultLevel.INTERRUPT,
replacement=result.replacement,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level == entities.ResultLevel.PASS:
message = result.replacement
return entities.FilterManagerResult(
level=entities.ManagerResultLevel.CONTINUE,
replacement=message,
user_notice='',
console_notice=''
)
async def post_process(self, message: str) -> entities.FilterManagerResult:
"""请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
"""
for filter in self.filter_chain:
if entities.EnableStage.POST in filter.enable_stages:
result = await filter.process(message)
if result.level == entities.ResultLevel.BLOCK:
return entities.FilterManagerResult(
level=entities.ManagerResultLevel.INTERRUPT,
replacement=result.replacement,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level in [
entities.ResultLevel.PASS,
entities.ResultLevel.MASKED
]:
message = result.replacement
return entities.FilterManagerResult(
level=entities.ManagerResultLevel.CONTINUE,
replacement=message,
user_notice='',
console_notice=''
)

View File

@@ -0,0 +1,64 @@
import typing
import enum
import pydantic
class ResultLevel(enum.Enum):
"""结果等级"""
PASS = enum.auto()
"""通过"""
WARN = enum.auto()
"""警告"""
MASKED = enum.auto()
"""已掩去"""
BLOCK = enum.auto()
"""阻止"""
class EnableStage(enum.Enum):
"""启用阶段"""
PRE = enum.auto()
"""预处理"""
POST = enum.auto()
"""后处理"""
class FilterResult(pydantic.BaseModel):
level: ResultLevel
replacement: str
"""替换后的消息"""
user_notice: str
"""不通过时,用户提示消息"""
console_notice: str
"""不通过时,控制台提示消息"""
class ManagerResultLevel(enum.Enum):
"""处理器结果等级"""
CONTINUE = enum.auto()
"""继续"""
INTERRUPT = enum.auto()
"""中断"""
class FilterManagerResult(pydantic.BaseModel):
level: ManagerResultLevel
replacement: str
"""替换后的消息"""
user_notice: str
"""用户提示消息"""
console_notice: str
"""控制台提示消息"""

View File

@@ -0,0 +1,34 @@
# 内容过滤器的抽象类
from __future__ import annotations
import abc
from ...boot import app
from . import entities
class ContentFilter(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
@property
def enable_stages(self):
"""启用的阶段
"""
return [
entities.EnableStage.PRE,
entities.EnableStage.POST
]
async def initialize(self):
"""初始化过滤器
"""
pass
@abc.abstractmethod
async def process(self, message: str) -> entities.FilterResult:
"""处理消息
"""
raise NotImplementedError

View File

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
import aiohttp
from .. import entities
from .. import filter as filter_model
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核"""
async def _get_token(self) -> str:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_TOKEN_URL,
params={
"grant_type": "client_credentials",
"client_id": self.ap.cfg_mgr.data['baidu_api_key'],
"client_secret": self.ap.cfg_mgr.data['baidu_secret_key']
}
) as resp:
return (await resp.json())['access_token']
async def process(self, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()),
headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'},
data=f"text={message}".encode('utf-8')
) as resp:
result = await resp.json()
if "error_code" in result:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice='',
console_notice=f"百度云判定出错,错误信息:{result['error_msg']}"
)
else:
conclusion = result["conclusion"]
if conclusion in ("合规"):
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=f"百度云判定结果:{conclusion}"
)
else:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice=self.ap.cfg_mgr.data['inappropriate_message_tips'],
console_notice=f"百度云判定结果:{conclusion}"
)

View File

@@ -0,0 +1,44 @@
from __future__ import annotations
import re
from .. import filter as filter_model
from .. import entities
from ....config import manager as cfg_mgr
class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言"""
sensitive: cfg_mgr.ConfigManager
async def initialize(self):
self.sensitive = await cfg_mgr.load_json_config(
"sensitive.json",
"res/templates/sensitive-template.json"
)
async def process(self, message: str) -> entities.FilterResult:
found = False
for word in self.sensitive.data['words']:
match = re.findall(word, message)
if len(match) > 0:
found = True
for i in range(len(match)):
if self.sensitive.data['mask_word'] == "":
message = message.replace(
match[i], self.sensitive.data['mask'] * len(match[i])
)
else:
message = message.replace(
match[i], self.sensitive.data['mask_word']
)
return entities.FilterResult(
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
replacement=message,
user_notice='[bot] 消息中存在不合适的内容, 请更换措辞' if found else '',
console_notice=''
)

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
import re
from .. import entities
from .. import filter as filter_model
class ContentIgnore(filter_model.ContentFilter):
"""根据内容忽略消息"""
@property
def enable_stages(self):
return [
entities.EnableStage.PRE,
]
async def process(self, message: str) -> entities.FilterResult:
if 'prefix' in self.ap.cfg_mgr.data['ignore_rules']:
for rule in self.ap.cfg_mgr.data['ignore_rules']['prefix']:
if message.startswith(rule):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
)
if 'regexp' in self.ap.cfg_mgr.data['ignore_rules']:
for rule in self.ap.cfg_mgr.data['ignore_rules']['regexp']:
if re.search(rule, message):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息'
)
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=''
)

View File

@@ -1,87 +0,0 @@
# 敏感词过滤模块
import re
import requests
import json
import logging
from ..utils import context
class ReplyFilter:
sensitive_words = []
mask = "*"
mask_word = ""
# 默认值( 兼容性考虑 )
baidu_check = False
baidu_api_key = ""
baidu_secret_key = ""
inappropriate_message_tips = "[百度云]请珍惜机器人,当前返回内容不合规"
def __init__(self, sensitive_words: list, mask: str = "*", mask_word: str = ""):
self.sensitive_words = sensitive_words
self.mask = mask
self.mask_word = mask_word
config = context.get_config_manager().data
self.baidu_check = config['baidu_check']
self.baidu_api_key = config['baidu_api_key']
self.baidu_secret_key = config['baidu_secret_key']
self.inappropriate_message_tips = config['inappropriate_message_tips']
def is_illegal(self, message: str) -> bool:
processed = self.process(message)
if processed != message:
return True
return False
def process(self, message: str) -> str:
# 本地关键词屏蔽
for word in self.sensitive_words:
match = re.findall(word, message)
if len(match) > 0:
for i in range(len(match)):
if self.mask_word == "":
message = message.replace(match[i], self.mask * len(match[i]))
else:
message = message.replace(match[i], self.mask_word)
# 百度云审核
if self.baidu_check:
# 百度云审核URL
baidu_url = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token=" + \
str(requests.post("https://aip.baidubce.com/oauth/2.0/token",
params={"grant_type": "client_credentials",
"client_id": self.baidu_api_key,
"client_secret": self.baidu_secret_key}).json().get("access_token"))
# 百度云审核
payload = "text=" + message
logging.info("向百度云发送:" + payload)
headers = {'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'}
if isinstance(payload, str):
payload = payload.encode('utf-8')
response = requests.request("POST", baidu_url, headers=headers, data=payload)
response_dict = json.loads(response.text)
if "error_code" in response_dict:
error_msg = response_dict.get("error_msg")
logging.warning(f"百度云判定出错,错误信息:{error_msg}")
conclusion = f"百度云判定出错,错误信息:{error_msg}\n以下是原消息:{message}"
else:
conclusion = response_dict["conclusion"]
if conclusion in ("合规"):
logging.info(f"百度云判定结果:{conclusion}")
return message
else:
logging.warning(f"百度云判定结果:{conclusion}")
conclusion = self.inappropriate_message_tips
# 返回百度云审核结果
return conclusion
return message

View File

@@ -1,18 +0,0 @@
import re
from ..utils import context
def ignore(msg: str) -> bool:
"""检查消息是否应该被忽略"""
config = context.get_config_manager().data
if 'prefix' in config['ignore_rules']:
for rule in config['ignore_rules']['prefix']:
if msg.startswith(rule):
return True
if 'regexp' in config['ignore_rules']:
for rule in config['ignore_rules']['regexp']:
if re.search(rule, msg):
return True

View File

@@ -12,7 +12,6 @@ import func_timeout
from ..openai import session as openai_session
from ..qqbot import filter as qqbot_filter
from ..qqbot import process as processor
from ..utils import context
from ..plugin import host as plugin_host
@@ -21,6 +20,7 @@ import tips as tips_custom
from ..qqbot import adapter as msadapter
from . import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from ..boot import app
@@ -33,8 +33,6 @@ class QQBotManager:
bot_account_id: int = 0
reply_filter = None
enable_banlist = False
enable_private = True
@@ -47,18 +45,21 @@ class QQBotManager:
ap: app.Application = None
bansess_mgr: bansess.SessionBanManager = None
cntfilter_mgr: cntfilter.ContentFilterManager = None
def __init__(self, first_time_init=True, ap: app.Application = None):
config = context.get_config_manager().data
self.ap = ap
self.bansess_mgr = bansess.SessionBanManager(ap)
self.cntfilter_mgr = cntfilter.ContentFilterManager(ap)
self.timeout = config['process_message_timeout']
self.retry = config['retry_times']
async def initialize(self):
await self.bansess_mgr.initialize()
await self.cntfilter_mgr.initialize()
config = context.get_config_manager().data
@@ -174,20 +175,6 @@ class QQBotManager:
self.unsubscribe_all = unsubscribe_all
config = context.get_config_manager().data
if os.path.exists("sensitive.json") \
and config['sensitive_word_filter'] is not None \
and config['sensitive_word_filter']:
with open("sensitive.json", "r", encoding="utf-8") as f:
sensitive_json = json.load(f)
self.reply_filter = qqbot_filter.ReplyFilter(
sensitive_words=sensitive_json['words'],
mask=sensitive_json['mask'] if 'mask' in sensitive_json else '*',
mask_word=sensitive_json['mask_word'] if 'mask_word' in sensitive_json else ''
)
else:
self.reply_filter = qqbot_filter.ReplyFilter([])
async def send(self, event, msg, check_quote=True, check_at_sender=True):
config = context.get_config_manager().data

View File

@@ -14,10 +14,10 @@ from ..utils import context
from ..plugin import host as plugin_host
from ..plugin import models as plugin_models
from ..qqbot import ignore
from ..qqbot import blob
import tips as tips_custom
from ..boot import app
from .cntfilter import entities
processing = []
@@ -32,7 +32,7 @@ def is_admin(qq: int) -> bool:
async def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: mirai.MessageChain,
sender_id: int) -> mirai.MessageChain:
sender_id: int) -> list:
global processing
mgr = context.get_qqbot_manager()
@@ -40,14 +40,10 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st
reply = []
session_name = "{}_{}".format(launcher_type, launcher_id)
if ignore.ignore(text_message):
logging.info("根据忽略规则忽略消息: {}".format(text_message))
return []
config = context.get_config_manager().data
if not config['wait_last_done'] and session_name in processing:
return mirai.MessageChain([mirai.Plain(tips_custom.message_drop_tip)])
return [mirai.Plain(tips_custom.message_drop_tip)]
# 检查是否被禁言
if launcher_type == 'group':
@@ -56,9 +52,14 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st
logging.info("机器人被禁言,跳过消息处理(group_{})".format(launcher_id))
return reply
if config['income_msg_check']:
if mgr.reply_filter.is_illegal(text_message):
return mirai.MessageChain(mirai.Plain("[bot] 消息中存在不合适的内容, 请更换措辞"))
cntfilter_res = await mgr.cntfilter_mgr.pre_process(text_message)
if cntfilter_res.level == entities.ManagerResultLevel.INTERRUPT:
if cntfilter_res.console_notice:
mgr.ap.logger.info(cntfilter_res.console_notice)
if cntfilter_res.user_notice:
return [mirai.Plain(cntfilter_res.user_notice)]
else:
return []
openai_session.get_session(session_name).acquire_response_lock()
@@ -147,7 +148,16 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st
reply[0][:min(100, len(reply[0]))] + (
"..." if len(reply[0]) > 100 else "")))
if msg_type == 'message':
reply = [mgr.reply_filter.process(reply[0])]
cntfilter_res = await mgr.cntfilter_mgr.post_process(reply[0])
if cntfilter_res.level == entities.ManagerResultLevel.INTERRUPT:
if cntfilter_res.console_notice:
mgr.ap.logger.info(cntfilter_res.console_notice)
if cntfilter_res.user_notice:
return [mirai.Plain(cntfilter_res.user_notice)]
else:
return []
else:
reply = [cntfilter_res.replacement]
reply = blob.check_text(reply[0])
else: