feat: 恢复ratelimit

This commit is contained in:
RockChinQ
2024-02-01 18:38:20 +08:00
parent 0dec10ddf2
commit f340a44abf
8 changed files with 58 additions and 41 deletions

View File

@@ -9,13 +9,7 @@ import traceback
from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \
FriendMessage, Image, MessageChain, Plain
import mirai
import func_timeout
from ..provider import session as openai_session
import tips as tips_custom
from ..platform import adapter as msadapter
from .ratelim import ratelim
from ..core import app, entities as core_entities
from ..plugin import events
@@ -31,15 +25,11 @@ class PlatformManager:
# modern
ap: app.Application = None
ratelimiter: ratelim.RateLimiter = None
def __init__(self, ap: app.Application = None):
self.ap = ap
self.ratelimiter = ratelim.RateLimiter(ap)
async def initialize(self):
await self.ratelimiter.initialize()
config = self.ap.cfg_mgr.data

View File

@@ -1,24 +0,0 @@
from __future__ import annotations
import abc
from ...core import app
class ReteLimitAlgo(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
raise NotImplementedError
@abc.abstractmethod
async def release_access(self, launcher_type: str, launcher_id: int):
raise NotImplementedError

View File

@@ -1,85 +0,0 @@
# 固定窗口算法
from __future__ import annotations
import asyncio
import time
from .. import algo
class SessionContainer:
wait_lock: asyncio.Lock
records: dict[int, int]
"""访问记录key为每分钟的起始时间戳value为访问次数"""
def __init__(self):
self.wait_lock = asyncio.Lock()
self.records = {}
class FixedWindowAlgo(algo.ReteLimitAlgo):
containers_lock: asyncio.Lock
"""访问记录容器锁"""
containers: dict[str, SessionContainer]
"""访问记录容器key为launcher_type launcher_id"""
async def initialize(self):
self.containers_lock = asyncio.Lock()
self.containers = {}
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
# 加锁,找容器
container: SessionContainer = None
session_name = f'{launcher_type}_{launcher_id}'
async with self.containers_lock:
container = self.containers.get(session_name)
if container is None:
container = SessionContainer()
self.containers[session_name] = container
# 等待锁
async with container.wait_lock:
# 获取当前时间戳
now = int(time.time())
# 获取当前分钟的起始时间戳
now = now - now % 60
# 获取当前分钟的访问次数
count = container.records.get(now, 0)
limitation = self.ap.cfg_mgr.data['rate_limitation']['default']
if session_name in self.ap.cfg_mgr.data['rate_limitation']:
limitation = self.ap.cfg_mgr.data['rate_limitation'][session_name]
# 如果访问次数超过了限制
if count >= limitation:
if self.ap.cfg_mgr.data['rate_limit_strategy'] == 'drop':
return False
elif self.ap.cfg_mgr.data['rate_limit_strategy'] == 'wait':
# 等待下一分钟
await asyncio.sleep(60 - time.time() % 60)
now = int(time.time())
now = now - now % 60
if now not in container.records:
container.records = {}
container.records[now] = 1
else:
# 访问次数加一
container.records[now] = count + 1
# 返回True
return True
async def release_access(self, launcher_type: str, launcher_id: int):
pass

View File

@@ -1,31 +0,0 @@
from __future__ import annotations
from . import algo
from .algos import fixedwin
from ...core import app
class RateLimiter:
"""限速器
"""
ap: app.Application
algo: algo.ReteLimitAlgo
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
self.algo = fixedwin.FixedWindowAlgo(self.ap)
await self.algo.initialize()
async def require(self, launcher_type: str, launcher_id: int) -> bool:
"""请求访问
"""
return await self.algo.require_access(launcher_type, launcher_id)
async def release(self, launcher_type: str, launcher_id: int):
"""释放访问
"""
return await self.algo.release_access(launcher_type, launcher_id)