From 8d084427d2b7c63fe62ff173d80ae7da8218a70e Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Fri, 26 Jan 2024 15:51:49 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E8=AF=B7=E6=B1=82=E5=A4=84?= =?UTF-8?q?=E7=90=86=E6=8E=A7=E5=88=B6=E6=B5=81=E5=9F=BA=E7=A1=80=E6=9E=B6?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/boot/log.py | 54 ----- pkg/{boot => core}/__init__.py | 0 pkg/{boot => core}/app.py | 19 +- pkg/{boot => core}/boot.py | 22 +- .../misc.py => core/bootutils/__init__.py} | 0 pkg/{boot => core/bootutils}/config.py | 4 +- pkg/{boot => core/bootutils}/deps.py | 0 pkg/{boot => core/bootutils}/files.py | 0 pkg/core/bootutils/log.py | 56 +++++ pkg/core/bootutils/misc.py | 0 pkg/core/controller.py | 84 ++++++++ pkg/core/entities.py | 41 ++++ pkg/core/pool.py | 52 +++++ pkg/openai/manager.py | 2 +- pkg/pipeline/__init__.py | 0 pkg/pipeline/bansess/__init__.py | 0 pkg/pipeline/bansess/bansess.py | 76 +++++++ pkg/pipeline/cntfilter/__init__.py | 0 pkg/pipeline/cntfilter/cntfilter.py | 128 ++++++++++++ pkg/pipeline/cntfilter/entities.py | 64 ++++++ pkg/pipeline/cntfilter/filter.py | 34 +++ pkg/pipeline/cntfilter/filters/__init__.py | 0 .../cntfilter/filters/baiduexamine.py | 61 ++++++ pkg/pipeline/cntfilter/filters/banwords.py | 44 ++++ pkg/pipeline/cntfilter/filters/cntignore.py | 43 ++++ pkg/pipeline/entities.py | 38 ++++ pkg/pipeline/longtext/__init__.py | 0 pkg/pipeline/longtext/longtext.py | 57 +++++ pkg/pipeline/longtext/strategies/__init__.py | 0 pkg/pipeline/longtext/strategies/forward.py | 62 ++++++ pkg/pipeline/longtext/strategies/image.py | 197 ++++++++++++++++++ pkg/pipeline/longtext/strategy.py | 22 ++ pkg/pipeline/resprule/__init__.py | 0 pkg/pipeline/resprule/entities.py | 9 + pkg/pipeline/resprule/resprule.py | 62 ++++++ pkg/pipeline/resprule/rule.py | 31 +++ pkg/pipeline/resprule/rules/__init__.py | 0 pkg/pipeline/resprule/rules/atbot.py | 28 +++ pkg/pipeline/resprule/rules/prefix.py | 29 +++ pkg/pipeline/resprule/rules/random.py | 22 ++ pkg/pipeline/resprule/rules/regexp.py | 31 +++ pkg/pipeline/stage.py | 43 ++++ pkg/pipeline/stagemgr.py | 47 +++++ pkg/qqbot/bansess/bansess.py | 2 +- pkg/qqbot/cntfilter/cntfilter.py | 2 +- pkg/qqbot/cntfilter/filter.py | 2 +- pkg/qqbot/longtext/longtext.py | 2 +- pkg/qqbot/longtext/strategy.py | 2 +- pkg/qqbot/manager.py | 92 ++------ pkg/qqbot/process.py | 2 +- pkg/qqbot/ratelim/algo.py | 2 +- pkg/qqbot/ratelim/ratelim.py | 2 +- pkg/qqbot/resprule/resprule.py | 2 +- pkg/qqbot/resprule/rule.py | 2 +- start.py | 2 +- 55 files changed, 1430 insertions(+), 146 deletions(-) delete mode 100644 pkg/boot/log.py rename pkg/{boot => core}/__init__.py (100%) rename pkg/{boot => core}/app.py (65%) rename pkg/{boot => core}/boot.py (88%) rename pkg/{boot/misc.py => core/bootutils/__init__.py} (100%) rename pkg/{boot => core/bootutils}/config.py (85%) rename pkg/{boot => core/bootutils}/deps.py (100%) rename pkg/{boot => core/bootutils}/files.py (100%) create mode 100644 pkg/core/bootutils/log.py create mode 100644 pkg/core/bootutils/misc.py create mode 100644 pkg/core/controller.py create mode 100644 pkg/core/entities.py create mode 100644 pkg/core/pool.py create mode 100644 pkg/pipeline/__init__.py create mode 100644 pkg/pipeline/bansess/__init__.py create mode 100644 pkg/pipeline/bansess/bansess.py create mode 100644 pkg/pipeline/cntfilter/__init__.py create mode 100644 pkg/pipeline/cntfilter/cntfilter.py create mode 100644 pkg/pipeline/cntfilter/entities.py create mode 100644 pkg/pipeline/cntfilter/filter.py create mode 100644 pkg/pipeline/cntfilter/filters/__init__.py create mode 100644 pkg/pipeline/cntfilter/filters/baiduexamine.py create mode 100644 pkg/pipeline/cntfilter/filters/banwords.py create mode 100644 pkg/pipeline/cntfilter/filters/cntignore.py create mode 100644 pkg/pipeline/entities.py create mode 100644 pkg/pipeline/longtext/__init__.py create mode 100644 pkg/pipeline/longtext/longtext.py create mode 100644 pkg/pipeline/longtext/strategies/__init__.py create mode 100644 pkg/pipeline/longtext/strategies/forward.py create mode 100644 pkg/pipeline/longtext/strategies/image.py create mode 100644 pkg/pipeline/longtext/strategy.py create mode 100644 pkg/pipeline/resprule/__init__.py create mode 100644 pkg/pipeline/resprule/entities.py create mode 100644 pkg/pipeline/resprule/resprule.py create mode 100644 pkg/pipeline/resprule/rule.py create mode 100644 pkg/pipeline/resprule/rules/__init__.py create mode 100644 pkg/pipeline/resprule/rules/atbot.py create mode 100644 pkg/pipeline/resprule/rules/prefix.py create mode 100644 pkg/pipeline/resprule/rules/random.py create mode 100644 pkg/pipeline/resprule/rules/regexp.py create mode 100644 pkg/pipeline/stage.py create mode 100644 pkg/pipeline/stagemgr.py diff --git a/pkg/boot/log.py b/pkg/boot/log.py deleted file mode 100644 index e0a15daa..00000000 --- a/pkg/boot/log.py +++ /dev/null @@ -1,54 +0,0 @@ -import logging -import os -import sys -import time - -import colorlog - - -log_colors_config = { - 'DEBUG': 'green', # cyan white - 'INFO': 'white', - 'WARNING': 'yellow', - 'ERROR': 'red', - 'CRITICAL': 'cyan', -} - - -async def init_logging() -> logging.Logger: - - level = logging.INFO - - if 'DEBUG' in os.environ and os.environ['DEBUG'] in ['true', '1']: - level = logging.DEBUG - - log_file_name = "logs/qcg-%s.log" % time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) - - qcg_logger = logging.getLogger("qcg") - - qcg_logger.setLevel(level) - - log_handlers: logging.Handler = [ - logging.StreamHandler(sys.stdout), - logging.FileHandler(log_file_name) - ] - - for handler in log_handlers: - handler.setLevel(level) - handler.setFormatter( - colorlog.ColoredFormatter( - fmt="[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - log_colors=log_colors_config - ) - ) - qcg_logger.addHandler(handler) - - logging.basicConfig(level=level, # 设置日志输出格式 - format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", - # 日志输出的格式 - # -8表示占位符,让输出左对齐,输出长度都为8位 - datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式 - ) - - return qcg_logger \ No newline at end of file diff --git a/pkg/boot/__init__.py b/pkg/core/__init__.py similarity index 100% rename from pkg/boot/__init__.py rename to pkg/core/__init__.py diff --git a/pkg/boot/app.py b/pkg/core/app.py similarity index 65% rename from pkg/boot/app.py rename to pkg/core/app.py index df9e92b9..8c0a0c58 100644 --- a/pkg/boot/app.py +++ b/pkg/core/app.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import asyncio from ..qqbot import manager as qqbot_mgr from ..openai import manager as openai_mgr @@ -8,6 +9,8 @@ from ..config import manager as config_mgr from ..database import manager as database_mgr from ..utils.center import v2 as center_mgr from ..plugin import host as plugin_host +from . import pool, controller +from ..pipeline import stagemgr class Application: @@ -23,16 +26,24 @@ class Application: ctr_mgr: center_mgr.V2CenterAPI = None + query_pool: pool.QueryPool = None + + ctrl: controller.Controller = None + + stage_mgr: stagemgr.StageManager = None + logger: logging.Logger = None def __init__(self): pass - async def initialize(self): - await self.im_mgr.initialize() - async def run(self): # TODO make it async plugin_host.initialize_plugins() - await self.im_mgr.run() \ No newline at end of file + tasks = [ + asyncio.create_task(self.im_mgr.run()), + asyncio.create_task(self.ctrl.run()) + ] + + await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) diff --git a/pkg/boot/boot.py b/pkg/core/boot.py similarity index 88% rename from pkg/boot/boot.py rename to pkg/core/boot.py index b7cc0f38..10fc51b3 100644 --- a/pkg/boot/boot.py +++ b/pkg/core/boot.py @@ -3,12 +3,15 @@ from __future__ import print_function import os import sys -from . import files -from . import deps -from . import log -from . import config +from .bootutils import files +from .bootutils import deps +from .bootutils import log +from .bootutils import config from . import app +from . import pool +from . import controller +from ..pipeline import stagemgr from ..audit import identifier from ..database import manager as db_mgr from ..openai import manager as llm_mgr @@ -86,6 +89,8 @@ async def make_app() -> app.Application: ap.cfg_mgr = cfg_mgr ap.tips_mgr = tips_mgr + ap.query_pool = pool.QueryPool() + center_v2_api = center_v2.V2CenterAPI( basic_info={ "host_id": identifier.identifier['host_id'], @@ -111,8 +116,16 @@ async def make_app() -> app.Application: llm_session.load_sessions() im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap) + await im_mgr_inst.initialize() ap.im_mgr = im_mgr_inst + stage_mgr = stagemgr.StageManager(ap) + await stage_mgr.initialize() + ap.stage_mgr = stage_mgr + + ctrl = controller.Controller(ap) + ap.ctrl = ctrl + # TODO make it async plugin_host.load_plugins() # plugin_host.initialize_plugins() @@ -122,5 +135,4 @@ async def make_app() -> app.Application: async def main(): app_inst = await make_app() - await app_inst.initialize() await app_inst.run() diff --git a/pkg/boot/misc.py b/pkg/core/bootutils/__init__.py similarity index 100% rename from pkg/boot/misc.py rename to pkg/core/bootutils/__init__.py diff --git a/pkg/boot/config.py b/pkg/core/bootutils/config.py similarity index 85% rename from pkg/boot/config.py rename to pkg/core/bootutils/config.py index 3f796214..f1471ae5 100644 --- a/pkg/boot/config.py +++ b/pkg/core/bootutils/config.py @@ -1,7 +1,7 @@ import json -from ..config import manager as config_mgr -from ..config.impls import pymodule +from ...config import manager as config_mgr +from ...config.impls import pymodule load_python_module_config = config_mgr.load_python_module_config diff --git a/pkg/boot/deps.py b/pkg/core/bootutils/deps.py similarity index 100% rename from pkg/boot/deps.py rename to pkg/core/bootutils/deps.py diff --git a/pkg/boot/files.py b/pkg/core/bootutils/files.py similarity index 100% rename from pkg/boot/files.py rename to pkg/core/bootutils/files.py diff --git a/pkg/core/bootutils/log.py b/pkg/core/bootutils/log.py new file mode 100644 index 00000000..4bc0e4de --- /dev/null +++ b/pkg/core/bootutils/log.py @@ -0,0 +1,56 @@ +import logging +import os +import sys +import time + +import colorlog + + +log_colors_config = { + "DEBUG": "green", # cyan white + "INFO": "white", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "cyan", +} + + +async def init_logging() -> logging.Logger: + level = logging.INFO + + if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]: + level = logging.DEBUG + + log_file_name = "logs/qcg-%s.log" % time.strftime( + "%Y-%m-%d-%H-%M-%S", time.localtime() + ) + + qcg_logger = logging.getLogger("qcg") + + qcg_logger.setLevel(level) + + color_formatter = colorlog.ColoredFormatter( + fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + log_colors=log_colors_config, + ) + + stream_handler = logging.StreamHandler(sys.stdout) + + log_handlers: logging.Handler = [stream_handler, logging.FileHandler(log_file_name)] + + for handler in log_handlers: + handler.setLevel(level) + handler.setFormatter(color_formatter) + qcg_logger.addHandler(handler) + + logging.basicConfig( + level=logging.INFO, # 设置日志输出格式 + format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", + # 日志输出的格式 + # -8表示占位符,让输出左对齐,输出长度都为8位 + datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式 + handlers=[logging.NullHandler()], + ) + + return qcg_logger diff --git a/pkg/core/bootutils/misc.py b/pkg/core/bootutils/misc.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/core/controller.py b/pkg/core/controller.py new file mode 100644 index 00000000..2470cbbd --- /dev/null +++ b/pkg/core/controller.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import asyncio +import traceback + +from . import app, entities +from ..pipeline import entities as pipeline_entities + +DEFAULT_QUERY_CONCURRENCY = 10 + + +class Controller: + """总控制器 + """ + ap: app.Application + + semaphore: asyncio.Semaphore = None + """请求并发控制信号量""" + + def __init__(self, ap: app.Application): + self.ap = ap + self.semaphore = asyncio.Semaphore(DEFAULT_QUERY_CONCURRENCY) + + async def consumer(self): + """事件处理循环 + """ + while True: + selected_query: entities.Query = None + + # 取请求 + async with self.ap.query_pool: + queries: list[entities.Query] = self.ap.query_pool.queries + + if queries: + selected_query = queries.pop(0) # FCFS + else: + await self.ap.query_pool.condition.wait() + continue + + if selected_query: + async def _process_query(selected_query): + async with self.semaphore: + await self.process_query(selected_query) + + asyncio.create_task(_process_query(selected_query)) + + async def process_query(self, query: entities.Query): + """处理请求 + """ + self.ap.logger.debug(f"Processing query {query}") + + try: + for stage_container in self.ap.stage_mgr.stage_containers: + res = await stage_container.inst.process(query, stage_container.inst_name) + + self.ap.logger.debug(f"Stage {stage_container.inst_name} res {res}") + + if res.user_notice: + await self.ap.im_mgr.send( + query.message_event, + res.user_notice + ) + if res.debug_notice: + self.ap.logger.debug(res.debug_notice) + if res.console_notice: + self.ap.logger.info(res.console_notice) + + if res.result_type == pipeline_entities.ResultType.INTERRUPT: + self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + break + elif res.result_type == pipeline_entities.ResultType.CONTINUE: + query = res.new_query + continue + + except Exception as e: + self.ap.logger.error(f"处理请求时出错 {query}: {e}") + traceback.print_exc() + finally: + self.ap.logger.debug(f"Query {query} processed") + + async def run(self): + """运行控制器 + """ + await self.consumer() diff --git a/pkg/core/entities.py b/pkg/core/entities.py new file mode 100644 index 00000000..505112ff --- /dev/null +++ b/pkg/core/entities.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import enum +import typing + +import pydantic +import mirai + + +class LauncherTypes(enum.Enum): + + PERSON = 'person' + """私聊""" + + GROUP = 'group' + """群聊""" + + +class Query(pydantic.BaseModel): + """一次请求的信息封装""" + + query_id: int + """请求ID""" + + launcher_type: LauncherTypes + """会话类型""" + + launcher_id: int + """会话ID""" + + sender_id: int + """发送者ID""" + + message_event: mirai.MessageEvent + """事件""" + + message_chain: mirai.MessageChain + """消息链""" + + resp_message_chain: typing.Optional[mirai.MessageChain] = None + """回复消息链""" diff --git a/pkg/core/pool.py b/pkg/core/pool.py new file mode 100644 index 00000000..3d949292 --- /dev/null +++ b/pkg/core/pool.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import asyncio + +import mirai + +from . import entities + + +class QueryPool: + + query_id_counter: int = 0 + + pool_lock: asyncio.Lock + + queries: list[entities.Query] + + condition: asyncio.Condition + + def __init__(self): + self.query_id_counter = 0 + self.pool_lock = asyncio.Lock() + self.queries = [] + self.condition = asyncio.Condition(self.pool_lock) + + async def add_query( + self, + launcher_type: entities.LauncherTypes, + launcher_id: int, + sender_id: int, + message_event: mirai.MessageEvent, + message_chain: mirai.MessageChain + ) -> entities.Query: + async with self.condition: + query = entities.Query( + query_id=self.query_id_counter, + launcher_type=launcher_type, + launcher_id=launcher_id, + sender_id=sender_id, + message_event=message_event, + message_chain=message_chain + ) + self.queries.append(query) + self.query_id_counter += 1 + self.condition.notify_all() + + async def __aenter__(self): + await self.pool_lock.acquire() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.pool_lock.release() diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index 3fd53be6..e070a29f 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -10,7 +10,7 @@ from ..utils import context from ..audit import gatherer from ..openai import modelmgr from ..openai.api import model as api_model -from ..boot import app +from ..core import app class OpenAIInteract: diff --git a/pkg/pipeline/__init__.py b/pkg/pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/bansess/__init__.py b/pkg/pipeline/bansess/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py new file mode 100644 index 00000000..a0f63c36 --- /dev/null +++ b/pkg/pipeline/bansess/bansess.py @@ -0,0 +1,76 @@ +from __future__ import annotations +import re + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class('BanSessionCheckStage') +class BanSessionCheckStage(stage.PipelineStage): + + banlist_mgr: cfg_mgr.ConfigManager + + async def initialize(self): + self.banlist_mgr = await cfg_mgr.load_python_module_config( + "banlist.py", + "res/templates/banlist-template.py" + ) + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str + ) -> entities.StageProcessResult: + + if not self.banlist_mgr.data['enable']: + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + result = False + + if query.launcher_type == 'group': + if not self.banlist_mgr.data['enable_group']: # 未启用群聊响应 + result = True + # 检查是否显式声明发起人QQ要被person忽略 + elif query.sender_id in self.banlist_mgr.data['person']: + result = True + else: + for group_rule in self.banlist_mgr.data['group']: + if type(group_rule) == int: + if group_rule == query.launcher_id: + result = True + elif type(group_rule) == str: + if group_rule.startswith('!'): + reg_str = group_rule[1:] + if re.match(reg_str, str(query.launcher_id)): + result = False + break + else: + if re.match(group_rule, str(query.launcher_id)): + result = True + elif query.launcher_type == 'person': + if not self.banlist_mgr.data['enable_private']: + result = True + else: + for person_rule in self.banlist_mgr.data['person']: + if type(person_rule) == int: + if person_rule == query.launcher_id: + result = True + elif type(person_rule) == str: + if person_rule.startswith('!'): + reg_str = person_rule[1:] + if re.match(reg_str, str(query.launcher_id)): + result = False + break + else: + if re.match(person_rule, str(query.launcher_id)): + result = True + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT, + new_query=query, + debug_notice=f'根据禁用列表忽略消息: {query.launcher_type}_{query.launcher_id}' if result else '' + ) diff --git a/pkg/pipeline/cntfilter/__init__.py b/pkg/pipeline/cntfilter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py new file mode 100644 index 00000000..0025b00a --- /dev/null +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import mirai + +from ...core import app + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr +from . import filter, entities as filter_entities +from .filters import cntignore, banwords, baiduexamine + + +@stage.stage_class('PostContentFilterStage') +@stage.stage_class('PreContentFilterStage') +class ContentFilterStage(stage.PipelineStage): + + filter_chain: list[filter.ContentFilter] + + def __init__(self, ap: app.Application): + self.filter_chain = [] + super().__init__(ap) + + async def initialize(self): + self.filter_chain.append(cntignore.ContentIgnore(self.ap)) + + if self.ap.cfg_mgr.data['sensitive_word_filter']: + self.filter_chain.append(banwords.BanWordFilter(self.ap)) + + if self.ap.cfg_mgr.data['baidu_check']: + self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap)) + + for filter in self.filter_chain: + await filter.initialize() + + async def _pre_process( + self, + message: str, + query: core_entities.Query, + ) -> entities.StageProcessResult: + """请求llm前处理消息 + 只要有一个不通过就不放行,只放行 PASS 的消息 + """ + if not self.ap.cfg_mgr.data['income_msg_check']: + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + for filter in self.filter_chain: + if filter_entities.EnableStage.PRE in filter.enable_stages: + result = await filter.process(message) + + if result.level in [ + filter_entities.ResultLevel.BLOCK, + filter_entities.ResultLevel.MASKED + ]: + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query, + user_notice=result.user_notice, + console_notice=result.console_notice + ) + elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个 + message = result.replacement + + query.message_chain = mirai.MessageChain( + mirai.Plain(message) + ) + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + async def _post_process( + self, + message: str, + query: core_entities.Query, + ) -> entities.StageProcessResult: + """请求llm后处理响应 + 只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter + """ + for filter in self.filter_chain: + if filter_entities.EnableStage.POST in filter.enable_stages: + result = await filter.process(message) + + if result.level == filter_entities.ResultLevel.BLOCK: + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query, + user_notice=result.user_notice, + console_notice=result.console_notice + ) + elif result.level in [ + filter_entities.ResultLevel.PASS, + filter_entities.ResultLevel.MASKED + ]: + message = result.replacement + + query.message_chain = mirai.MessageChain( + mirai.Plain(message) + ) + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str + ) -> entities.StageProcessResult: + """处理 + """ + if stage_inst_name == 'PreContentFilterStage': + return await self._pre_process( + str(query.message_chain).strip(), + query + ) + elif stage_inst_name == 'PostContentFilterStage': + return await self._post_process( + str(query.message_chain).strip(), + query + ) + else: + raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}') diff --git a/pkg/pipeline/cntfilter/entities.py b/pkg/pipeline/cntfilter/entities.py new file mode 100644 index 00000000..7ab05675 --- /dev/null +++ b/pkg/pipeline/cntfilter/entities.py @@ -0,0 +1,64 @@ + +import typing +import enum + +import pydantic + + +class ResultLevel(enum.Enum): + """结果等级""" + PASS = enum.auto() + """通过""" + + WARN = enum.auto() + """警告""" + + MASKED = enum.auto() + """已掩去""" + + BLOCK = enum.auto() + """阻止""" + + +class EnableStage(enum.Enum): + """启用阶段""" + PRE = enum.auto() + """预处理""" + + POST = enum.auto() + """后处理""" + + +class FilterResult(pydantic.BaseModel): + level: ResultLevel + + replacement: str + """替换后的消息""" + + user_notice: str + """不通过时,用户提示消息""" + + console_notice: str + """不通过时,控制台提示消息""" + + +class ManagerResultLevel(enum.Enum): + """处理器结果等级""" + CONTINUE = enum.auto() + """继续""" + + INTERRUPT = enum.auto() + """中断""" + +class FilterManagerResult(pydantic.BaseModel): + + level: ManagerResultLevel + + replacement: str + """替换后的消息""" + + user_notice: str + """用户提示消息""" + + console_notice: str + """控制台提示消息""" diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py new file mode 100644 index 00000000..57792145 --- /dev/null +++ b/pkg/pipeline/cntfilter/filter.py @@ -0,0 +1,34 @@ +# 内容过滤器的抽象类 +from __future__ import annotations +import abc + +from ...core import app +from . import entities + + +class ContentFilter(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + @property + def enable_stages(self): + """启用的阶段 + """ + return [ + entities.EnableStage.PRE, + entities.EnableStage.POST + ] + + async def initialize(self): + """初始化过滤器 + """ + pass + + @abc.abstractmethod + async def process(self, message: str) -> entities.FilterResult: + """处理消息 + """ + raise NotImplementedError diff --git a/pkg/pipeline/cntfilter/filters/__init__.py b/pkg/pipeline/cntfilter/filters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py new file mode 100644 index 00000000..a658897b --- /dev/null +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import aiohttp + +from .. import entities +from .. import filter as filter_model + + +BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}" +BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" + + +class BaiduCloudExamine(filter_model.ContentFilter): + """百度云内容审核""" + + async def _get_token(self) -> str: + async with aiohttp.ClientSession() as session: + async with session.post( + BAIDU_EXAMINE_TOKEN_URL, + params={ + "grant_type": "client_credentials", + "client_id": self.ap.cfg_mgr.data['baidu_api_key'], + "client_secret": self.ap.cfg_mgr.data['baidu_secret_key'] + } + ) as resp: + return (await resp.json())['access_token'] + + async def process(self, message: str) -> entities.FilterResult: + + async with aiohttp.ClientSession() as session: + async with session.post( + BAIDU_EXAMINE_URL.format(await self._get_token()), + headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'}, + data=f"text={message}".encode('utf-8') + ) as resp: + result = await resp.json() + + if "error_code" in result: + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement=message, + user_notice='', + console_notice=f"百度云判定出错,错误信息:{result['error_msg']}" + ) + else: + conclusion = result["conclusion"] + + if conclusion in ("合规"): + return entities.FilterResult( + level=entities.ResultLevel.PASS, + replacement=message, + user_notice='', + console_notice=f"百度云判定结果:{conclusion}" + ) + else: + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement=message, + user_notice=self.ap.cfg_mgr.data['inappropriate_message_tips'], + console_notice=f"百度云判定结果:{conclusion}" + ) \ No newline at end of file diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py new file mode 100644 index 00000000..9451c7b8 --- /dev/null +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -0,0 +1,44 @@ +from __future__ import annotations +import re + +from .. import filter as filter_model +from .. import entities +from ....config import manager as cfg_mgr + + +class BanWordFilter(filter_model.ContentFilter): + """根据内容禁言""" + + sensitive: cfg_mgr.ConfigManager + + async def initialize(self): + self.sensitive = await cfg_mgr.load_json_config( + "sensitive.json", + "res/templates/sensitive-template.json" + ) + + async def process(self, message: str) -> entities.FilterResult: + found = False + + for word in self.sensitive.data['words']: + match = re.findall(word, message) + + if len(match) > 0: + found = True + + for i in range(len(match)): + if self.sensitive.data['mask_word'] == "": + message = message.replace( + match[i], self.sensitive.data['mask'] * len(match[i]) + ) + else: + message = message.replace( + match[i], self.sensitive.data['mask_word'] + ) + + return entities.FilterResult( + level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS, + replacement=message, + user_notice='[bot] 消息中存在不合适的内容, 请更换措辞' if found else '', + console_notice='' + ) \ No newline at end of file diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py new file mode 100644 index 00000000..81408868 --- /dev/null +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -0,0 +1,43 @@ +from __future__ import annotations +import re + +from .. import entities +from .. import filter as filter_model + + +class ContentIgnore(filter_model.ContentFilter): + """根据内容忽略消息""" + + @property + def enable_stages(self): + return [ + entities.EnableStage.PRE, + ] + + async def process(self, message: str) -> entities.FilterResult: + if 'prefix' in self.ap.cfg_mgr.data['ignore_rules']: + for rule in self.ap.cfg_mgr.data['ignore_rules']['prefix']: + if message.startswith(rule): + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement='', + user_notice='', + console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息' + ) + + if 'regexp' in self.ap.cfg_mgr.data['ignore_rules']: + for rule in self.ap.cfg_mgr.data['ignore_rules']['regexp']: + if re.search(rule, message): + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement='', + user_notice='', + console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息' + ) + + return entities.FilterResult( + level=entities.ResultLevel.PASS, + replacement=message, + user_notice='', + console_notice='' + ) \ No newline at end of file diff --git a/pkg/pipeline/entities.py b/pkg/pipeline/entities.py new file mode 100644 index 00000000..e687c082 --- /dev/null +++ b/pkg/pipeline/entities.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import enum +import typing + +import pydantic +import mirai +import mirai.models.message as mirai_message + +from ..core import entities + + +class ResultType(enum.Enum): + + CONTINUE = enum.auto() + """继续流水线""" + + INTERRUPT = enum.auto() + """中断流水线""" + + +class StageProcessResult(pydantic.BaseModel): + + result_type: ResultType + + new_query: entities.Query + + user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = [] + """只要设置了就会发送给用户""" + + admin_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = [] + """只要设置了就会发送给管理员""" + + console_notice: typing.Optional[str] = '' + """只要设置了就会输出到控制台""" + + debug_notice: typing.Optional[str] = '' + diff --git a/pkg/pipeline/longtext/__init__.py b/pkg/pipeline/longtext/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py new file mode 100644 index 00000000..11144891 --- /dev/null +++ b/pkg/pipeline/longtext/longtext.py @@ -0,0 +1,57 @@ +from __future__ import annotations +import os +import traceback + +from PIL import Image, ImageDraw, ImageFont +from mirai.models.message import MessageComponent, Plain, MessageChain + +from ...core import app +from . import strategy +from .strategies import image, forward +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class("LongTextProcessStage") +class LongTextProcessStage(stage.PipelineStage): + + strategy_impl: strategy.LongTextStrategy + + async def initialize(self): + config = self.ap.cfg_mgr.data + if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image': + use_font = config['font_path'] + try: + # 检查是否存在 + if not os.path.exists(use_font): + # 若是windows系统,使用微软雅黑 + if os.name == "nt": + use_font = "C:/Windows/Fonts/msyh.ttc" + if not os.path.exists(use_font): + self.ap.logger.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") + config['blob_message_strategy'] = "forward" + else: + self.ap.logger.info("使用Windows自带字体:" + use_font) + self.ap.cfg_mgr.data['font_path'] = use_font + else: + self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") + self.ap.cfg_mgr.data['blob_message_strategy'] = "forward" + except: + traceback.print_exc() + self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) + self.ap.cfg_mgr.data['blob_message_strategy'] = "forward" + + if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image': + self.strategy_impl = image.Text2ImageStrategy(self.ap) + elif self.ap.cfg_mgr.data['blob_message_strategy'] == 'forward': + self.strategy_impl = forward.ForwardComponentStrategy(self.ap) + await self.strategy_impl.initialize() + + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + if len(str(query.resp_message_chain)) > self.ap.cfg_mgr.data['blob_message_threshold']: + query.message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain))) + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) \ No newline at end of file diff --git a/pkg/pipeline/longtext/strategies/__init__.py b/pkg/pipeline/longtext/strategies/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/longtext/strategies/forward.py b/pkg/pipeline/longtext/strategies/forward.py new file mode 100644 index 00000000..d1b5c36c --- /dev/null +++ b/pkg/pipeline/longtext/strategies/forward.py @@ -0,0 +1,62 @@ +# 转发消息组件 +from __future__ import annotations +import typing + +from mirai.models import MessageChain +from mirai.models.message import MessageComponent, ForwardMessageNode +from mirai.models.base import MiraiBaseModel + +from .. import strategy as strategy_model + + +class ForwardMessageDiaplay(MiraiBaseModel): + title: str = "群聊的聊天记录" + brief: str = "[聊天记录]" + source: str = "聊天记录" + preview: typing.List[str] = [] + summary: str = "查看x条转发消息" + + +class Forward(MessageComponent): + """合并转发。""" + type: str = "Forward" + """消息组件类型。""" + display: ForwardMessageDiaplay + """显示信息""" + node_list: typing.List[ForwardMessageNode] + """转发消息节点列表。""" + def __init__(self, *args, **kwargs): + if len(args) == 1: + self.node_list = args[0] + super().__init__(**kwargs) + super().__init__(*args, **kwargs) + + def __str__(self): + return '[聊天记录]' + + +class ForwardComponentStrategy(strategy_model.LongTextStrategy): + + async def process(self, message: str) -> list[MessageComponent]: + display = ForwardMessageDiaplay( + title="群聊的聊天记录", + brief="[聊天记录]", + source="聊天记录", + preview=["QQ用户: "+message], + summary="查看1条转发消息" + ) + + node_list = [ + ForwardMessageNode( + sender_id=self.ap.im_mgr.bot_account_id, + sender_name='QQ用户', + message_chain=MessageChain([message]) + ) + ] + + forward = Forward( + display=display, + node_list=node_list + ) + + return [forward] diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py new file mode 100644 index 00000000..4f789098 --- /dev/null +++ b/pkg/pipeline/longtext/strategies/image.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import typing +import os +import base64 +import time +import re + +from PIL import Image, ImageDraw, ImageFont + +from mirai.models import MessageChain, Image as ImageComponent +from mirai.models.message import MessageComponent + +from .. import strategy as strategy_model + + +class Text2ImageStrategy(strategy_model.LongTextStrategy): + + text_render_font: ImageFont.FreeTypeFont + + async def initialize(self): + self.text_render_font = ImageFont.truetype(self.ap.cfg_mgr.data['font_path'], 32, encoding="utf-8") + + async def process(self, message: str) -> list[MessageComponent]: + img_path = self.text_to_image( + text_str=message, + save_as='temp/{}.png'.format(int(time.time())) + ) + + compressed_path, size = self.compress_image( + img_path, + outfile="temp/{}_compressed.png".format(int(time.time())) + ) + + with open(compressed_path, 'rb') as f: + img = f.read() + + b64 = base64.b64encode(img) + + # 删除图片 + os.remove(img_path) + + if os.path.exists(compressed_path): + os.remove(compressed_path) + + return [ + ImageComponent( + base64=b64.decode('utf-8'), + ) + ] + + def indexNumber(self, path=''): + """ + 查找字符串中数字所在串中的位置 + :param path:目标字符串 + :return:: : [['1', 16], ['2', 35], ['1', 51]] + """ + kv = [] + nums = [] + beforeDatas = re.findall('[\d]+', path) + for num in beforeDatas: + indexV = [] + times = path.count(num) + if times > 1: + if num not in nums: + indexs = re.finditer(num, path) + for index in indexs: + iV = [] + i = index.span()[0] + iV.append(num) + iV.append(i) + kv.append(iV) + nums.append(num) + else: + index = path.find(num) + indexV.append(num) + indexV.append(index) + kv.append(indexV) + # 根据数字位置排序 + indexSort = [] + resultIndex = [] + for vi in kv: + indexSort.append(vi[1]) + indexSort.sort() + for i in indexSort: + for v in kv: + if i == v[1]: + resultIndex.append(v) + return resultIndex + + + def get_size(self, file): + # 获取文件大小:KB + size = os.path.getsize(file) + return size / 1024 + + + def get_outfile(self, infile, outfile): + if outfile: + return outfile + dir, suffix = os.path.splitext(infile) + outfile = '{}-out{}'.format(dir, suffix) + return outfile + + + def compress_image(self, infile, outfile='', kb=100, step=20, quality=90): + """不改变图片尺寸压缩到指定大小 + :param infile: 压缩源文件 + :param outfile: 压缩文件保存地址 + :param mb: 压缩目标,KB + :param step: 每次调整的压缩比率 + :param quality: 初始压缩比率 + :return: 压缩文件地址,压缩文件大小 + """ + o_size = self.get_size(infile) + if o_size <= kb: + return infile, o_size + outfile = self.get_outfile(infile, outfile) + while o_size > kb: + im = Image.open(infile) + im.save(outfile, quality=quality) + if quality - step < 0: + break + quality -= step + o_size = self.get_size(outfile) + return outfile, self.get_size(outfile) + + + def text_to_image(self, text_str: str, save_as="temp.png", width=800): + + text_str = text_str.replace("\t", " ") + + # 分行 + lines = text_str.split('\n') + + # 计算并分割 + final_lines = [] + + text_width = width-80 + + self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width)) + for line in lines: + # 如果长了就分割 + line_width = self.text_render_font.getlength(line) + self.ap.logger.debug("line_width: {}".format(line_width)) + if line_width < text_width: + final_lines.append(line) + continue + else: + rest_text = line + while True: + # 分割最前面的一行 + point = int(len(rest_text) * (text_width / line_width)) + + # 检查断点是否在数字中间 + numbers = self.indexNumber(rest_text) + + for number in numbers: + if number[1] < point < number[1] + len(number[0]) and number[1] != 0: + point = number[1] + break + + final_lines.append(rest_text[:point]) + rest_text = rest_text[point:] + line_width = self.text_render_font.getlength(rest_text) + if line_width < text_width: + final_lines.append(rest_text) + break + else: + continue + # 准备画布 + img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)) + draw = ImageDraw.Draw(img, mode='RGBA') + + self.ap.logger.debug("正在绘制图片...") + # 绘制正文 + line_number = 0 + offset_x = 20 + offset_y = 30 + for final_line in final_lines: + draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=self.text_render_font) + # 遍历此行,检查是否有emoji + idx_in_line = 0 + for ch in final_line: + # 检查字符占位宽 + char_code = ord(ch) + if char_code >= 127: + idx_in_line += 1 + else: + idx_in_line += 0.5 + + line_number += 1 + + self.ap.logger.debug("正在保存图片...") + img.save(save_as) + + return save_as diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py new file mode 100644 index 00000000..5c6bfb9c --- /dev/null +++ b/pkg/pipeline/longtext/strategy.py @@ -0,0 +1,22 @@ +from __future__ import annotations +import abc +import typing + +import mirai +from mirai.models.message import MessageComponent + +from ...core import app + + +class LongTextStrategy(metaclass=abc.ABCMeta): + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def process(self, message: str) -> list[MessageComponent]: + return [] diff --git a/pkg/pipeline/resprule/__init__.py b/pkg/pipeline/resprule/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/resprule/entities.py b/pkg/pipeline/resprule/entities.py new file mode 100644 index 00000000..ffee3081 --- /dev/null +++ b/pkg/pipeline/resprule/entities.py @@ -0,0 +1,9 @@ +import pydantic +import mirai + + +class RuleJudgeResult(pydantic.BaseModel): + + matching: bool = False + + replacement: mirai.MessageChain = None diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py new file mode 100644 index 00000000..6335a7d4 --- /dev/null +++ b/pkg/pipeline/resprule/resprule.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import mirai + +from ...core import app +from . import entities as rule_entities, rule +from .rules import atbot, prefix, regexp, random + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class("GroupRespondRuleCheckStage") +class GroupRespondRuleCheckStage(stage.PipelineStage): + """群组响应规则检查器 + """ + + rule_matchers: list[rule.GroupRespondRule] + + async def initialize(self): + """初始化检查器 + """ + self.rule_matchers = [ + atbot.AtBotRule(self.ap), + prefix.PrefixRule(self.ap), + regexp.RegExpRule(self.ap), + random.RandomRespRule(self.ap), + ] + + for rule_matcher in self.rule_matchers: + await rule_matcher.initialize() + + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + + if query.launcher_type != 'group': + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + rules = self.ap.cfg_mgr.data['response_rules'] + + use_rule = rules['default'] + + if str(query.launcher_id) in use_rule: + use_rule = use_rule[str(query.launcher_id)] + + for rule_matcher in self.rule_matchers: # 任意一个匹配就放行 + res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule) + if res.matching: + query.message_chain = res.replacement + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query, + ) + + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) diff --git a/pkg/pipeline/resprule/rule.py b/pkg/pipeline/resprule/rule.py new file mode 100644 index 00000000..e530d063 --- /dev/null +++ b/pkg/pipeline/resprule/rule.py @@ -0,0 +1,31 @@ +from __future__ import annotations +import abc + +import mirai + +from ...core import app +from . import entities + + +class GroupRespondRule(metaclass=abc.ABCMeta): + """群组响应规则的抽象类 + """ + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + """判断消息是否匹配规则 + """ + raise NotImplementedError diff --git a/pkg/pipeline/resprule/rules/__init__.py b/pkg/pipeline/resprule/rules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/resprule/rules/atbot.py b/pkg/pipeline/resprule/rules/atbot.py new file mode 100644 index 00000000..eefc4891 --- /dev/null +++ b/pkg/pipeline/resprule/rules/atbot.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import mirai + +from .. import rule as rule_model +from .. import entities + + +class AtBotRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + + if message_chain.has(mirai.At(self.ap.im_mgr.bot_account_id)) and rule_dict['at']: + message_chain.remove(mirai.At(self.ap.im_mgr.bot_account_id)) + return entities.RuleJudgeResult( + matching=True, + replacement=message_chain, + ) + + return entities.RuleJudgeResult( + matching=False, + replacement = message_chain + ) diff --git a/pkg/pipeline/resprule/rules/prefix.py b/pkg/pipeline/resprule/rules/prefix.py new file mode 100644 index 00000000..31ead5ab --- /dev/null +++ b/pkg/pipeline/resprule/rules/prefix.py @@ -0,0 +1,29 @@ +import mirai + +from .. import rule as rule_model +from .. import entities + + +class PrefixRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + prefixes = rule_dict['prefix'] + + for prefix in prefixes: + if message_text.startswith(prefix): + return entities.RuleJudgeResult( + matching=True, + replacement=mirai.MessageChain([ + mirai.Plain(message_text[len(prefix):]) + ]), + ) + + return entities.RuleJudgeResult( + matching=False, + replacement=message_chain + ) diff --git a/pkg/pipeline/resprule/rules/random.py b/pkg/pipeline/resprule/rules/random.py new file mode 100644 index 00000000..1e8354b5 --- /dev/null +++ b/pkg/pipeline/resprule/rules/random.py @@ -0,0 +1,22 @@ +import random + +import mirai + +from .. import rule as rule_model +from .. import entities + + +class RandomRespRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + random_rate = rule_dict['random_rate'] + + return entities.RuleJudgeResult( + matching=random.random() < random_rate, + replacement=message_chain + ) \ No newline at end of file diff --git a/pkg/pipeline/resprule/rules/regexp.py b/pkg/pipeline/resprule/rules/regexp.py new file mode 100644 index 00000000..0d621fe4 --- /dev/null +++ b/pkg/pipeline/resprule/rules/regexp.py @@ -0,0 +1,31 @@ +import re + +import mirai + +from .. import rule as rule_model +from .. import entities + + +class RegExpRule(rule_model.GroupRespondRule): + + async def match( + self, + message_text: str, + message_chain: mirai.MessageChain, + rule_dict: dict + ) -> entities.RuleJudgeResult: + regexps = rule_dict['regexp'] + + for regexp in regexps: + match = re.match(regexp, message_text) + + if match: + return entities.RuleJudgeResult( + matching=True, + replacement=message_chain, + ) + + return entities.RuleJudgeResult( + matching=False, + replacement=message_chain + ) diff --git a/pkg/pipeline/stage.py b/pkg/pipeline/stage.py new file mode 100644 index 00000000..84a0339d --- /dev/null +++ b/pkg/pipeline/stage.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import abc + +from ..core import app, entities as core_entities +from . import entities + + +_stage_classes: dict[str, PipelineStage] = {} + + +def stage_class(name: str): + + def decorator(cls): + _stage_classes[name] = cls + return cls + + return decorator + + +class PipelineStage(metaclass=abc.ABCMeta): + """流水线阶段 + """ + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + """初始化 + """ + pass + + @abc.abstractmethod + async def process( + self, + query: core_entities.Query, + stage_inst_name: str, + ) -> entities.StageProcessResult: + """处理 + """ + raise NotImplementedError diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py new file mode 100644 index 00000000..f5407a2e --- /dev/null +++ b/pkg/pipeline/stagemgr.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import pydantic + +from ..core import app +from . import stage +from .resprule import resprule +from .bansess import bansess +from .cntfilter import cntfilter +from .longtext import longtext + + +class StageInstContainer(): + """阶段实例容器 + """ + + inst_name: str + + inst: stage.PipelineStage + + def __init__(self, inst_name: str, inst: stage.PipelineStage): + self.inst_name = inst_name + self.inst = inst + + +class StageManager: + ap: app.Application + + stage_containers: list[StageInstContainer] + + def __init__(self, ap: app.Application): + self.ap = ap + + self.stage_containers = [] + + async def initialize(self): + """初始化 + """ + + for name, cls in stage._stage_classes.items(): + self.stage_containers.append(StageInstContainer( + inst_name=name, + inst=cls(self.ap) + )) + + for stage_containers in self.stage_containers: + await stage_containers.inst.initialize() diff --git a/pkg/qqbot/bansess/bansess.py b/pkg/qqbot/bansess/bansess.py index d8ef4958..74ffd3f7 100644 --- a/pkg/qqbot/bansess/bansess.py +++ b/pkg/qqbot/bansess/bansess.py @@ -3,7 +3,7 @@ from __future__ import annotations import re -from ...boot import app +from ...core import app from ...config import manager as cfg_mgr diff --git a/pkg/qqbot/cntfilter/cntfilter.py b/pkg/qqbot/cntfilter/cntfilter.py index 2d690b57..4c7305c0 100644 --- a/pkg/qqbot/cntfilter/cntfilter.py +++ b/pkg/qqbot/cntfilter/cntfilter.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ...boot import app +from ...core import app from . import entities from . import filter from .filters import cntignore, banwords, baiduexamine diff --git a/pkg/qqbot/cntfilter/filter.py b/pkg/qqbot/cntfilter/filter.py index 4d4cd79f..57792145 100644 --- a/pkg/qqbot/cntfilter/filter.py +++ b/pkg/qqbot/cntfilter/filter.py @@ -2,7 +2,7 @@ from __future__ import annotations import abc -from ...boot import app +from ...core import app from . import entities diff --git a/pkg/qqbot/longtext/longtext.py b/pkg/qqbot/longtext/longtext.py index 21267880..697f65e4 100644 --- a/pkg/qqbot/longtext/longtext.py +++ b/pkg/qqbot/longtext/longtext.py @@ -5,7 +5,7 @@ import traceback from PIL import Image, ImageDraw, ImageFont from mirai.models.message import MessageComponent, Plain -from ...boot import app +from ...core import app from . import strategy from .strategies import image, forward diff --git a/pkg/qqbot/longtext/strategy.py b/pkg/qqbot/longtext/strategy.py index ef4cc1a5..5c6bfb9c 100644 --- a/pkg/qqbot/longtext/strategy.py +++ b/pkg/qqbot/longtext/strategy.py @@ -5,7 +5,7 @@ import typing import mirai from mirai.models.message import MessageComponent -from ...boot import app +from ...core import app class LongTextStrategy(metaclass=abc.ABCMeta): diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 5239604f..b16450e8 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -24,7 +24,7 @@ from .cntfilter import cntfilter from .longtext import longtext from .ratelim import ratelim -from ..boot import app +from ..core import app, entities as core_entities # 控制QQ消息输入输出的类 @@ -91,45 +91,29 @@ class QQBotManager: # Caution: 注册新的事件处理器之后,请务必在unsubscribe_all中编写相应的取消订阅代码 async def on_friend_message(event: FriendMessage): - async def friend_message_handler(): - # 触发事件 - args = { - "launcher_type": "person", - "launcher_id": event.sender.id, - "sender_id": event.sender.id, - "message_chain": event.message_chain, - } - plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args) + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain + ) - if plugin_event.is_prevented_default(): - return - - await self.on_person_message(event) - - asyncio.create_task(friend_message_handler()) self.adapter.register_listener( FriendMessage, on_friend_message ) async def on_stranger_message(event: StrangerMessage): + + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain + ) - async def stranger_message_handler(): - # 触发事件 - args = { - "launcher_type": "person", - "launcher_id": event.sender.id, - "sender_id": event.sender.id, - "message_chain": event.message_chain, - } - plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args) - - if plugin_event.is_prevented_default(): - return - - await self.on_person_message(event) - - asyncio.create_task(stranger_message_handler()) # nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件 if config['msg_source_adapter'] == 'yirimirai': self.adapter.register_listener( @@ -139,49 +123,19 @@ class QQBotManager: async def on_group_message(event: GroupMessage): - async def group_message_handler(event: GroupMessage): - # 触发事件 - args = { - "launcher_type": "group", - "launcher_id": event.group.id, - "sender_id": event.sender.id, - "message_chain": event.message_chain, - } - plugin_event = plugin_host.emit(plugin_models.GroupMessageReceived, **args) - - if plugin_event.is_prevented_default(): - return - - await self.on_group_message(event) - - asyncio.create_task(group_message_handler(event)) + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.GROUP, + launcher_id=event.group.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain + ) self.adapter.register_listener( GroupMessage, on_group_message ) - def unsubscribe_all(): - """取消所有订阅 - - 用于在热重载流程中卸载所有事件处理器 - """ - self.adapter.unregister_listener( - FriendMessage, - on_friend_message - ) - if config['msg_source_adapter'] == 'yirimirai': - self.adapter.unregister_listener( - StrangerMessage, - on_stranger_message - ) - self.adapter.unregister_listener( - GroupMessage, - on_group_message - ) - - self.unsubscribe_all = unsubscribe_all - async def send(self, event, msg, check_quote=True, check_at_sender=True): config = context.get_config_manager().data diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index e1673583..65de8d52 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -14,7 +14,7 @@ from ..utils import context from ..plugin import host as plugin_host from ..plugin import models as plugin_models import tips as tips_custom -from ..boot import app +from ..core import app from .cntfilter import entities processing = [] diff --git a/pkg/qqbot/ratelim/algo.py b/pkg/qqbot/ratelim/algo.py index 10bbdd3a..b6d9ba7b 100644 --- a/pkg/qqbot/ratelim/algo.py +++ b/pkg/qqbot/ratelim/algo.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from ...boot import app +from ...core import app class ReteLimitAlgo(metaclass=abc.ABCMeta): diff --git a/pkg/qqbot/ratelim/ratelim.py b/pkg/qqbot/ratelim/ratelim.py index ab23d714..68fe0316 100644 --- a/pkg/qqbot/ratelim/ratelim.py +++ b/pkg/qqbot/ratelim/ratelim.py @@ -2,7 +2,7 @@ from __future__ import annotations from . import algo from .algos import fixedwin -from ...boot import app +from ...core import app class RateLimiter: diff --git a/pkg/qqbot/resprule/resprule.py b/pkg/qqbot/resprule/resprule.py index f0c51921..9ea8321d 100644 --- a/pkg/qqbot/resprule/resprule.py +++ b/pkg/qqbot/resprule/resprule.py @@ -2,7 +2,7 @@ from __future__ import annotations import mirai -from ...boot import app +from ...core import app from . import entities, rule from .rules import atbot, prefix, regexp, random diff --git a/pkg/qqbot/resprule/rule.py b/pkg/qqbot/resprule/rule.py index 67af0204..e530d063 100644 --- a/pkg/qqbot/resprule/rule.py +++ b/pkg/qqbot/resprule/rule.py @@ -3,7 +3,7 @@ import abc import mirai -from ...boot import app +from ...core import app from . import entities diff --git a/start.py b/start.py index f22012ee..b56ea9e9 100644 --- a/start.py +++ b/start.py @@ -1,6 +1,6 @@ import asyncio -from pkg.boot import boot +from pkg.core import boot if __name__ == '__main__':