refactor: 独立ratelimiter包

This commit is contained in:
RockChinQ
2024-01-25 22:35:15 +08:00
parent f4ead5ec5c
commit b43882aad0
9 changed files with 147 additions and 102 deletions

View File

@@ -22,6 +22,7 @@ from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .longtext import longtext
from .ratelim import ratelim
from ..boot import app
@@ -44,6 +45,7 @@ class QQBotManager:
cntfilter_mgr: cntfilter.ContentFilterManager = None
longtext_pcs: longtext.LongTextProcessor = None
resprule_chkr: resprule.GroupRespondRuleChecker = None
ratelimiter: ratelim.RateLimiter = None
def __init__(self, first_time_init=True, ap: app.Application = None):
config = context.get_config_manager().data
@@ -53,6 +55,7 @@ class QQBotManager:
self.cntfilter_mgr = cntfilter.ContentFilterManager(ap)
self.longtext_pcs = longtext.LongTextProcessor(ap)
self.resprule_chkr = resprule.GroupRespondRuleChecker(ap)
self.ratelimiter = ratelim.RateLimiter(ap)
self.timeout = config['process_message_timeout']
self.retry = config['retry_times']
@@ -62,6 +65,7 @@ class QQBotManager:
await self.cntfilter_mgr.initialize()
await self.longtext_pcs.initialize()
await self.resprule_chkr.initialize()
await self.ratelimiter.initialize()
config = context.get_config_manager().data

View File

@@ -7,7 +7,6 @@ import traceback
import mirai
import logging
from ..qqbot import ratelimit
from ..qqbot import command, message
from ..openai import session as openai_session
from ..utils import context
@@ -103,12 +102,10 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st
else: # 消息
msg_type = 'message'
# 限速丢弃检查
# print(ratelimit.__crt_minute_usage__[session_name])
if config['rate_limit_strategy'] == "drop":
if ratelimit.is_reach_limit(session_name):
logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message))
if not await mgr.ratelimiter.require(launcher_type, launcher_id):
logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message))
return mirai.MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else []
return mirai.MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else []
before = time.time()
# 触发插件事件
@@ -133,12 +130,6 @@ async def process_message(launcher_type: str, launcher_id: int, text_message: st
reply = message.process_normal_message(text_message,
mgr, config, launcher_type, launcher_id, sender_id)
# 限速等待时间
if config['rate_limit_strategy'] == "wait":
time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before))
ratelimit.add_usage(session_name)
if reply is not None and len(reply) > 0 and (type(reply[0]) == str or type(reply[0]) == mirai.Plain):
if type(reply[0]) == mirai.Plain:
reply[0] = reply[0].text

View File

24
pkg/qqbot/ratelim/algo.py Normal file
View File

@@ -0,0 +1,24 @@
from __future__ import annotations
import abc
from ...boot import app
class ReteLimitAlgo(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
raise NotImplementedError
@abc.abstractmethod
async def release_access(self, launcher_type: str, launcher_id: int):
raise NotImplementedError

View File

View File

@@ -0,0 +1,85 @@
# 固定窗口算法
from __future__ import annotations
import asyncio
import time
from .. import algo
class SessionContainer:
wait_lock: asyncio.Lock
records: dict[int, int]
"""访问记录key为每分钟的起始时间戳value为访问次数"""
def __init__(self):
self.wait_lock = asyncio.Lock()
self.records = {}
class FixedWindowAlgo(algo.ReteLimitAlgo):
containers_lock: asyncio.Lock
"""访问记录容器锁"""
containers: dict[str, SessionContainer]
"""访问记录容器key为launcher_type launcher_id"""
async def initialize(self):
self.containers_lock = asyncio.Lock()
self.containers = {}
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
# 加锁,找容器
container: SessionContainer = None
session_name = f'{launcher_type}_{launcher_id}'
async with self.containers_lock:
container = self.containers.get(session_name)
if container is None:
container = SessionContainer()
self.containers[session_name] = container
# 等待锁
async with container.wait_lock:
# 获取当前时间戳
now = int(time.time())
# 获取当前分钟的起始时间戳
now = now - now % 60
# 获取当前分钟的访问次数
count = container.records.get(now, 0)
limitation = self.ap.cfg_mgr.data['rate_limitation']['default']
if session_name in self.ap.cfg_mgr.data['rate_limitation']:
limitation = self.ap.cfg_mgr.data['rate_limitation'][session_name]
# 如果访问次数超过了限制
if count >= limitation:
if self.ap.cfg_mgr.data['rate_limit_strategy'] == 'drop':
return False
elif self.ap.cfg_mgr.data['rate_limit_strategy'] == 'wait':
# 等待下一分钟
await asyncio.sleep(60 - time.time() % 60)
now = int(time.time())
now = now - now % 60
if now not in container.records:
container.records = {}
container.records[now] = 1
else:
# 访问次数加一
container.records[now] = count + 1
# 返回True
return True
async def release_access(self, launcher_type: str, launcher_id: int):
pass

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from . import algo
from .algos import fixedwin
from ...boot import app
class RateLimiter:
"""限速器
"""
ap: app.Application
algo: algo.ReteLimitAlgo
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
self.algo = fixedwin.FixedWindowAlgo(self.ap)
await self.algo.initialize()
async def require(self, launcher_type: str, launcher_id: int) -> bool:
"""请求访问
"""
return await self.algo.require_access(launcher_type, launcher_id)
async def release(self, launcher_type: str, launcher_id: int):
"""释放访问
"""
return await self.algo.release_access(launcher_type, launcher_id)

View File

@@ -1,89 +0,0 @@
# 限速相关模块
import time
import logging
import threading
from ..utils import context
__crt_minute_usage__ = {}
"""当前分钟每个会话的对话次数"""
__timer_thr__: threading.Thread = None
def get_limitation(session_name: str) -> int:
"""获取会话的限制次数"""
config = context.get_config_manager().data
if session_name in config['rate_limitation']:
return config['rate_limitation'][session_name]
else:
return config['rate_limitation']["default"]
def add_usage(session_name: str):
"""增加会话的对话次数"""
global __crt_minute_usage__
if session_name in __crt_minute_usage__:
__crt_minute_usage__[session_name] += 1
else:
__crt_minute_usage__[session_name] = 1
def start_timer():
"""启动定时器"""
global __timer_thr__
__timer_thr__ = threading.Thread(target=run_timer, daemon=True)
__timer_thr__.start()
def run_timer():
"""启动定时器,每分钟清空一次对话次数"""
global __crt_minute_usage__
global __timer_thr__
# 等待直到整分钟
time.sleep(60 - time.time() % 60)
while True:
if __timer_thr__ != threading.current_thread():
break
logging.debug("清空当前分钟的对话次数")
__crt_minute_usage__ = {}
time.sleep(60)
def get_usage(session_name: str) -> int:
"""获取会话的对话次数"""
global __crt_minute_usage__
if session_name in __crt_minute_usage__:
return __crt_minute_usage__[session_name]
else:
return 0
def get_rest_wait_time(session_name: str, spent: float) -> float:
"""获取会话此回合的剩余等待时间"""
global __crt_minute_usage__
min_seconds_per_round = 60.0 / get_limitation(session_name)
if session_name in __crt_minute_usage__:
return max(0, min_seconds_per_round - spent)
else:
return 0
def is_reach_limit(session_name: str) -> bool:
"""判断会话是否超过限制"""
global __crt_minute_usage__
if session_name in __crt_minute_usage__:
return __crt_minute_usage__[session_name] >= get_limitation(session_name)
else:
return False
start_timer()

View File

@@ -7,4 +7,3 @@ class RuleJudgeResult(pydantic.BaseModel):
matching: bool = False
replacement: mirai.MessageChain = None