diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 5e2aa4d2..fee2cd3f 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -25,14 +25,14 @@ class ContentFilterStage(stage.PipelineStage): async def initialize(self): filters_required = [ - "ContentIgnore" + "content-filter" ] if self.ap.pipeline_cfg.data['check-sensitive-words']: - filters_required.append("BanWordFilter") + filters_required.append("ban-word-filter") if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']: - filters_required.append("BaiduCloudExamine") + filters_required.append("baidu-cloud-examine") for filter in filter_model.preregistered_filters: if filter.name in filters_required: diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py index faa4bb6b..8c5b77cd 100644 --- a/pkg/pipeline/cntfilter/filters/baiduexamine.py +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -10,7 +10,7 @@ BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" -@filter_model.filter_class("BaiduCloudExamine") +@filter_model.filter_class("baidu-cloud-examine") class BaiduCloudExamine(filter_model.ContentFilter): """百度云内容审核""" diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index c94374c8..9391971c 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -6,7 +6,7 @@ from .. import entities from ....config import manager as cfg_mgr -@filter_model.filter_class("BanWordFilter") +@filter_model.filter_class("ban-word-filter") class BanWordFilter(filter_model.ContentFilter): """根据内容禁言""" diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py index baafeef0..781f6397 100644 --- a/pkg/pipeline/cntfilter/filters/cntignore.py +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -5,7 +5,7 @@ from .. import entities from .. import filter as filter_model -@filter_model.filter_class("ContentIgnore") +@filter_model.filter_class("content-ignore") class ContentIgnore(filter_model.ContentFilter): """根据内容忽略消息""" diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 2962ae28..2095845d 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -45,11 +45,14 @@ class LongTextProcessStage(stage.PipelineStage): self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" - - if config['strategy'] == 'image': - self.strategy_impl = image.Text2ImageStrategy(self.ap) - elif config['strategy'] == 'forward': - self.strategy_impl = forward.ForwardComponentStrategy(self.ap) + + for strategy_cls in strategy.preregistered_strategies: + if strategy_cls.name == config['strategy']: + self.strategy_impl = strategy_cls(self.ap) + break + else: + raise ValueError(f"未找到名为 {config['strategy']} 的长消息处理策略") + await self.strategy_impl.initialize() async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: diff --git a/pkg/pipeline/longtext/strategies/forward.py b/pkg/pipeline/longtext/strategies/forward.py index cfab49d9..4a790313 100644 --- a/pkg/pipeline/longtext/strategies/forward.py +++ b/pkg/pipeline/longtext/strategies/forward.py @@ -36,6 +36,7 @@ class Forward(MessageComponent): return '[聊天记录]' +@strategy_model.strategy_class("forward") class ForwardComponentStrategy(strategy_model.LongTextStrategy): async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py index af34f4e6..f96f03c5 100644 --- a/pkg/pipeline/longtext/strategies/image.py +++ b/pkg/pipeline/longtext/strategies/image.py @@ -15,6 +15,7 @@ from .. import strategy as strategy_model from ....core import entities as core_entities +@strategy_model.strategy_class("image") class Text2ImageStrategy(strategy_model.LongTextStrategy): text_render_font: ImageFont.FreeTypeFont diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py index a1f8a94f..296c5b4c 100644 --- a/pkg/pipeline/longtext/strategy.py +++ b/pkg/pipeline/longtext/strategy.py @@ -9,7 +9,30 @@ from ...core import app from ...core import entities as core_entities +preregistered_strategies: list[typing.Type[LongTextStrategy]] = [] + + +def strategy_class( + name: str +) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: + def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]: + assert issubclass(cls, LongTextStrategy) + + cls.name = name + + preregistered_strategies.append(cls) + + return cls + + return decorator + + class LongTextStrategy(metaclass=abc.ABCMeta): + """长文本处理策略抽象类 + """ + + name: str + ap: app.Application def __init__(self, ap: app.Application):