style: introduce ruff as linter and formatter (#1356)

* style: remove necessary imports

* style: fix F841

* style: fix F401

* style: fix F811

* style: fix E402

* style: fix E721

* style: fix E722

* style: fix E722

* style: fix F541

* style: ruff format

* style: all passed

* style: add ruff in deps

* style: more ignores in ruff.toml

* style: add pre-commit
This commit is contained in:
Junyan Qin (Chin)
2025-04-29 17:24:07 +08:00
committed by GitHub
parent 09e70d70e9
commit 209f16af76
240 changed files with 5307 additions and 4689 deletions
+32 -46
View File
@@ -4,20 +4,21 @@ from ...core import app
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine
from ...provider import entities as llm_entities
from ...platform.types import message as platform_message
from ...platform.types import events as platform_events
from ...platform.types import entities as platform_entities
from ...utils import importutil
from . import filters
importutil.import_modules_in_pkg(filters)
@stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage):
"""内容过滤阶段
前置:
检查消息是否符合规则,不符合则拦截。
改写:
@@ -36,13 +37,12 @@ class ContentFilterStage(stage.PipelineStage):
super().__init__(ap)
async def initialize(self, pipeline_config: dict):
filters_required = [
"content-ignore",
'content-ignore',
]
if pipeline_config['safety']['content-filter']['check-sensitive-words']:
filters_required.append("ban-word-filter")
filters_required.append('ban-word-filter')
# TODO revert it
# if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
@@ -50,9 +50,7 @@ class ContentFilterStage(stage.PipelineStage):
for filter in filter_model.preregistered_filters:
if filter.name in filters_required:
self.filter_chain.append(
filter(self.ap)
)
self.filter_chain.append(filter(self.ap))
for filter in self.filter_chain:
await filter.initialize()
@@ -68,8 +66,7 @@ class ContentFilterStage(stage.PipelineStage):
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
else:
for filter in self.filter_chain:
@@ -78,26 +75,25 @@ class ContentFilterStage(stage.PipelineStage):
if result.level in [
filter_entities.ResultLevel.BLOCK,
filter_entities.ResultLevel.MASKED
filter_entities.ResultLevel.MASKED,
]:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
console_notice=result.console_notice,
)
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = platform_message.MessageChain(
platform_message.Plain(message)
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
async def _post_process(
self,
message: str,
@@ -108,8 +104,7 @@ class ContentFilterStage(stage.PipelineStage):
"""
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
else:
message = message.strip()
@@ -122,30 +117,25 @@ class ContentFilterStage(stage.PipelineStage):
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
console_notice=result.console_notice,
)
elif result.level in [
filter_entities.ResultLevel.PASS,
filter_entities.ResultLevel.MASKED
filter_entities.ResultLevel.MASKED,
]:
message = result.replacement
query.resp_messages[-1].content = message
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
async def process(
self,
query: core_entities.Query,
stage_inst_name: str
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
"""处理
"""
"""处理"""
if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False
text_components = [platform_message.Plain, platform_message.Source]
@@ -156,28 +146,24 @@ class ContentFilterStage(stage.PipelineStage):
break
if contain_non_text:
self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。")
self.ap.logger.debug('消息中包含非文本消息,跳过内容过滤器检查。')
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
return await self._pre_process(
str(query.message_chain).strip(),
query
)
return await self._pre_process(str(query.message_chain).strip(), query)
elif stage_inst_name == 'PostContentFilterStage':
# 仅处理 query.resp_messages[-1].content 是 str 的情况
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(query.resp_messages[-1].content, str):
return await self._post_process(
query.resp_messages[-1].content,
query
)
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(
query.resp_messages[-1].content, str
):
return await self._post_process(query.resp_messages[-1].content, query)
else:
self.ap.logger.debug(f"resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。")
self.ap.logger.debug(
'resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。'
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')
+4 -5
View File
@@ -1,14 +1,11 @@
import typing
import enum
import pydantic.v1 as pydantic
from ...provider import entities as llm_entities
class ResultLevel(enum.Enum):
"""结果等级"""
PASS = enum.auto()
"""通过"""
@@ -24,6 +21,7 @@ class ResultLevel(enum.Enum):
class EnableStage(enum.Enum):
"""启用阶段"""
PRE = enum.auto()
"""预处理"""
@@ -55,14 +53,15 @@ class FilterResult(pydantic.BaseModel):
class ManagerResultLevel(enum.Enum):
"""处理器结果等级"""
CONTINUE = enum.auto()
"""继续"""
INTERRUPT = enum.auto()
"""中断"""
class FilterManagerResult(pydantic.BaseModel):
class FilterManagerResult(pydantic.BaseModel):
level: ManagerResultLevel
replacement: str
+8 -10
View File
@@ -5,14 +5,13 @@ import typing
from ...core import app, entities as core_entities
from . import entities
from ...provider import entities as llm_entities
preregistered_filters: list[typing.Type[ContentFilter]] = []
def filter_class(
name: str
name: str,
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
"""内容过滤器类装饰器
@@ -22,6 +21,7 @@ def filter_class(
Returns:
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器
"""
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
assert issubclass(cls, ContentFilter)
@@ -53,23 +53,21 @@ class ContentFilter(metaclass=abc.ABCMeta):
entity.EnableStage.PRE: 消息请求AI前,此时需要检查的内容是用户的输入消息。
entity.EnableStage.POST: 消息请求AI后,此时需要检查的内容是AI的回复消息。
"""
return [
entities.EnableStage.PRE,
entities.EnableStage.POST
]
return [entities.EnableStage.PRE, entities.EnableStage.POST]
async def initialize(self):
"""初始化过滤器
"""
"""初始化过滤器"""
pass
@abc.abstractmethod
async def process(self, query: core_entities.Query, message: str=None, image_url=None) -> entities.FilterResult:
async def process(
self, query: core_entities.Query, message: str = None, image_url=None
) -> entities.FilterResult:
"""处理消息
分为前后阶段,具体取决于 enable_stages 的值。
对于内容过滤器来说,不需要考虑消息所处的阶段,只需要检查消息内容即可。
Args:
message (str): 需要检查的内容
image_url (str): 要检查的图片的 URL
+26 -18
View File
@@ -7,11 +7,11 @@ from .. import filter as filter_model
from ....core import entities as core_entities
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
@filter_model.filter_class("baidu-cloud-examine")
@filter_model.filter_class('baidu-cloud-examine')
class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核"""
@@ -20,44 +20,52 @@ class BaiduCloudExamine(filter_model.ContentFilter):
async with session.post(
BAIDU_EXAMINE_TOKEN_URL,
params={
"grant_type": "client_credentials",
"client_id": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
"client_secret": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret']
}
'grant_type': 'client_credentials',
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
'api-key'
],
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
'api-secret'
],
},
) as resp:
return (await resp.json())['access_token']
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()),
headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'},
data=f"text={message}".encode('utf-8')
headers={
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json',
},
data=f'text={message}'.encode('utf-8'),
) as resp:
result = await resp.json()
if "error_code" in result:
if 'error_code' in result:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice='',
console_notice=f"百度云判定出错,错误信息:{result['error_msg']}"
console_notice=f'百度云判定出错,错误信息:{result["error_msg"]}',
)
else:
conclusion = result["conclusion"]
conclusion = result['conclusion']
if conclusion in ("合规"):
if conclusion in ('合规'):
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=f"百度云判定结果:{conclusion}"
console_notice=f'百度云判定结果:{conclusion}',
)
else:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice="消息中存在不合适的内容, 请修改",
console_notice=f"百度云判定结果:{conclusion}"
user_notice='消息中存在不合适的内容, 请修改',
console_notice=f'百度云判定结果:{conclusion}',
)
+9 -6
View File
@@ -6,14 +6,16 @@ from .. import entities
from ....core import entities as core_entities
@filter_model.filter_class("ban-word-filter")
@filter_model.filter_class('ban-word-filter')
class BanWordFilter(filter_model.ContentFilter):
"""根据内容过滤"""
async def initialize(self):
pass
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
found = False
for word in self.ap.sensitive_meta.data['words']:
@@ -23,9 +25,10 @@ class BanWordFilter(filter_model.ContentFilter):
found = True
for i in range(len(match)):
if self.ap.sensitive_meta.data['mask_word'] == "":
if self.ap.sensitive_meta.data['mask_word'] == '':
message = message.replace(
match[i], self.ap.sensitive_meta.data['mask'] * len(match[i])
match[i],
self.ap.sensitive_meta.data['mask'] * len(match[i]),
)
else:
message = message.replace(
@@ -36,5 +39,5 @@ class BanWordFilter(filter_model.ContentFilter):
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
replacement=message,
user_notice='消息中存在不合适的内容, 请修改' if found else '',
console_notice=''
)
console_notice='',
)
+9 -7
View File
@@ -6,7 +6,7 @@ from .. import filter as filter_model
from ....core import entities as core_entities
@filter_model.filter_class("content-ignore")
@filter_model.filter_class('content-ignore')
class ContentIgnore(filter_model.ContentFilter):
"""根据内容忽略消息"""
@@ -16,7 +16,9 @@ class ContentIgnore(filter_model.ContentFilter):
entities.EnableStage.PRE,
]
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
if message.startswith(rule):
@@ -24,9 +26,9 @@ class ContentIgnore(filter_model.ContentFilter):
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息',
)
if 'regexp' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['regexp']:
if re.search(rule, message):
@@ -34,12 +36,12 @@ class ContentIgnore(filter_model.ContentFilter):
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息'
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息',
)
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=''
)
console_notice='',
)