Files
LangBot/pkg/pipeline/ratelimit/ratelimit.py
T
Junyan Qin (Chin) 209f16af76 style: introduce ruff as linter and formatter (#1356)
* style: remove necessary imports

* style: fix F841

* style: fix F401

* style: fix F811

* style: fix E402

* style: fix E721

* style: fix E722

* style: fix E722

* style: fix F541

* style: ruff format

* style: all passed

* style: add ruff in deps

* style: more ignores in ruff.toml

* style: add pre-commit
2025-04-29 17:24:07 +08:00

76 lines
2.3 KiB
Python

from __future__ import annotations
import typing
from .. import entities, stage
from . import algo
from ...core import entities as core_entities
from ...utils import importutil
from . import algos
importutil.import_modules_in_pkg(algos)
@stage.stage_class('RequireRateLimitOccupancy')
@stage.stage_class('ReleaseRateLimitOccupancy')
class RateLimit(stage.PipelineStage):
"""限速器控制阶段
不改写query,只检查是否需要限速。
"""
algo: algo.ReteLimitAlgo
async def initialize(self, pipeline_config: dict):
algo_name = 'fixwin'
algo_class = None
for algo_cls in algo.preregistered_algos:
if algo_cls.name == algo_name:
algo_class = algo_cls
break
else:
raise ValueError(f'未知的限速算法: {algo_name}')
self.algo = algo_class(self.ap)
await self.algo.initialize()
async def process(
self,
query: core_entities.Query,
stage_inst_name: str,
) -> typing.Union[
entities.StageProcessResult,
typing.AsyncGenerator[entities.StageProcessResult, None],
]:
"""处理"""
if stage_inst_name == 'RequireRateLimitOccupancy':
if await self.algo.require_access(
query,
query.launcher_type.value,
query.launcher_id,
):
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)
else:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
console_notice=f'根据限速规则忽略 {query.launcher_type.value}:{query.launcher_id} 消息',
user_notice='请求数超过限速器设定值,已丢弃本消息。',
)
elif stage_inst_name == 'ReleaseRateLimitOccupancy':
await self.algo.release_access(
query,
query.launcher_type.value,
query.launcher_id,
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)