From abc19e78b8b9c4fe24aafd23dd3eae0146dda187 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 11 Feb 2024 23:35:05 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=91=BD=E4=BB=A4=E8=A1=8C=E9=80=80?= =?UTF-8?q?=E5=87=BA=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/app.py | 14 +++++++++++++- pkg/pipeline/resprule/rules/random.py | 3 ++- pkg/pipeline/resprule/rules/regexp.py | 3 ++- pkg/platform/manager.py | 12 ++++-------- requirements.txt | 1 + 5 files changed, 22 insertions(+), 11 deletions(-) diff --git a/pkg/core/app.py b/pkg/core/app.py index 595b01a8..ee901ba5 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -4,6 +4,8 @@ import logging import asyncio import traceback +import aioconsole + from ..platform import manager as im_mgr from ..provider.session import sessionmgr as llm_session_mgr from ..provider.requester import modelmgr as llm_model_mgr @@ -72,11 +74,21 @@ class Application: tasks = [] try: + + tasks = [ asyncio.create_task(self.im_mgr.run()), asyncio.create_task(self.ctrl.run()) ] - await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + async def interrupt(tasks): + await asyncio.sleep(1.5) + while await aioconsole.ainput("使用 exit 退出程序 > ") != 'exit': + pass + for task in tasks: + task.cancel() + + await interrupt(tasks) except asyncio.CancelledError: pass diff --git a/pkg/pipeline/resprule/rules/random.py b/pkg/pipeline/resprule/rules/random.py index 2e845ab0..185e03ec 100644 --- a/pkg/pipeline/resprule/rules/random.py +++ b/pkg/pipeline/resprule/rules/random.py @@ -13,7 +13,8 @@ class RandomRespRule(rule_model.GroupRespondRule): self, message_text: str, message_chain: mirai.MessageChain, - rule_dict: dict + rule_dict: dict, + query: core_entities.Query ) -> entities.RuleJudgeResult: random_rate = rule_dict['random'] diff --git a/pkg/pipeline/resprule/rules/regexp.py b/pkg/pipeline/resprule/rules/regexp.py index 18a3ce09..4e39d432 100644 --- a/pkg/pipeline/resprule/rules/regexp.py +++ b/pkg/pipeline/resprule/rules/regexp.py @@ -13,7 +13,8 @@ class RegExpRule(rule_model.GroupRespondRule): self, message_text: str, message_chain: mirai.MessageChain, - rule_dict: dict + rule_dict: dict, + query: core_entities.Query ) -> entities.RuleJudgeResult: regexps = rule_dict['regexp'] diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 432d39ed..f4f423db 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -78,13 +78,6 @@ class PlatformManager: adapter=adapter ) - # nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件 - # if self.ap.platform_cfg.data['platform-adapter'] == 'yiri-mirai': - # self.adapter.register_listener( - # StrangerMessage, - # on_stranger_message - # ) - async def on_group_message(event: GroupMessage, adapter: msadapter.MessageSourceAdapter): event_ctx = await self.ap.plugin_mgr.emit_event( @@ -187,7 +180,10 @@ class PlatformManager: tasks = [] for adapter in self.adapters: tasks.append(adapter.run_async()) - await asyncio.gather(*tasks) + + for task in tasks: + asyncio.create_task(task) + except Exception as e: self.ap.logger.error('平台适配器运行出错: ' + str(e)) self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") diff --git a/requirements.txt b/requirements.txt index de78dcec..649ed9b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ tiktoken PyYaml aiohttp pydantic +aioconsole \ No newline at end of file