style: introduce ruff as linter and formatter (#1356)

* style: remove necessary imports

* style: fix F841

* style: fix F401

* style: fix F811

* style: fix E402

* style: fix E721

* style: fix E722

* style: fix E722

* style: fix F541

* style: ruff format

* style: all passed

* style: add ruff in deps

* style: more ignores in ruff.toml

* style: add pre-commit
This commit is contained in:
Junyan Qin (Chin)
2025-04-29 17:24:07 +08:00
committed by GitHub
parent 09e70d70e9
commit 209f16af76
240 changed files with 5307 additions and 4689 deletions

View File

@@ -3,33 +3,38 @@ from __future__ import annotations
from .. import stage, entities
from ...core import entities as core_entities
from . import truncator
from .truncators import round
from ...utils import importutil
from . import truncators
importutil.import_modules_in_pkg(truncators)
@stage.stage_class("ConversationMessageTruncator")
@stage.stage_class('ConversationMessageTruncator')
class ConversationMessageTruncator(stage.PipelineStage):
"""会话消息截断器
用于截断会话消息链,以适应平台消息长度限制。
"""
trun: truncator.Truncator
async def initialize(self, pipeline_config: dict):
use_method = "round"
use_method = 'round'
for trun in truncator.preregistered_truncators:
if trun.name == use_method:
self.trun = trun(self.ap)
break
else:
raise ValueError(f"未知的截断器: {use_method}")
raise ValueError(f'未知的截断器: {use_method}')
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理
"""
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
"""处理"""
query = await self.trun.truncate(query)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
result_type=entities.ResultType.CONTINUE, new_query=query
)

View File

@@ -10,7 +10,7 @@ preregistered_truncators: list[typing.Type[Truncator]] = []
def truncator_class(
name: str
name: str,
) -> typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]:
"""截断器类装饰器
@@ -20,6 +20,7 @@ def truncator_class(
Returns:
typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: 装饰器
"""
def decorator(cls: typing.Type[Truncator]) -> typing.Type[Truncator]:
assert issubclass(cls, Truncator)
@@ -33,13 +34,12 @@ def truncator_class(
class Truncator(abc.ABC):
"""消息截断器基类
"""
"""消息截断器基类"""
name: str
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap

View File

@@ -4,14 +4,12 @@ from .. import truncator
from ....core import entities as core_entities
@truncator.truncator_class("round")
@truncator.truncator_class('round')
class RoundTruncator(truncator.Truncator):
"""前文回合数阶段器
"""
"""前文回合数阶段器"""
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
"""截断
"""
"""截断"""
max_round = query.pipeline_config['ai']['local-agent']['max-round']
temp_messages = []
@@ -26,7 +24,7 @@ class RoundTruncator(truncator.Truncator):
current_round += 1
else:
break
query.messages = temp_messages[::-1]
return query