From 22cb8a6a0678bf441eba9cdb6dd8f59a6b097a4c Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Fri, 8 Mar 2024 20:22:06 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=86=85=E5=AE=B9=E8=BF=87=E6=BB=A4?= =?UTF-8?q?=E5=99=A8=E7=9A=84=E5=8F=AF=E6=89=A9=E5=B1=95=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/operator.py | 2 +- pkg/pipeline/cntfilter/cntfilter.py | 21 +++++++++---- pkg/pipeline/cntfilter/filter.py | 30 +++++++++++++++++++ .../cntfilter/filters/baiduexamine.py | 1 + pkg/pipeline/cntfilter/filters/banwords.py | 1 + pkg/pipeline/cntfilter/filters/cntignore.py | 1 + pkg/platform/manager.py | 19 ------------ pkg/platform/sources/nakuru.py | 2 ++ pkg/platform/sources/qqbotpy.py | 2 ++ 9 files changed, 53 insertions(+), 26 deletions(-) diff --git a/pkg/command/operator.py b/pkg/command/operator.py index 641a8cf5..307e9fbe 100644 --- a/pkg/command/operator.py +++ b/pkg/command/operator.py @@ -30,7 +30,7 @@ def operator_class( parent_class (typing.Type[CommandOperator], optional): 父节点,若为None则为顶级命令. Defaults to None. Returns: - typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 注册后的命令类 + typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 装饰器 """ def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]: diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 92157bdd..5e2aa4d2 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -7,7 +7,7 @@ from ...core import app from .. import stage, entities, stagemgr from ...core import entities as core_entities from ...config import manager as cfg_mgr -from . import filter, entities as filter_entities +from . import filter as filter_model, entities as filter_entities from .filters import cntignore, banwords, baiduexamine @@ -16,20 +16,29 @@ from .filters import cntignore, banwords, baiduexamine class ContentFilterStage(stage.PipelineStage): """内容过滤阶段""" - filter_chain: list[filter.ContentFilter] + filter_chain: list[filter_model.ContentFilter] def __init__(self, ap: app.Application): self.filter_chain = [] super().__init__(ap) async def initialize(self): - self.filter_chain.append(cntignore.ContentIgnore(self.ap)) + + filters_required = [ + "ContentIgnore" + ] if self.ap.pipeline_cfg.data['check-sensitive-words']: - self.filter_chain.append(banwords.BanWordFilter(self.ap)) - + filters_required.append("BanWordFilter") + if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']: - self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap)) + filters_required.append("BaiduCloudExamine") + + for filter in filter_model.preregistered_filters: + if filter.name in filters_required: + self.filter_chain.append( + filter(self.ap) + ) for filter in self.filter_chain: await filter.initialize() diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 57792145..23471392 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -1,12 +1,42 @@ # 内容过滤器的抽象类 from __future__ import annotations import abc +import typing from ...core import app from . import entities +preregistered_filters: list[typing.Type[ContentFilter]] = [] + + +def filter_class( + name: str +) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: + """内容过滤器类装饰器 + + Args: + name (str): 过滤器名称 + + Returns: + typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器 + """ + def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]: + assert issubclass(cls, ContentFilter) + + cls.name = name + + preregistered_filters.append(cls) + + return cls + + return decorator + + class ContentFilter(metaclass=abc.ABCMeta): + """内容过滤器抽象类""" + + name: str ap: app.Application diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py index f72fe960..faa4bb6b 100644 --- a/pkg/pipeline/cntfilter/filters/baiduexamine.py +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -10,6 +10,7 @@ BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" +@filter_model.filter_class("BaiduCloudExamine") class BaiduCloudExamine(filter_model.ContentFilter): """百度云内容审核""" diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index 587f81c3..c94374c8 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -6,6 +6,7 @@ from .. import entities from ....config import manager as cfg_mgr +@filter_model.filter_class("BanWordFilter") class BanWordFilter(filter_model.ContentFilter): """根据内容禁言""" diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py index 92fe94e8..baafeef0 100644 --- a/pkg/pipeline/cntfilter/filters/cntignore.py +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -5,6 +5,7 @@ from .. import entities from .. import filter as filter_model +@filter_model.filter_class("ContentIgnore") class ContentIgnore(filter_model.ContentFilter): """根据内容忽略消息""" diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 3d73c198..7b40f2ab 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -163,25 +163,6 @@ class PlatformManager: quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False ) - # 通知系统管理员 - # TODO delete - # async def notify_admin(self, message: str): - # await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))])) - - # async def notify_admin_message_chain(self, message: mirai.MessageChain): - # if self.ap.system_cfg.data['admin-sessions'] != []: - - # admin_list = [] - # for admin in self.ap.system_cfg.data['admin-sessions']: - # admin_list.append(admin) - - # for adm in admin_list: - # self.adapter.send_message( - # adm.split("_")[0], - # adm.split("_")[1], - # message - # ) - async def run(self): try: tasks = [] diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/nakuru.py index 0a419a06..0b3b8c09 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -24,6 +24,8 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): msg_list = message_chain.__root__ elif type(message_chain) is list: msg_list = message_chain + elif type(message_chain) is str: + msg_list = [mirai.Plain(message_chain)] else: raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/qqbotpy.py index 6d74d0ea..313249a0 100644 --- a/pkg/platform/sources/qqbotpy.py +++ b/pkg/platform/sources/qqbotpy.py @@ -89,6 +89,8 @@ class OfficialMessageConverter(adapter_model.MessageConverter): msg_list = message_chain.__root__ elif type(message_chain) is list: msg_list = message_chain + elif type(message_chain) is str: + msg_list = [mirai.Plain(text=message_chain)] else: raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))