From 16d7cc7d188c22336ca89f3c0fd14aa6899956bd Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Mon, 2 Jan 2023 00:35:36 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=8C=81=E4=B9=85=E4=BF=9D=E5=AD=98bot?= =?UTF-8?q?=E5=AF=B9=E8=B1=A1=E4=BB=A5=E6=88=90=E5=8A=9F=E9=87=8D=E5=90=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 57 +++++++++++++++++++++++++++++++++---------- pkg/openai/manager.py | 2 +- pkg/qqbot/manager.py | 37 +++++++++++++++++----------- pkg/qqbot/process.py | 7 +++--- pkg/utils/context.py | 3 ++- pkg/utils/reloader.py | 18 ++++++++++++++ 6 files changed, 92 insertions(+), 32 deletions(-) diff --git a/main.py b/main.py index 55bb6166..d99f420a 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import asyncio import os import shutil import sys @@ -7,6 +8,8 @@ import time import logging import colorlog +from mirai.bot import MiraiRunner + import sys sys.path.append(".") @@ -27,11 +30,15 @@ def init_db(): database.initialize_database() -def main(): +def main(first_time_init=False): # 导入config.py assert os.path.exists('config.py') import config + import pkg.utils.context + if pkg.utils.context.context['logger_handler'] is not None: + logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler']) + logging.basicConfig(level=config.logging_level, # 设置日志输出格式 filename='qchatgpt.log', # log日志输出的文件位置和文件名 format="[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s", @@ -53,8 +60,8 @@ def main(): import pkg.database.manager import pkg.openai.session import pkg.qqbot.manager - import pkg.utils.context + pkg.utils.context.context['logger_handler'] = sh # 主启动流程 database = pkg.database.manager.DatabaseManager() @@ -67,7 +74,8 @@ def main(): # 初始化qq机器人 qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config, - timeout=config.process_message_timeout, retry=config.retry_times) + timeout=config.process_message_timeout, retry=config.retry_times, + first_time_init=first_time_init) qq_bot_thread = threading.Thread(target=qqbot.bot.run, args=(), daemon=True) qq_bot_thread.start() @@ -76,21 +84,44 @@ def main(): while True: try: - time.sleep(86400) + time.sleep(10000) + if qqbot != pkg.utils.context.get_qqbot_manager(): # 已经reload了 + break except KeyboardInterrupt: - try: - pkg.utils.context.get_openai_manager().key_mgr.dump_fee() - for session in pkg.openai.session.sessions: - logging.info('持久化session: %s', session) - pkg.openai.session.sessions[session].persistence() - except Exception as e: - if not isinstance(e, KeyboardInterrupt): - raise e + stop() + print("程序退出") sys.exit(0) +def stop(): + import pkg.utils.context + import pkg.qqbot.manager + import pkg.openai.session + try: + qqbot_inst = pkg.utils.context.get_qqbot_manager() + assert isinstance(qqbot_inst, pkg.qqbot.manager.QQBotManager) + + # try: + # asyncio.run(qqbot_inst.bot.shutdown()) + # except ValueError: + # pass + # + # import mirai.utils + # MiraiRunner.__class__._instance = None + # mirai.utils.Singleton._instance = None + + pkg.utils.context.get_openai_manager().key_mgr.dump_fee() + for session in pkg.openai.session.sessions: + logging.info('持久化session: %s', session) + pkg.openai.session.sessions[session].persistence() + except Exception as e: + if not isinstance(e, KeyboardInterrupt): + raise e + + if __name__ == '__main__': + print('程序启动') # 检查是否有config.py,如果没有就把config-template.py复制一份,并退出程序 if not os.path.exists('config.py'): shutil.copy('config-template.py', 'config.py') @@ -110,4 +141,4 @@ if __name__ == '__main__': print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") sys.exit(0) - main() + main(True) diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index fa77fb4c..e59146d6 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -30,7 +30,7 @@ class OpenAIInteract: # 请求OpenAI Completion def request_completion(self, prompt, stop): - print("request") + # print("request") response = openai.Completion.create( prompt=prompt, stop=stop, diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 7a121385..dcd1321d 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -57,7 +57,7 @@ class QQBotManager: reply_filter = None - def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3): + def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True): self.timeout = timeout self.retry = retry @@ -70,6 +70,28 @@ class QQBotManager: else: self.reply_filter = pkg.qqbot.filter.ReplyFilter([]) + + if first_time_init: + self.first_time_init(mirai_http_api_config) + else: + self.bot = pkg.utils.context.get_qqbot_manager().bot + + pkg.utils.context.set_qqbot_manager(self) + + @self.bot.on(FriendMessage) + async def on_friend_message(event: FriendMessage): + go(self.on_person_message, (event,)) + + @self.bot.on(StrangerMessage) + async def on_stranger_message(event: StrangerMessage): + go(self.on_person_message, (event,)) + + @self.bot.on(GroupMessage) + async def on_group_message(event: GroupMessage): + go(self.on_group_message, (event,)) + + def first_time_init(self, mirai_http_api_config: dict): + if 'adapter' not in mirai_http_api_config or mirai_http_api_config['adapter'] == "WebSocketAdapter": bot = Mirai( qq=mirai_http_api_config['qq'], @@ -92,22 +114,9 @@ class QQBotManager: else: raise Exception("未知的适配器类型") - @bot.on(FriendMessage) - async def on_friend_message(event: FriendMessage): - go(self.on_person_message, (event,)) - - @bot.on(StrangerMessage) - async def on_stranger_message(event: StrangerMessage): - go(self.on_person_message, (event,)) - - @bot.on(GroupMessage) - async def on_group_message(event: GroupMessage): - go(self.on_group_message, (event,)) self.bot = bot - pkg.utils.context.set_qqbot_manager(self) - def send(self, event, msg, check_quote=True): asyncio.run( self.bot.send(event, msg, quote=True if hasattr(config, diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index 596a10ad..0145fcec 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -1,6 +1,7 @@ # 此模块提供了消息处理的具体逻辑的接口 import asyncio import datetime +import threading import pkg.qqbot.manager as manager from func_timeout import func_set_timeout @@ -8,7 +9,6 @@ import logging import openai from mirai import Image, MessageChain -from mirai.models.message import Quote import config @@ -162,8 +162,9 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes reply.append(" ".join(params)) elif cmd == 'reload' and launcher_type == 'person' and launcher_id == config.admin_qq: try: - pkg.utils.reloader.reload_all() - reply = ["[bot]已重新加载所有模块"] + # pkg.utils.reloader.reload_all() + threading.Thread(target=pkg.utils.reloader.reload_all, daemon=True).start() + # reply = ["[bot]已重新加载所有模块"] except Exception as e: logging.error("reload failed:{}".format(e)) reply = ["[bot]重载失败:{}".format(e)] diff --git a/pkg/utils/context.py b/pkg/utils/context.py index be2849a8..2576609b 100644 --- a/pkg/utils/context.py +++ b/pkg/utils/context.py @@ -3,7 +3,8 @@ context = { 'database.manager.DatabaseManager': None, 'openai.manager.OpenAIInteract': None, 'qqbot.manager.QQBotManager': None, - } + }, + 'logger_handler': None, } diff --git a/pkg/utils/reloader.py b/pkg/utils/reloader.py index 7b7f1a35..67cc51a7 100644 --- a/pkg/utils/reloader.py +++ b/pkg/utils/reloader.py @@ -1,9 +1,14 @@ import logging +import os +import threading + +import colorlog import pkg import importlib import pkgutil import pkg.utils.context +from main import log_colors_config def walk(module, prefix=''): @@ -16,7 +21,20 @@ def walk(module, prefix=''): def reload_all(): + # 执行关闭流程 + logging.info("执行程序关闭流程") + import main + main.stop() + import pkg + context = pkg.utils.context.context walk(pkg) importlib.reload(__import__('config')) + importlib.reload(__import__('main')) pkg.utils.context.context = context + + # 执行启动流程 + logging.info("执行程序启动流程") + main.main() + + logging.info('程序启动完成')