From b43882aad0b879b15d460a65666bfb4c1fe7ab11 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 25 Jan 2024 22:35:15 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=8B=AC=E7=AB=8Bratelimiter?= =?UTF-8?q?=E5=8C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/qqbot/manager.py | 4 ++ pkg/qqbot/process.py | 15 +---- pkg/qqbot/ratelim/__init__.py | 0 pkg/qqbot/ratelim/algo.py | 24 ++++++++ pkg/qqbot/ratelim/algos/__init__.py | 0 pkg/qqbot/ratelim/algos/fixedwin.py | 85 +++++++++++++++++++++++++++ pkg/qqbot/ratelim/ratelim.py | 31 ++++++++++ pkg/qqbot/ratelimit.py | 89 ----------------------------- pkg/qqbot/resprule/entities.py | 1 - 9 files changed, 147 insertions(+), 102 deletions(-) create mode 100644 pkg/qqbot/ratelim/__init__.py create mode 100644 pkg/qqbot/ratelim/algo.py create mode 100644 pkg/qqbot/ratelim/algos/__init__.py create mode 100644 pkg/qqbot/ratelim/algos/fixedwin.py create mode 100644 pkg/qqbot/ratelim/ratelim.py delete mode 100644 pkg/qqbot/ratelimit.py diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index a973ab6d..5239604f 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -22,6 +22,7 @@ from .resprule import resprule from .bansess import bansess from .cntfilter import cntfilter from .longtext import longtext +from .ratelim import ratelim from ..boot import app @@ -44,6 +45,7 @@ class QQBotManager: cntfilter_mgr: cntfilter.ContentFilterManager = None longtext_pcs: longtext.LongTextProcessor = None resprule_chkr: resprule.GroupRespondRuleChecker = None + ratelimiter: ratelim.RateLimiter = None def __init__(self, first_time_init=True, ap: app.Application = None): config = context.get_config_manager().data @@ -53,6 +55,7 @@ class QQBotManager: self.cntfilter_mgr = cntfilter.ContentFilterManager(ap) self.longtext_pcs = longtext.LongTextProcessor(ap) self.resprule_chkr = resprule.GroupRespondRuleChecker(ap) + self.ratelimiter = ratelim.RateLimiter(ap) self.timeout = config['process_message_timeout'] self.retry = config['retry_times'] @@ -62,6 +65,7 @@ class QQBotManager: await self.cntfilter_mgr.initialize() await self.longtext_pcs.initialize() await self.resprule_chkr.initialize() + await self.ratelimiter.initialize() config = context.get_config_manager().data diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index f6379c71..e1673583 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -7,7 +7,6 @@ import traceback import mirai import logging -from ..qqbot import ratelimit from ..qqbot import command, message from ..openai import session as openai_session from ..utils import context @@ -103,12 +102,10 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st else: # 消息 msg_type = 'message' # 限速丢弃检查 - # print(ratelimit.__crt_minute_usage__[session_name]) - if config['rate_limit_strategy'] == "drop": - if ratelimit.is_reach_limit(session_name): - logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message)) + if not await mgr.ratelimiter.require(launcher_type, launcher_id): + logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message)) - return mirai.MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else [] + return mirai.MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else [] before = time.time() # 触发插件事件 @@ -133,12 +130,6 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st reply = message.process_normal_message(text_message, mgr, config, launcher_type, launcher_id, sender_id) - # 限速等待时间 - if config['rate_limit_strategy'] == "wait": - time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before)) - - ratelimit.add_usage(session_name) - if reply is not None and len(reply) > 0 and (type(reply[0]) == str or type(reply[0]) == mirai.Plain): if type(reply[0]) == mirai.Plain: reply[0] = reply[0].text diff --git a/pkg/qqbot/ratelim/__init__.py b/pkg/qqbot/ratelim/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/ratelim/algo.py b/pkg/qqbot/ratelim/algo.py new file mode 100644 index 00000000..10bbdd3a --- /dev/null +++ b/pkg/qqbot/ratelim/algo.py @@ -0,0 +1,24 @@ +from __future__ import annotations +import abc + +from ...boot import app + + +class ReteLimitAlgo(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def require_access(self, launcher_type: str, launcher_id: int) -> bool: + raise NotImplementedError + + @abc.abstractmethod + async def release_access(self, launcher_type: str, launcher_id: int): + raise NotImplementedError + \ No newline at end of file diff --git a/pkg/qqbot/ratelim/algos/__init__.py b/pkg/qqbot/ratelim/algos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/qqbot/ratelim/algos/fixedwin.py b/pkg/qqbot/ratelim/algos/fixedwin.py new file mode 100644 index 00000000..4996fbaa --- /dev/null +++ b/pkg/qqbot/ratelim/algos/fixedwin.py @@ -0,0 +1,85 @@ +# 固定窗口算法 +from __future__ import annotations + +import asyncio +import time + +from .. import algo + + +class SessionContainer: + + wait_lock: asyncio.Lock + + records: dict[int, int] + """访问记录,key为每分钟的起始时间戳,value为访问次数""" + + def __init__(self): + self.wait_lock = asyncio.Lock() + self.records = {} + + +class FixedWindowAlgo(algo.ReteLimitAlgo): + + containers_lock: asyncio.Lock + """访问记录容器锁""" + + containers: dict[str, SessionContainer] + """访问记录容器,key为launcher_type launcher_id""" + + async def initialize(self): + self.containers_lock = asyncio.Lock() + self.containers = {} + + async def require_access(self, launcher_type: str, launcher_id: int) -> bool: + # 加锁,找容器 + container: SessionContainer = None + + session_name = f'{launcher_type}_{launcher_id}' + + async with self.containers_lock: + container = self.containers.get(session_name) + + if container is None: + container = SessionContainer() + self.containers[session_name] = container + + # 等待锁 + async with container.wait_lock: + # 获取当前时间戳 + now = int(time.time()) + + # 获取当前分钟的起始时间戳 + now = now - now % 60 + + # 获取当前分钟的访问次数 + count = container.records.get(now, 0) + + limitation = self.ap.cfg_mgr.data['rate_limitation']['default'] + + if session_name in self.ap.cfg_mgr.data['rate_limitation']: + limitation = self.ap.cfg_mgr.data['rate_limitation'][session_name] + + # 如果访问次数超过了限制 + if count >= limitation: + if self.ap.cfg_mgr.data['rate_limit_strategy'] == 'drop': + return False + elif self.ap.cfg_mgr.data['rate_limit_strategy'] == 'wait': + # 等待下一分钟 + await asyncio.sleep(60 - time.time() % 60) + + now = int(time.time()) + now = now - now % 60 + + if now not in container.records: + container.records = {} + container.records[now] = 1 + else: + # 访问次数加一 + container.records[now] = count + 1 + + # 返回True + return True + + async def release_access(self, launcher_type: str, launcher_id: int): + pass diff --git a/pkg/qqbot/ratelim/ratelim.py b/pkg/qqbot/ratelim/ratelim.py new file mode 100644 index 00000000..ab23d714 --- /dev/null +++ b/pkg/qqbot/ratelim/ratelim.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from . import algo +from .algos import fixedwin +from ...boot import app + + +class RateLimiter: + """限速器 + """ + + ap: app.Application + + algo: algo.ReteLimitAlgo + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + self.algo = fixedwin.FixedWindowAlgo(self.ap) + await self.algo.initialize() + + async def require(self, launcher_type: str, launcher_id: int) -> bool: + """请求访问 + """ + return await self.algo.require_access(launcher_type, launcher_id) + + async def release(self, launcher_type: str, launcher_id: int): + """释放访问 + """ + return await self.algo.release_access(launcher_type, launcher_id) \ No newline at end of file diff --git a/pkg/qqbot/ratelimit.py b/pkg/qqbot/ratelimit.py deleted file mode 100644 index 96d289ff..00000000 --- a/pkg/qqbot/ratelimit.py +++ /dev/null @@ -1,89 +0,0 @@ -# 限速相关模块 -import time -import logging -import threading - -from ..utils import context - - -__crt_minute_usage__ = {} -"""当前分钟每个会话的对话次数""" - - -__timer_thr__: threading.Thread = None - - -def get_limitation(session_name: str) -> int: - """获取会话的限制次数""" - config = context.get_config_manager().data - - if session_name in config['rate_limitation']: - return config['rate_limitation'][session_name] - else: - return config['rate_limitation']["default"] - - -def add_usage(session_name: str): - """增加会话的对话次数""" - global __crt_minute_usage__ - if session_name in __crt_minute_usage__: - __crt_minute_usage__[session_name] += 1 - else: - __crt_minute_usage__[session_name] = 1 - - -def start_timer(): - """启动定时器""" - global __timer_thr__ - __timer_thr__ = threading.Thread(target=run_timer, daemon=True) - __timer_thr__.start() - - -def run_timer(): - """启动定时器,每分钟清空一次对话次数""" - global __crt_minute_usage__ - global __timer_thr__ - - # 等待直到整分钟 - time.sleep(60 - time.time() % 60) - - while True: - if __timer_thr__ != threading.current_thread(): - break - - logging.debug("清空当前分钟的对话次数") - __crt_minute_usage__ = {} - time.sleep(60) - - -def get_usage(session_name: str) -> int: - """获取会话的对话次数""" - global __crt_minute_usage__ - if session_name in __crt_minute_usage__: - return __crt_minute_usage__[session_name] - else: - return 0 - - -def get_rest_wait_time(session_name: str, spent: float) -> float: - """获取会话此回合的剩余等待时间""" - global __crt_minute_usage__ - - min_seconds_per_round = 60.0 / get_limitation(session_name) - - if session_name in __crt_minute_usage__: - return max(0, min_seconds_per_round - spent) - else: - return 0 - - -def is_reach_limit(session_name: str) -> bool: - """判断会话是否超过限制""" - global __crt_minute_usage__ - - if session_name in __crt_minute_usage__: - return __crt_minute_usage__[session_name] >= get_limitation(session_name) - else: - return False - -start_timer() diff --git a/pkg/qqbot/resprule/entities.py b/pkg/qqbot/resprule/entities.py index 1cdd76f2..ffee3081 100644 --- a/pkg/qqbot/resprule/entities.py +++ b/pkg/qqbot/resprule/entities.py @@ -7,4 +7,3 @@ class RuleJudgeResult(pydantic.BaseModel): matching: bool = False replacement: mirai.MessageChain = None -