mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-18 03:34:20 +00:00
style: introduce ruff as linter and formatter (#1356)
* style: remove necessary imports * style: fix F841 * style: fix F401 * style: fix F811 * style: fix E402 * style: fix E721 * style: fix E722 * style: fix E722 * style: fix F541 * style: ruff format * style: all passed * style: add ruff in deps * style: more ignores in ruff.toml * style: add pre-commit
This commit is contained in:
committed by
GitHub
parent
09e70d70e9
commit
209f16af76
@@ -4,20 +4,21 @@ from ...core import app
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from . import filter as filter_model, entities as filter_entities
|
||||
from .filters import cntignore, banwords, baiduexamine
|
||||
from ...provider import entities as llm_entities
|
||||
from ...platform.types import message as platform_message
|
||||
from ...platform.types import events as platform_events
|
||||
from ...platform.types import entities as platform_entities
|
||||
from ...utils import importutil
|
||||
|
||||
from . import filters
|
||||
|
||||
importutil.import_modules_in_pkg(filters)
|
||||
|
||||
|
||||
@stage.stage_class('PostContentFilterStage')
|
||||
@stage.stage_class('PreContentFilterStage')
|
||||
class ContentFilterStage(stage.PipelineStage):
|
||||
"""内容过滤阶段
|
||||
|
||||
|
||||
前置:
|
||||
检查消息是否符合规则,不符合则拦截。
|
||||
改写:
|
||||
@@ -36,13 +37,12 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
super().__init__(ap)
|
||||
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
|
||||
filters_required = [
|
||||
"content-ignore",
|
||||
'content-ignore',
|
||||
]
|
||||
|
||||
if pipeline_config['safety']['content-filter']['check-sensitive-words']:
|
||||
filters_required.append("ban-word-filter")
|
||||
filters_required.append('ban-word-filter')
|
||||
|
||||
# TODO revert it
|
||||
# if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
|
||||
@@ -50,9 +50,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
for filter in filter_model.preregistered_filters:
|
||||
if filter.name in filters_required:
|
||||
self.filter_chain.append(
|
||||
filter(self.ap)
|
||||
)
|
||||
self.filter_chain.append(filter(self.ap))
|
||||
|
||||
for filter in self.filter_chain:
|
||||
await filter.initialize()
|
||||
@@ -68,8 +66,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
else:
|
||||
for filter in self.filter_chain:
|
||||
@@ -78,26 +75,25 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
if result.level in [
|
||||
filter_entities.ResultLevel.BLOCK,
|
||||
filter_entities.ResultLevel.MASKED
|
||||
filter_entities.ResultLevel.MASKED,
|
||||
]:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice=result.user_notice,
|
||||
console_notice=result.console_notice
|
||||
console_notice=result.console_notice,
|
||||
)
|
||||
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
|
||||
message = result.replacement
|
||||
|
||||
|
||||
query.message_chain = platform_message.MessageChain(
|
||||
platform_message.Plain(message)
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
|
||||
|
||||
async def _post_process(
|
||||
self,
|
||||
message: str,
|
||||
@@ -108,8 +104,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
"""
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
else:
|
||||
message = message.strip()
|
||||
@@ -122,30 +117,25 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice=result.user_notice,
|
||||
console_notice=result.console_notice
|
||||
console_notice=result.console_notice,
|
||||
)
|
||||
elif result.level in [
|
||||
filter_entities.ResultLevel.PASS,
|
||||
filter_entities.ResultLevel.MASKED
|
||||
filter_entities.ResultLevel.MASKED,
|
||||
]:
|
||||
message = result.replacement
|
||||
|
||||
query.resp_messages[-1].content = message
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
"""处理"""
|
||||
if stage_inst_name == 'PreContentFilterStage':
|
||||
|
||||
contain_non_text = False
|
||||
|
||||
text_components = [platform_message.Plain, platform_message.Source]
|
||||
@@ -156,28 +146,24 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
break
|
||||
|
||||
if contain_non_text:
|
||||
self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。")
|
||||
self.ap.logger.debug('消息中包含非文本消息,跳过内容过滤器检查。')
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
|
||||
return await self._pre_process(
|
||||
str(query.message_chain).strip(),
|
||||
query
|
||||
)
|
||||
return await self._pre_process(str(query.message_chain).strip(), query)
|
||||
elif stage_inst_name == 'PostContentFilterStage':
|
||||
# 仅处理 query.resp_messages[-1].content 是 str 的情况
|
||||
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(query.resp_messages[-1].content, str):
|
||||
return await self._post_process(
|
||||
query.resp_messages[-1].content,
|
||||
query
|
||||
)
|
||||
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(
|
||||
query.resp_messages[-1].content, str
|
||||
):
|
||||
return await self._post_process(query.resp_messages[-1].content, query)
|
||||
else:
|
||||
self.ap.logger.debug(f"resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。")
|
||||
self.ap.logger.debug(
|
||||
'resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。'
|
||||
)
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
|
||||
import typing
|
||||
import enum
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
class ResultLevel(enum.Enum):
|
||||
"""结果等级"""
|
||||
|
||||
PASS = enum.auto()
|
||||
"""通过"""
|
||||
|
||||
@@ -24,6 +21,7 @@ class ResultLevel(enum.Enum):
|
||||
|
||||
class EnableStage(enum.Enum):
|
||||
"""启用阶段"""
|
||||
|
||||
PRE = enum.auto()
|
||||
"""预处理"""
|
||||
|
||||
@@ -55,14 +53,15 @@ class FilterResult(pydantic.BaseModel):
|
||||
|
||||
class ManagerResultLevel(enum.Enum):
|
||||
"""处理器结果等级"""
|
||||
|
||||
CONTINUE = enum.auto()
|
||||
"""继续"""
|
||||
|
||||
INTERRUPT = enum.auto()
|
||||
"""中断"""
|
||||
|
||||
class FilterManagerResult(pydantic.BaseModel):
|
||||
|
||||
class FilterManagerResult(pydantic.BaseModel):
|
||||
level: ManagerResultLevel
|
||||
|
||||
replacement: str
|
||||
|
||||
@@ -5,14 +5,13 @@ import typing
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
preregistered_filters: list[typing.Type[ContentFilter]] = []
|
||||
|
||||
|
||||
def filter_class(
|
||||
name: str
|
||||
name: str,
|
||||
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
|
||||
"""内容过滤器类装饰器
|
||||
|
||||
@@ -22,6 +21,7 @@ def filter_class(
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器
|
||||
"""
|
||||
|
||||
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
|
||||
assert issubclass(cls, ContentFilter)
|
||||
|
||||
@@ -53,23 +53,21 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
entity.EnableStage.PRE: 消息请求AI前,此时需要检查的内容是用户的输入消息。
|
||||
entity.EnableStage.POST: 消息请求AI后,此时需要检查的内容是AI的回复消息。
|
||||
"""
|
||||
return [
|
||||
entities.EnableStage.PRE,
|
||||
entities.EnableStage.POST
|
||||
]
|
||||
return [entities.EnableStage.PRE, entities.EnableStage.POST]
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化过滤器
|
||||
"""
|
||||
"""初始化过滤器"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, query: core_entities.Query, message: str=None, image_url=None) -> entities.FilterResult:
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str = None, image_url=None
|
||||
) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
|
||||
分为前后阶段,具体取决于 enable_stages 的值。
|
||||
对于内容过滤器来说,不需要考虑消息所处的阶段,只需要检查消息内容即可。
|
||||
|
||||
|
||||
Args:
|
||||
message (str): 需要检查的内容
|
||||
image_url (str): 要检查的图片的 URL
|
||||
|
||||
@@ -7,11 +7,11 @@ from .. import filter as filter_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
|
||||
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
|
||||
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
|
||||
|
||||
|
||||
@filter_model.filter_class("baidu-cloud-examine")
|
||||
@filter_model.filter_class('baidu-cloud-examine')
|
||||
class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
"""百度云内容审核"""
|
||||
|
||||
@@ -20,44 +20,52 @@ class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_TOKEN_URL,
|
||||
params={
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
|
||||
"client_secret": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret']
|
||||
}
|
||||
'grant_type': 'client_credentials',
|
||||
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
|
||||
'api-key'
|
||||
],
|
||||
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
|
||||
'api-secret'
|
||||
],
|
||||
},
|
||||
) as resp:
|
||||
return (await resp.json())['access_token']
|
||||
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str
|
||||
) -> entities.FilterResult:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'},
|
||||
data=f"text={message}".encode('utf-8')
|
||||
headers={
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
data=f'text={message}'.encode('utf-8'),
|
||||
) as resp:
|
||||
result = await resp.json()
|
||||
|
||||
if "error_code" in result:
|
||||
if 'error_code' in result:
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement=message,
|
||||
user_notice='',
|
||||
console_notice=f"百度云判定出错,错误信息:{result['error_msg']}"
|
||||
console_notice=f'百度云判定出错,错误信息:{result["error_msg"]}',
|
||||
)
|
||||
else:
|
||||
conclusion = result["conclusion"]
|
||||
conclusion = result['conclusion']
|
||||
|
||||
if conclusion in ("合规"):
|
||||
if conclusion in ('合规'):
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.PASS,
|
||||
replacement=message,
|
||||
user_notice='',
|
||||
console_notice=f"百度云判定结果:{conclusion}"
|
||||
console_notice=f'百度云判定结果:{conclusion}',
|
||||
)
|
||||
else:
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement=message,
|
||||
user_notice="消息中存在不合适的内容, 请修改",
|
||||
console_notice=f"百度云判定结果:{conclusion}"
|
||||
user_notice='消息中存在不合适的内容, 请修改',
|
||||
console_notice=f'百度云判定结果:{conclusion}',
|
||||
)
|
||||
|
||||
@@ -6,14 +6,16 @@ from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@filter_model.filter_class("ban-word-filter")
|
||||
@filter_model.filter_class('ban-word-filter')
|
||||
class BanWordFilter(filter_model.ContentFilter):
|
||||
"""根据内容过滤"""
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str
|
||||
) -> entities.FilterResult:
|
||||
found = False
|
||||
|
||||
for word in self.ap.sensitive_meta.data['words']:
|
||||
@@ -23,9 +25,10 @@ class BanWordFilter(filter_model.ContentFilter):
|
||||
found = True
|
||||
|
||||
for i in range(len(match)):
|
||||
if self.ap.sensitive_meta.data['mask_word'] == "":
|
||||
if self.ap.sensitive_meta.data['mask_word'] == '':
|
||||
message = message.replace(
|
||||
match[i], self.ap.sensitive_meta.data['mask'] * len(match[i])
|
||||
match[i],
|
||||
self.ap.sensitive_meta.data['mask'] * len(match[i]),
|
||||
)
|
||||
else:
|
||||
message = message.replace(
|
||||
@@ -36,5 +39,5 @@ class BanWordFilter(filter_model.ContentFilter):
|
||||
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
|
||||
replacement=message,
|
||||
user_notice='消息中存在不合适的内容, 请修改' if found else '',
|
||||
console_notice=''
|
||||
)
|
||||
console_notice='',
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ from .. import filter as filter_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@filter_model.filter_class("content-ignore")
|
||||
@filter_model.filter_class('content-ignore')
|
||||
class ContentIgnore(filter_model.ContentFilter):
|
||||
"""根据内容忽略消息"""
|
||||
|
||||
@@ -16,7 +16,9 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
entities.EnableStage.PRE,
|
||||
]
|
||||
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str
|
||||
) -> entities.FilterResult:
|
||||
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
|
||||
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
|
||||
if message.startswith(rule):
|
||||
@@ -24,9 +26,9 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement='',
|
||||
user_notice='',
|
||||
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
|
||||
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息',
|
||||
)
|
||||
|
||||
|
||||
if 'regexp' in query.pipeline_config['trigger']['ignore-rules']:
|
||||
for rule in query.pipeline_config['trigger']['ignore-rules']['regexp']:
|
||||
if re.search(rule, message):
|
||||
@@ -34,12 +36,12 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement='',
|
||||
user_notice='',
|
||||
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息'
|
||||
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息',
|
||||
)
|
||||
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.PASS,
|
||||
replacement=message,
|
||||
user_notice='',
|
||||
console_notice=''
|
||||
)
|
||||
console_notice='',
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user