refactor: 请求处理控制流基础架构

This commit is contained in:
RockChinQ
2024-01-26 15:51:49 +08:00
parent a064c24f60
commit 8d084427d2
55 changed files with 1430 additions and 146 deletions
View File
+128
View File
@@ -0,0 +1,128 @@
from __future__ import annotations
import mirai
from ...core import app
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine
@stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage):
filter_chain: list[filter.ContentFilter]
def __init__(self, ap: app.Application):
self.filter_chain = []
super().__init__(ap)
async def initialize(self):
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
if self.ap.cfg_mgr.data['sensitive_word_filter']:
self.filter_chain.append(banwords.BanWordFilter(self.ap))
if self.ap.cfg_mgr.data['baidu_check']:
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
for filter in self.filter_chain:
await filter.initialize()
async def _pre_process(
self,
message: str,
query: core_entities.Query,
) -> entities.StageProcessResult:
"""请求llm前处理消息
只要有一个不通过就不放行,只放行 PASS 的消息
"""
if not self.ap.cfg_mgr.data['income_msg_check']:
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
for filter in self.filter_chain:
if filter_entities.EnableStage.PRE in filter.enable_stages:
result = await filter.process(message)
if result.level in [
filter_entities.ResultLevel.BLOCK,
filter_entities.ResultLevel.MASKED
]:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = mirai.MessageChain(
mirai.Plain(message)
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
async def _post_process(
self,
message: str,
query: core_entities.Query,
) -> entities.StageProcessResult:
"""请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
"""
for filter in self.filter_chain:
if filter_entities.EnableStage.POST in filter.enable_stages:
result = await filter.process(message)
if result.level == filter_entities.ResultLevel.BLOCK:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level in [
filter_entities.ResultLevel.PASS,
filter_entities.ResultLevel.MASKED
]:
message = result.replacement
query.message_chain = mirai.MessageChain(
mirai.Plain(message)
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
async def process(
self,
query: core_entities.Query,
stage_inst_name: str
) -> entities.StageProcessResult:
"""处理
"""
if stage_inst_name == 'PreContentFilterStage':
return await self._pre_process(
str(query.message_chain).strip(),
query
)
elif stage_inst_name == 'PostContentFilterStage':
return await self._post_process(
str(query.message_chain).strip(),
query
)
else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')
+64
View File
@@ -0,0 +1,64 @@
import typing
import enum
import pydantic
class ResultLevel(enum.Enum):
"""结果等级"""
PASS = enum.auto()
"""通过"""
WARN = enum.auto()
"""警告"""
MASKED = enum.auto()
"""已掩去"""
BLOCK = enum.auto()
"""阻止"""
class EnableStage(enum.Enum):
"""启用阶段"""
PRE = enum.auto()
"""预处理"""
POST = enum.auto()
"""后处理"""
class FilterResult(pydantic.BaseModel):
level: ResultLevel
replacement: str
"""替换后的消息"""
user_notice: str
"""不通过时,用户提示消息"""
console_notice: str
"""不通过时,控制台提示消息"""
class ManagerResultLevel(enum.Enum):
"""处理器结果等级"""
CONTINUE = enum.auto()
"""继续"""
INTERRUPT = enum.auto()
"""中断"""
class FilterManagerResult(pydantic.BaseModel):
level: ManagerResultLevel
replacement: str
"""替换后的消息"""
user_notice: str
"""用户提示消息"""
console_notice: str
"""控制台提示消息"""
+34
View File
@@ -0,0 +1,34 @@
# 内容过滤器的抽象类
from __future__ import annotations
import abc
from ...core import app
from . import entities
class ContentFilter(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
@property
def enable_stages(self):
"""启用的阶段
"""
return [
entities.EnableStage.PRE,
entities.EnableStage.POST
]
async def initialize(self):
"""初始化过滤器
"""
pass
@abc.abstractmethod
async def process(self, message: str) -> entities.FilterResult:
"""处理消息
"""
raise NotImplementedError
@@ -0,0 +1,61 @@
from __future__ import annotations
import aiohttp
from .. import entities
from .. import filter as filter_model
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核"""
async def _get_token(self) -> str:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_TOKEN_URL,
params={
"grant_type": "client_credentials",
"client_id": self.ap.cfg_mgr.data['baidu_api_key'],
"client_secret": self.ap.cfg_mgr.data['baidu_secret_key']
}
) as resp:
return (await resp.json())['access_token']
async def process(self, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()),
headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'},
data=f"text={message}".encode('utf-8')
) as resp:
result = await resp.json()
if "error_code" in result:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice='',
console_notice=f"百度云判定出错,错误信息:{result['error_msg']}"
)
else:
conclusion = result["conclusion"]
if conclusion in ("合规"):
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=f"百度云判定结果:{conclusion}"
)
else:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice=self.ap.cfg_mgr.data['inappropriate_message_tips'],
console_notice=f"百度云判定结果:{conclusion}"
)
@@ -0,0 +1,44 @@
from __future__ import annotations
import re
from .. import filter as filter_model
from .. import entities
from ....config import manager as cfg_mgr
class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言"""
sensitive: cfg_mgr.ConfigManager
async def initialize(self):
self.sensitive = await cfg_mgr.load_json_config(
"sensitive.json",
"res/templates/sensitive-template.json"
)
async def process(self, message: str) -> entities.FilterResult:
found = False
for word in self.sensitive.data['words']:
match = re.findall(word, message)
if len(match) > 0:
found = True
for i in range(len(match)):
if self.sensitive.data['mask_word'] == "":
message = message.replace(
match[i], self.sensitive.data['mask'] * len(match[i])
)
else:
message = message.replace(
match[i], self.sensitive.data['mask_word']
)
return entities.FilterResult(
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
replacement=message,
user_notice='[bot] 消息中存在不合适的内容, 请更换措辞' if found else '',
console_notice=''
)
@@ -0,0 +1,43 @@
from __future__ import annotations
import re
from .. import entities
from .. import filter as filter_model
class ContentIgnore(filter_model.ContentFilter):
"""根据内容忽略消息"""
@property
def enable_stages(self):
return [
entities.EnableStage.PRE,
]
async def process(self, message: str) -> entities.FilterResult:
if 'prefix' in self.ap.cfg_mgr.data['ignore_rules']:
for rule in self.ap.cfg_mgr.data['ignore_rules']['prefix']:
if message.startswith(rule):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
)
if 'regexp' in self.ap.cfg_mgr.data['ignore_rules']:
for rule in self.ap.cfg_mgr.data['ignore_rules']['regexp']:
if re.search(rule, message):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息'
)
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=''
)