From 835275b47ffca001992c80b5ce7e4b146599df76 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sat, 23 Mar 2024 22:39:42 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=A4=9A=E5=A4=84=E5=AF=B9=20launcher?= =?UTF-8?q?=5Ftype=20=E6=9E=9A=E4=B8=BE=E7=9A=84=E4=B8=8D=E5=BD=93?= =?UTF-8?q?=E6=AF=94=E8=BE=83=20(#736)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/entities.py | 2 +- pkg/pipeline/bansess/bansess.py | 20 +++++++++++--------- pkg/pipeline/ratelimit/ratelimit.py | 2 +- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pkg/core/entities.py b/pkg/core/entities.py index f0f3f151..2e7d0b1d 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -94,7 +94,7 @@ class Conversation(pydantic.BaseModel): class Session(pydantic.BaseModel): - """会话,一个 Session 对应一个 {launcher_type}_{launcher_id}""" + """会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}""" launcher_type: LauncherTypes launcher_id: int diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py index 3add6f1f..95a7cffd 100644 --- a/pkg/pipeline/bansess/bansess.py +++ b/pkg/pipeline/bansess/bansess.py @@ -25,22 +25,24 @@ class BanSessionCheckStage(stage.PipelineStage): sess_list = self.ap.pipeline_cfg.data['access-control'][mode] - if (query.launcher_type == 'group' and 'group_*' in sess_list) \ - or (query.launcher_type == 'person' and 'person_*' in sess_list): + if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \ + or (query.launcher_type.value == 'person' and 'person_*' in sess_list): found = True else: for sess in sess_list: - if sess == f"{query.launcher_type}_{query.launcher_id}": + if sess == f"{query.launcher_type.value}_{query.launcher_id}": found = True break + + ctn = False - result = False - - if mode == 'blacklist': - result = found + if mode == 'whitelist': + ctn = found + else: + ctn = not found return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT, + result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT, new_query=query, - debug_notice=f'根据访问控制忽略消息: {query.launcher_type}_{query.launcher_id}' if result else '' + console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else '' ) diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py index f43c8b06..2622247a 100644 --- a/pkg/pipeline/ratelimit/ratelimit.py +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -59,7 +59,7 @@ class RateLimit(stage.PipelineStage): ) elif stage_inst_name == "ReleaseRateLimitOccupancy": await self.algo.release_access( - query.launcher_type, + query.launcher_type.value, query.launcher_id, ) return entities.StageProcessResult(