mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-11 00:06:04 +00:00
feat: 恢复ratelimit
This commit is contained in:
@@ -9,13 +9,7 @@ import traceback
|
||||
from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \
|
||||
FriendMessage, Image, MessageChain, Plain
|
||||
import mirai
|
||||
import func_timeout
|
||||
|
||||
from ..provider import session as openai_session
|
||||
|
||||
import tips as tips_custom
|
||||
from ..platform import adapter as msadapter
|
||||
from .ratelim import ratelim
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from ..plugin import events
|
||||
@@ -31,15 +25,11 @@ class PlatformManager:
|
||||
# modern
|
||||
ap: app.Application = None
|
||||
|
||||
ratelimiter: ratelim.RateLimiter = None
|
||||
|
||||
def __init__(self, ap: app.Application = None):
|
||||
|
||||
self.ap = ap
|
||||
self.ratelimiter = ratelim.RateLimiter(ap)
|
||||
|
||||
async def initialize(self):
|
||||
await self.ratelimiter.initialize()
|
||||
|
||||
config = self.ap.cfg_mgr.data
|
||||
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
|
||||
from ...core import app
|
||||
|
||||
|
||||
class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def release_access(self, launcher_type: str, launcher_id: int):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
# 固定窗口算法
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from .. import algo
|
||||
|
||||
|
||||
class SessionContainer:
|
||||
|
||||
wait_lock: asyncio.Lock
|
||||
|
||||
records: dict[int, int]
|
||||
"""访问记录,key为每分钟的起始时间戳,value为访问次数"""
|
||||
|
||||
def __init__(self):
|
||||
self.wait_lock = asyncio.Lock()
|
||||
self.records = {}
|
||||
|
||||
|
||||
class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
|
||||
containers_lock: asyncio.Lock
|
||||
"""访问记录容器锁"""
|
||||
|
||||
containers: dict[str, SessionContainer]
|
||||
"""访问记录容器,key为launcher_type launcher_id"""
|
||||
|
||||
async def initialize(self):
|
||||
self.containers_lock = asyncio.Lock()
|
||||
self.containers = {}
|
||||
|
||||
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
|
||||
# 加锁,找容器
|
||||
container: SessionContainer = None
|
||||
|
||||
session_name = f'{launcher_type}_{launcher_id}'
|
||||
|
||||
async with self.containers_lock:
|
||||
container = self.containers.get(session_name)
|
||||
|
||||
if container is None:
|
||||
container = SessionContainer()
|
||||
self.containers[session_name] = container
|
||||
|
||||
# 等待锁
|
||||
async with container.wait_lock:
|
||||
# 获取当前时间戳
|
||||
now = int(time.time())
|
||||
|
||||
# 获取当前分钟的起始时间戳
|
||||
now = now - now % 60
|
||||
|
||||
# 获取当前分钟的访问次数
|
||||
count = container.records.get(now, 0)
|
||||
|
||||
limitation = self.ap.cfg_mgr.data['rate_limitation']['default']
|
||||
|
||||
if session_name in self.ap.cfg_mgr.data['rate_limitation']:
|
||||
limitation = self.ap.cfg_mgr.data['rate_limitation'][session_name]
|
||||
|
||||
# 如果访问次数超过了限制
|
||||
if count >= limitation:
|
||||
if self.ap.cfg_mgr.data['rate_limit_strategy'] == 'drop':
|
||||
return False
|
||||
elif self.ap.cfg_mgr.data['rate_limit_strategy'] == 'wait':
|
||||
# 等待下一分钟
|
||||
await asyncio.sleep(60 - time.time() % 60)
|
||||
|
||||
now = int(time.time())
|
||||
now = now - now % 60
|
||||
|
||||
if now not in container.records:
|
||||
container.records = {}
|
||||
container.records[now] = 1
|
||||
else:
|
||||
# 访问次数加一
|
||||
container.records[now] = count + 1
|
||||
|
||||
# 返回True
|
||||
return True
|
||||
|
||||
async def release_access(self, launcher_type: str, launcher_id: int):
|
||||
pass
|
||||
@@ -1,31 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import algo
|
||||
from .algos import fixedwin
|
||||
from ...core import app
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""限速器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
algo: algo.ReteLimitAlgo
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
self.algo = fixedwin.FixedWindowAlgo(self.ap)
|
||||
await self.algo.initialize()
|
||||
|
||||
async def require(self, launcher_type: str, launcher_id: int) -> bool:
|
||||
"""请求访问
|
||||
"""
|
||||
return await self.algo.require_access(launcher_type, launcher_id)
|
||||
|
||||
async def release(self, launcher_type: str, launcher_id: int):
|
||||
"""释放访问
|
||||
"""
|
||||
return await self.algo.release_access(launcher_type, launcher_id)
|
||||
Reference in New Issue
Block a user