From 77076f3bddd6afbf1236ec9578cdb854718e2fe7 Mon Sep 17 00:00:00 2001 From: LINSTCL Date: Wed, 8 Mar 2023 15:21:37 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=BA=BF=E7=A8=8B=E6=8E=A7?= =?UTF-8?q?=E5=88=B6=E7=B1=BB=EF=BC=8C=E4=BF=AE=E6=94=B9main=E7=BB=93?= =?UTF-8?q?=E6=9E=84=EF=BC=8C=E4=BF=AE=E6=94=B9=E5=90=AF=E5=8A=A8=E6=B5=81?= =?UTF-8?q?=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config-template.py | 4 +- main.py | 104 ++++++++++++++++++++++++----------------- pkg/qqbot/manager.py | 25 +++++----- pkg/utils/__init__.py | 1 + pkg/utils/context.py | 53 ++++++++++++++++++--- pkg/utils/reloader.py | 21 ++++++--- pkg/utils/threadctl.py | 93 ++++++++++++++++++++++++++++++++++++ 7 files changed, 231 insertions(+), 70 deletions(-) create mode 100644 pkg/utils/threadctl.py diff --git a/config-template.py b/config-template.py index fbb3eea0..7951b19d 100644 --- a/config-template.py +++ b/config-template.py @@ -208,7 +208,9 @@ alter_tip_message = '出错了,请稍后再试' # 机器人线程池大小 # 该参数决定机器人可以同时处理几个人的消息,超出线程池数量的请求会被阻塞,不会被丢弃 # 如果你不清楚该参数的意义,请不要更改 -pool_num = 10 +sys_pool_num = 8 +admin_pool_num = 2 +user_pool_num = 3 # 每个会话的过期时间,单位为秒 # 默认值20分钟 diff --git a/main.py b/main.py index bfdd7a80..da5d1831 100644 --- a/main.py +++ b/main.py @@ -23,7 +23,7 @@ import colorlog import requests import websockets.exceptions from urllib3.exceptions import InsecureRequestWarning - +import pkg.utils.context sys.path.append(".") @@ -74,11 +74,8 @@ def init_runtime_log_file(): def reset_logging(): global log_file_name - assert os.path.exists('config.py') - config = importlib.import_module('config') - - import pkg.utils.context + config = pkg.utils.context.get_config() if pkg.utils.context.context['logger_handler'] is not None: logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler']) @@ -106,12 +103,13 @@ def reset_logging(): return sh -def main(first_time_init=False): +def start(first_time_init=False): """启动流程,reload之后会被执行""" global known_exception_caught + import pkg.utils.context - import config + config = pkg.utils.context.get_config() # 更新openai库到最新版本 if not hasattr(config, 'upgrade_dependencies') or config.upgrade_dependencies: print("正在更新依赖库,请等待...") @@ -126,31 +124,10 @@ def main(first_time_init=False): known_exception_caught = False try: - # 导入config.py - assert os.path.exists('config.py') - - config = importlib.import_module('config') - init_runtime_log_file() sh = reset_logging() - # 配置完整性校验 - is_integrity = True - config_template = importlib.import_module('config-template') - for key in dir(config_template): - if not key.startswith("__") and not hasattr(config, key): - setattr(config, key, getattr(config_template, key)) - logging.warning("[{}]不存在".format(key)) - is_integrity = False - if not is_integrity: - logging.warning("配置文件不完整,请依据config-template.py检查config.py") - logging.warning("以上配置已被设为默认值,将在5秒后继续启动... ") - time.sleep(5) - - import pkg.utils.context - pkg.utils.context.set_config(config) - # 检查是否设置了管理员 if not (hasattr(config, 'admin_qq') and config.admin_qq != 0): # logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段") @@ -197,7 +174,7 @@ def main(first_time_init=False): # 初始化qq机器人 qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config, timeout=config.process_message_timeout, retry=config.retry_times, - first_time_init=first_time_init, pool_num=config.pool_num) + first_time_init=first_time_init) # 加载插件 import pkg.plugin.host @@ -252,10 +229,10 @@ def main(first_time_init=False): known_exception_caught = True raise e - qq_bot_thread = threading.Thread(target=run_bot_wrapper, args=(), daemon=True) - qq_bot_thread.start() + pkg.utils.context.get_thread_ctl().submit_sys_task( + run_bot_wrapper + ) finally: - time.sleep(12) if first_time_init: if not known_exception_caught: logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 ' @@ -294,11 +271,8 @@ def main(first_time_init=False): except Exception as e: logging.warning("检查更新失败:{}".format(e)) - return qqbot - def stop(): - import pkg.utils.context import pkg.qqbot.manager import pkg.openai.session try: @@ -316,14 +290,30 @@ def stop(): if not isinstance(e, KeyboardInterrupt): raise e - -if __name__ == '__main__': - # 检查是否有config.py,如果没有就把config-template.py复制一份,并退出程序 +# 临时函数,用于加载config和上下文,未来统一放在config类 +def load_config(): + #存在性校验 if not os.path.exists('config.py'): shutil.copy('config-template.py', 'config.py') print('请先在config.py中填写配置') sys.exit(0) + #完整性校验 + is_integrity = True + config_template = importlib.import_module('config-template') + config = importlib.import_module('config') + for key in dir(config_template): + if not key.startswith("__") and not hasattr(config, key): + setattr(config, key, getattr(config_template, key)) + logging.warning("[{}]不存在".format(key)) + is_integrity = False + if not is_integrity: + logging.warning("配置文件不完整,请依据config-template.py检查config.py") + logging.warning("以上配置已被设为默认值,将在5秒后继续启动... ") + time.sleep(5) + #context配置 + pkg.utils.context.set_config(config) +def check_file(): # 检查是否有banlist.py,如果没有就把banlist-template.py复制一份 if not os.path.exists('banlist.py'): shutil.copy('banlist-template.py', 'banlist.py') @@ -342,6 +332,24 @@ if __name__ == '__main__': if not os.path.exists(path): os.mkdir(path) +def main(): + # 加载配置 + load_config() + config = pkg.utils.context.get_config() + + # 初始化相关文件 + check_file() + + # 配置线程池 + from pkg.utils import ThreadCtl + thread_ctl = ThreadCtl( + sys_pool_num = config.sys_pool_num, + admin_pool_num = config.admin_pool_num, + user_pool_num = config.user_pool_num + ) + pkg.utils.context.set_thread_ctl(thread_ctl) + + # 控制台指令处理 if len(sys.argv) > 1 and sys.argv[1] == 'init_db': init_db() sys.exit(0) @@ -352,19 +360,27 @@ if __name__ == '__main__': updater.update_all(cli=True) sys.exit(0) + # 不知道干啥的 # import pkg.utils.configmgr # # pkg.utils.configmgr.set_config_and_reload("quote_origin", False) requests.packages.urllib3.disable_warnings(InsecureRequestWarning) - qqbot = main(True) + pkg.utils.context.get_thread_ctl().submit_sys_task( + start, + True + ) - import pkg.utils.context + # 主线程循环 while True: try: - time.sleep(10) - except KeyboardInterrupt: + time.sleep(0xFF) + except: stop() - - print("程序退出") + pkg.utils.context.get_thread_ctl().shutdown() + print("退出") sys.exit(0) + +if __name__ == '__main__': + main() + diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 5d817eee..300edca8 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -2,7 +2,6 @@ import asyncio import json import os import threading -from concurrent.futures import ThreadPoolExecutor import mirai.models.bus from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \ @@ -66,9 +65,6 @@ def random_responding(): class QQBotManager: retry = 3 - #线程池控制 - pool = None - bot: Mirai = None reply_filter = None @@ -78,14 +74,10 @@ class QQBotManager: ban_person = [] ban_group = [] - def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, pool_num: int = 10, first_time_init=True): + def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True): self.timeout = timeout self.retry = retry - self.pool_num = pool_num - self.pool = ThreadPoolExecutor(max_workers=self.pool_num) - logging.debug("Registered thread pool Size:{}".format(pool_num)) - # 加载禁用列表 if os.path.exists("banlist.py"): import banlist @@ -138,7 +130,10 @@ class QQBotManager: self.on_person_message(event) - self.go(friend_message_handler, event) + pkg.utils.context.get_thread_ctl().submit_user_task( + friend_message_handler, + event + ) @self.bot.on(StrangerMessage) async def on_stranger_message(event: StrangerMessage): @@ -158,7 +153,10 @@ class QQBotManager: self.on_person_message(event) - self.go(stranger_message_handler, event) + pkg.utils.context.get_thread_ctl().submit_user_task( + stranger_message_handler, + event + ) @self.bot.on(GroupMessage) async def on_group_message(event: GroupMessage): @@ -178,7 +176,10 @@ class QQBotManager: self.on_group_message(event) - self.go(group_message_handler, event) + pkg.utils.context.get_thread_ctl().submit_user_task( + group_message_handler, + event + ) def unsubscribe_all(): """取消所有订阅 diff --git a/pkg/utils/__init__.py b/pkg/utils/__init__.py index e69de29b..5b1c9803 100644 --- a/pkg/utils/__init__.py +++ b/pkg/utils/__init__.py @@ -0,0 +1 @@ +from .threadctl import ThreadCtl \ No newline at end of file diff --git a/pkg/utils/context.py b/pkg/utils/context.py index 30f81770..854ab085 100644 --- a/pkg/utils/context.py +++ b/pkg/utils/context.py @@ -1,50 +1,91 @@ +import threading + context = { 'inst': { 'database.manager.DatabaseManager': None, 'openai.manager.OpenAIInteract': None, 'qqbot.manager.QQBotManager': None, }, + 'pool_ctl': None, 'logger_handler': None, 'config': None, 'plugin_host': None, } +context_lock = threading.Lock() - +### context耦合度非常高,需要大改 ### def set_config(inst): + context_lock.acquire() context['config'] = inst + context_lock.release() def get_config(): - return context['config'] + context_lock.acquire() + t = context['config'] + context_lock.release() + return t def set_database_manager(inst): + context_lock.acquire() context['inst']['database.manager.DatabaseManager'] = inst + context_lock.release() def get_database_manager(): - return context['inst']['database.manager.DatabaseManager'] + context_lock.acquire() + t = context['inst']['database.manager.DatabaseManager'] + context_lock.release() + return t def set_openai_manager(inst): + context_lock.acquire() context['inst']['openai.manager.OpenAIInteract'] = inst + context_lock.release() def get_openai_manager(): - return context['inst']['openai.manager.OpenAIInteract'] + context_lock.acquire() + t = context['inst']['openai.manager.OpenAIInteract'] + context_lock.release() + return t def set_qqbot_manager(inst): + context_lock.acquire() context['inst']['qqbot.manager.QQBotManager'] = inst + context_lock.release() def get_qqbot_manager(): - return context['inst']['qqbot.manager.QQBotManager'] + context_lock.acquire() + t = context['inst']['qqbot.manager.QQBotManager'] + context_lock.release() + return t def set_plugin_host(inst): + context_lock.acquire() context['plugin_host'] = inst + context_lock.release() def get_plugin_host(): - return context['plugin_host'] + context_lock.acquire() + t = context['plugin_host'] + context_lock.release() + return t + +def set_thread_ctl(inst): + context_lock.acquire() + context['pool_ctl'] = inst + context_lock.release() + +from pkg.utils import ThreadCtl +def get_thread_ctl() -> ThreadCtl: + context_lock.acquire() + t = context['pool_ctl'] + context_lock.release() + return t \ No newline at end of file diff --git a/pkg/utils/reloader.py b/pkg/utils/reloader.py index d701b067..3ab0edd9 100644 --- a/pkg/utils/reloader.py +++ b/pkg/utils/reloader.py @@ -3,7 +3,7 @@ import threading import importlib import pkgutil -import pkg.utils.context +import pkg.utils.context as context import pkg.plugin.host @@ -22,20 +22,20 @@ def walk(module, prefix='', path_prefix=''): def reload_all(notify=True): # 解除bot的事件注册 import pkg - pkg.utils.context.get_qqbot_manager().unsubscribe_all() + context.get_qqbot_manager().unsubscribe_all() # 执行关闭流程 logging.info("执行程序关闭流程") import main main.stop() # 重载所有模块 - pkg.utils.context.context['exceeded_keys'] = pkg.utils.context.get_openai_manager().key_mgr.exceeded - context = pkg.utils.context.context + context.context['exceeded_keys'] = context.get_openai_manager().key_mgr.exceeded + this_context = context.context walk(pkg) importlib.reload(__import__('config')) importlib.reload(__import__('main')) importlib.reload(__import__('banlist')) - pkg.utils.context.context = context + context.context = this_context # 重载插件 import plugins @@ -43,8 +43,15 @@ def reload_all(notify=True): # 执行启动流程 logging.info("执行程序启动流程") - threading.Thread(target=main.main, args=(False,), daemon=False).start() + context.get_thread_ctl().reload( + admin_pool_num=context.get_config().admin_pool_num, + user_pool_num=context.get_config().user_pool_num + ) + context.get_thread_ctl().submit_sys_task( + main.start, + False + ) logging.info('程序启动完成') if notify: - pkg.utils.context.get_qqbot_manager().notify_admin("重载完成") + context.get_qqbot_manager().notify_admin("重载完成") diff --git a/pkg/utils/threadctl.py b/pkg/utils/threadctl.py new file mode 100644 index 00000000..d5cfb601 --- /dev/null +++ b/pkg/utils/threadctl.py @@ -0,0 +1,93 @@ +from concurrent.futures import ThreadPoolExecutor, Future +import threading, time + +class Pool(): + ''' + 线程池结构 + ''' + pool_num:int = None + ctl:ThreadPoolExecutor = None + task_list:list = None + task_list_lock:threading.Lock = None + monitor_type = True + + def __init__(self, pool_num): + self.pool_num = pool_num + self.ctl = ThreadPoolExecutor(max_workers = self.pool_num) + self.task_list = [] + self.task_list_lock = threading.Lock() + + def __thread_monitor__(self): + while self.monitor_type: + for t in self.task_list: + if not t.done(): + continue + try: + self.task_list.pop(self.task_list.index(t)) + except: + continue + time.sleep(1) + +class ThreadCtl(): + def __init__(self, sys_pool_num, admin_pool_num, user_pool_num): + ''' + 线程池控制类 + sys_pool_num:分配系统使用的线程池数量(>=5) + admin_pool_num:用于处理管理员消息的线程池数量(>=1) + user_pool_num:分配用于处理用户消息的线程池的数量(>=1) + ''' + if sys_pool_num < 5: + raise Exception("Too few system threads(sys_pool_num needs >= 8, but received {})".format(sys_pool_num)) + if admin_pool_num < 1: + raise Exception("Too few admin threads(admin_pool_num needs >= 1, but received {})".format(admin_pool_num)) + if user_pool_num < 1: + raise Exception("Too few user threads(user_pool_num needs >= 1, but received {})".format(admin_pool_num)) + self.__sys_pool__ = Pool(sys_pool_num) + self.__admin_pool__ = Pool(admin_pool_num) + self.__user_pool__ = Pool(user_pool_num) + self.submit_sys_task(self.__sys_pool__.__thread_monitor__) + self.submit_sys_task(self.__admin_pool__.__thread_monitor__) + self.submit_sys_task(self.__user_pool__.__thread_monitor__) + + def __submit__(self, pool:Pool, fn, /, *args, **kwargs ): + t = pool.ctl.submit(fn, *args, **kwargs) + pool.task_list_lock.acquire() + pool.task_list.append(t) + pool.task_list_lock.release() + return t + + def submit_sys_task(self, fn, /, *args, **kwargs): + return self.__submit__( + self.__sys_pool__, + fn, *args, **kwargs + ) + + def submit_admin_task(self, fn, /, *args, **kwargs): + return self.__submit__( + self.__admin_pool__, + fn, *args, **kwargs + ) + + def submit_user_task(self, fn, /, *args, **kwargs): + return self.__submit__( + self.__user_pool__, + fn, *args, **kwargs + ) + + def shutdown(self): + self.__user_pool__.ctl.shutdown(cancel_futures=True) + self.__user_pool__.monitor_type = False + self.__admin_pool__.ctl.shutdown(cancel_futures=True) + self.__admin_pool__.monitor_type = False + self.__sys_pool__.monitor_type = False + self.__sys_pool__.ctl.shutdown(wait=True, cancel_futures=False) + + def reload(self, admin_pool_num, user_pool_num): + self.__user_pool__.ctl.shutdown(cancel_futures=True) + self.__user_pool__.monitor_type = False + self.__admin_pool__.ctl.shutdown(cancel_futures=True) + self.__admin_pool__.monitor_type = False + self.__admin_pool__ = Pool(admin_pool_num) + self.__user_pool__ = Pool(user_pool_num) + self.submit_sys_task(self.__admin_pool__.__thread_monitor__) + self.submit_sys_task(self.__user_pool__.__thread_monitor__)