feat: 恢复ratelimit

This commit is contained in:
RockChinQ
2024-02-01 18:38:20 +08:00
parent 0dec10ddf2
commit f340a44abf
8 changed files with 58 additions and 41 deletions

View File

@@ -0,0 +1,55 @@
from __future__ import annotations
import typing
from .. import entities, stagemgr, stage
from . import algo
from .algos import fixedwin
from ...core import entities as core_entities
@stage.stage_class("RequireRateLimitOccupancy")
@stage.stage_class("ReleaseRateLimitOccupancy")
class RateLimit(stage.PipelineStage):
algo: algo.ReteLimitAlgo
async def initialize(self):
self.algo = fixedwin.FixedWindowAlgo(self.ap)
await self.algo.initialize()
async def process(
self,
query: core_entities.Query,
stage_inst_name: str,
) -> typing.Union[
entities.StageProcessResult,
typing.AsyncGenerator[entities.StageProcessResult, None],
]:
"""处理
"""
if stage_inst_name == "RequireRateLimitOccupancy":
if await self.algo.require_access(
query.launcher_type.value,
query.launcher_id,
):
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)
else:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
console_notice=f"根据限速规则忽略 {query.launcher_type.value}:{query.launcher_id} 消息",
user_notice=self.ap.tips_mgr.data['rate_limit_drop_tip']
)
elif stage_inst_name == "ReleaseRateLimitOccupancy":
await self.algo.release_access(
query.launcher_type,
query.launcher_id,
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)

View File

@@ -12,6 +12,7 @@ from .longtext import longtext
from .respback import respback
from .wrapper import wrapper
from .preproc import preproc
from .ratelimit import ratelimit
stage_order = [
@@ -19,7 +20,9 @@ stage_order = [
"BanSessionCheckStage",
"PreContentFilterStage",
"PreProcessor",
"RequireRateLimitOccupancy",
"MessageProcessor",
"ReleaseRateLimitOccupancy",
"PostContentFilterStage",
"ResponseWrapper",
"LongTextProcessStage",

View File

@@ -9,13 +9,7 @@ import traceback
from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \
FriendMessage, Image, MessageChain, Plain
import mirai
import func_timeout
from ..provider import session as openai_session
import tips as tips_custom
from ..platform import adapter as msadapter
from .ratelim import ratelim
from ..core import app, entities as core_entities
from ..plugin import events
@@ -31,15 +25,11 @@ class PlatformManager:
# modern
ap: app.Application = None
ratelimiter: ratelim.RateLimiter = None
def __init__(self, ap: app.Application = None):
self.ap = ap
self.ratelimiter = ratelim.RateLimiter(ap)
async def initialize(self):
await self.ratelimiter.initialize()
config = self.ap.cfg_mgr.data

View File

@@ -1,31 +0,0 @@
from __future__ import annotations
from . import algo
from .algos import fixedwin
from ...core import app
class RateLimiter:
"""限速器
"""
ap: app.Application
algo: algo.ReteLimitAlgo
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
self.algo = fixedwin.FixedWindowAlgo(self.ap)
await self.algo.initialize()
async def require(self, launcher_type: str, launcher_id: int) -> bool:
"""请求访问
"""
return await self.algo.require_access(launcher_type, launcher_id)
async def release(self, launcher_type: str, launcher_id: int):
"""释放访问
"""
return await self.algo.release_access(launcher_type, launcher_id)