diff --git a/config-template.py b/config-template.py index 12a1ea89..d6e86c68 100644 --- a/config-template.py +++ b/config-template.py @@ -238,10 +238,18 @@ hide_exce_info_to_user = False # 设置为空字符串时,不发送提示信息 alter_tip_message = '出错了,请稍后再试' -# 机器人线程池大小 +# 线程池相关配置 # 该参数决定机器人可以同时处理几个人的消息,超出线程池数量的请求会被阻塞,不会被丢弃 # 如果你不清楚该参数的意义,请不要更改 -pool_num = 10 +# 程序运行本身线程池,无代码层面修改请勿更改 +sys_pool_num = 8 + +# 执行管理员请求和指令的线程池并行线程数量,一般和管理员数量相等 +admin_pool_num = 2 + +# 执行用户请求和指令的线程池并行线程数量 +# 如需要更高的并发,可以增大该值 +user_pool_num = 6 # 每个会话的过期时间,单位为秒 # 默认值20分钟 diff --git a/main.py b/main.py index f571725e..16130fe6 100644 --- a/main.py +++ b/main.py @@ -24,7 +24,7 @@ import colorlog import requests import websockets.exceptions from urllib3.exceptions import InsecureRequestWarning - +import pkg.utils.context sys.path.append(".") @@ -75,11 +75,8 @@ def init_runtime_log_file(): def reset_logging(): global log_file_name - assert os.path.exists('config.py') - config = importlib.import_module('config') - - import pkg.utils.context + import config if pkg.utils.context.context['logger_handler'] is not None: logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler']) @@ -107,12 +104,46 @@ def reset_logging(): return sh -def main(first_time_init=False): +# 临时函数,用于加载config和上下文,未来统一放在config类 +def load_config(): + # 完整性校验 + is_integrity = True + config_template = importlib.import_module('config-template') + config = importlib.import_module('config') + for key in dir(config_template): + if not key.startswith("__") and not hasattr(config, key): + setattr(config, key, getattr(config_template, key)) + logging.warning("[{}]不存在".format(key)) + is_integrity = False + + if not is_integrity: + logging.warning("配置文件不完整,请依据config-template.py检查config.py") + + # 检查override.json覆盖 + if os.path.exists("override.json"): + override_json = json.load(open("override.json", "r", encoding="utf-8")) + for key in override_json: + if hasattr(config, key): + setattr(config, key, override_json[key]) + logging.info("覆写配置[{}]为[{}]".format(key, override_json[key])) + else: + logging.error("无法覆写配置[{}]为[{}],该配置不存在,请检查override.json是否正确".format(key, override_json[key])) + + if not is_integrity: + logging.warning("以上配置已被设为默认值,将在5秒后继续启动... ") + time.sleep(5) + + # 存进上下文 + pkg.utils.context.set_config(config) + + +def start(first_time_init=False): """启动流程,reload之后会被执行""" global known_exception_caught + import pkg.utils.context - import config + config = pkg.utils.context.get_config() # 更新openai库到最新版本 if not hasattr(config, 'upgrade_dependencies') or config.upgrade_dependencies: print("正在更新依赖库,请等待...") @@ -127,43 +158,9 @@ def main(first_time_init=False): known_exception_caught = False try: - # 导入config.py - assert os.path.exists('config.py') - - config = importlib.import_module('config') - - init_runtime_log_file() sh = reset_logging() - - # 配置完整性校验 - is_integrity = True - config_template = importlib.import_module('config-template') - for key in dir(config_template): - if not key.startswith("__") and not hasattr(config, key): - setattr(config, key, getattr(config_template, key)) - logging.warning("[{}]不存在".format(key)) - is_integrity = False - - if not is_integrity: - logging.warning("配置文件不完整,请依据config-template.py检查config.py") - logging.warning("以上配置已被设为默认值,将在5秒后继续启动... ") - - # 检查override.json覆盖 - if os.path.exists("override.json"): - override_json = json.load(open("override.json", "r", encoding="utf-8")) - for key in override_json: - if hasattr(config, key): - setattr(config, key, override_json[key]) - logging.info("覆写配置[{}]为[{}]".format(key, override_json[key])) - else: - logging.error("无法覆写配置[{}]为[{}],该配置不存在,请检查override.json是否正确".format(key, override_json[key])) - - if not is_integrity: - time.sleep(5) - - import pkg.utils.context - pkg.utils.context.set_config(config) + pkg.utils.context.context['logger_handler'] = sh # 检查是否设置了管理员 if not (hasattr(config, 'admin_qq') and config.admin_qq != 0): @@ -198,7 +195,6 @@ def main(first_time_init=False): pkg.openai.dprompt.read_prompt_from_file() pkg.openai.dprompt.read_scenario_from_file() - pkg.utils.context.context['logger_handler'] = sh # 主启动流程 database = pkg.database.manager.DatabaseManager() @@ -212,7 +208,7 @@ def main(first_time_init=False): # 初始化qq机器人 qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config, timeout=config.process_message_timeout, retry=config.retry_times, - first_time_init=first_time_init, pool_num=config.pool_num) + first_time_init=first_time_init) # 加载插件 import pkg.plugin.host @@ -266,9 +262,15 @@ def main(first_time_init=False): "捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/QChatGPT/issues 查找或提issue".format(e)) known_exception_caught = True raise e - - qq_bot_thread = threading.Thread(target=run_bot_wrapper, args=(), daemon=True) - qq_bot_thread.start() + finally: + time.sleep(12) + threading.Thread( + target=run_bot_wrapper + ).start() + # 机器人暂时不能放在线程池中 + # pkg.utils.context.get_thread_ctl().submit_sys_task( + # run_bot_wrapper + # ) finally: # 判断若是Windows,输出选择模式可能会暂停程序的警告 if os.name == 'nt': @@ -276,6 +278,7 @@ def main(first_time_init=False): logging.info("您正在使用Windows系统,若命令行窗口处于“选择”模式,程序可能会被暂停,此时请右键点击窗口空白区域使其取消选择模式。") time.sleep(12) + if first_time_init: if not known_exception_caught: logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 ' @@ -324,9 +327,7 @@ def main(first_time_init=False): return qqbot - def stop(): - import pkg.utils.context import pkg.qqbot.manager import pkg.openai.session try: @@ -345,8 +346,8 @@ def stop(): raise e -if __name__ == '__main__': - # 检查是否有config.py,如果没有就把config-template.py复制一份,并退出程序 +def check_file(): + # 配置文件存在性校验 if not os.path.exists('config.py'): shutil.copy('config-template.py', 'config.py') print('请先在config.py中填写配置') @@ -374,6 +375,30 @@ if __name__ == '__main__': if not os.path.exists(path): os.mkdir(path) + +def main(): + # 初始化相关文件 + check_file() + + # 初始化logging + init_runtime_log_file() + pkg.utils.context.context['logger_handler'] = reset_logging() + + # 加载配置 + load_config() + config = pkg.utils.context.get_config() + + # 配置线程池 + from pkg.utils import ThreadCtl + thread_ctl = ThreadCtl( + sys_pool_num=config.sys_pool_num, + admin_pool_num=config.admin_pool_num, + user_pool_num=config.user_pool_num + ) + # 存进上下文 + pkg.utils.context.set_thread_ctl(thread_ctl) + + # 启动指令处理 if len(sys.argv) > 1 and sys.argv[1] == 'init_db': init_db() sys.exit(0) @@ -384,16 +409,29 @@ if __name__ == '__main__': updater.update_all(cli=True) sys.exit(0) + # 关闭urllib的http警告 requests.packages.urllib3.disable_warnings(InsecureRequestWarning) - qqbot = main(True) + pkg.utils.context.get_thread_ctl().submit_sys_task( + start, + True + ) - import pkg.utils.context + # 主线程循环 while True: try: - time.sleep(10) - except KeyboardInterrupt: + time.sleep(0xFF) + except: stop() + pkg.utils.context.get_thread_ctl().shutdown() + import platform + if platform.system() == 'Windows': + cmd = "taskkill /F /PID {}".format(os.getpid()) + elif platform.system() in ['Linux', 'Darwin']: + cmd = "kill -9 {}".format(os.getpid()) + os.system(cmd) + + +if __name__ == '__main__': + main() - print("程序退出") - sys.exit(0) diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 5d817eee..300edca8 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -2,7 +2,6 @@ import asyncio import json import os import threading -from concurrent.futures import ThreadPoolExecutor import mirai.models.bus from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \ @@ -66,9 +65,6 @@ def random_responding(): class QQBotManager: retry = 3 - #线程池控制 - pool = None - bot: Mirai = None reply_filter = None @@ -78,14 +74,10 @@ class QQBotManager: ban_person = [] ban_group = [] - def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, pool_num: int = 10, first_time_init=True): + def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True): self.timeout = timeout self.retry = retry - self.pool_num = pool_num - self.pool = ThreadPoolExecutor(max_workers=self.pool_num) - logging.debug("Registered thread pool Size:{}".format(pool_num)) - # 加载禁用列表 if os.path.exists("banlist.py"): import banlist @@ -138,7 +130,10 @@ class QQBotManager: self.on_person_message(event) - self.go(friend_message_handler, event) + pkg.utils.context.get_thread_ctl().submit_user_task( + friend_message_handler, + event + ) @self.bot.on(StrangerMessage) async def on_stranger_message(event: StrangerMessage): @@ -158,7 +153,10 @@ class QQBotManager: self.on_person_message(event) - self.go(stranger_message_handler, event) + pkg.utils.context.get_thread_ctl().submit_user_task( + stranger_message_handler, + event + ) @self.bot.on(GroupMessage) async def on_group_message(event: GroupMessage): @@ -178,7 +176,10 @@ class QQBotManager: self.on_group_message(event) - self.go(group_message_handler, event) + pkg.utils.context.get_thread_ctl().submit_user_task( + group_message_handler, + event + ) def unsubscribe_all(): """取消所有订阅 diff --git a/pkg/utils/__init__.py b/pkg/utils/__init__.py index e69de29b..5b1c9803 100644 --- a/pkg/utils/__init__.py +++ b/pkg/utils/__init__.py @@ -0,0 +1 @@ +from .threadctl import ThreadCtl \ No newline at end of file diff --git a/pkg/utils/context.py b/pkg/utils/context.py index 30f81770..2f8dee44 100644 --- a/pkg/utils/context.py +++ b/pkg/utils/context.py @@ -1,50 +1,94 @@ +import threading +from pkg.utils import ThreadCtl + + context = { 'inst': { 'database.manager.DatabaseManager': None, 'openai.manager.OpenAIInteract': None, 'qqbot.manager.QQBotManager': None, }, + 'pool_ctl': None, 'logger_handler': None, 'config': None, 'plugin_host': None, } +context_lock = threading.Lock() - +### context耦合度非常高,需要大改 ### def set_config(inst): + context_lock.acquire() context['config'] = inst + context_lock.release() def get_config(): - return context['config'] + context_lock.acquire() + t = context['config'] + context_lock.release() + return t def set_database_manager(inst): + context_lock.acquire() context['inst']['database.manager.DatabaseManager'] = inst + context_lock.release() def get_database_manager(): - return context['inst']['database.manager.DatabaseManager'] + context_lock.acquire() + t = context['inst']['database.manager.DatabaseManager'] + context_lock.release() + return t def set_openai_manager(inst): + context_lock.acquire() context['inst']['openai.manager.OpenAIInteract'] = inst + context_lock.release() def get_openai_manager(): - return context['inst']['openai.manager.OpenAIInteract'] + context_lock.acquire() + t = context['inst']['openai.manager.OpenAIInteract'] + context_lock.release() + return t def set_qqbot_manager(inst): + context_lock.acquire() context['inst']['qqbot.manager.QQBotManager'] = inst + context_lock.release() def get_qqbot_manager(): - return context['inst']['qqbot.manager.QQBotManager'] + context_lock.acquire() + t = context['inst']['qqbot.manager.QQBotManager'] + context_lock.release() + return t def set_plugin_host(inst): + context_lock.acquire() context['plugin_host'] = inst + context_lock.release() def get_plugin_host(): - return context['plugin_host'] + context_lock.acquire() + t = context['plugin_host'] + context_lock.release() + return t + + +def set_thread_ctl(inst): + context_lock.acquire() + context['pool_ctl'] = inst + context_lock.release() + + +def get_thread_ctl() -> ThreadCtl: + context_lock.acquire() + t: ThreadCtl = context['pool_ctl'] + context_lock.release() + return t diff --git a/pkg/utils/reloader.py b/pkg/utils/reloader.py index d701b067..0449a40c 100644 --- a/pkg/utils/reloader.py +++ b/pkg/utils/reloader.py @@ -3,7 +3,7 @@ import threading import importlib import pkgutil -import pkg.utils.context +import pkg.utils.context as context import pkg.plugin.host @@ -22,20 +22,21 @@ def walk(module, prefix='', path_prefix=''): def reload_all(notify=True): # 解除bot的事件注册 import pkg - pkg.utils.context.get_qqbot_manager().unsubscribe_all() + context.get_qqbot_manager().unsubscribe_all() # 执行关闭流程 logging.info("执行程序关闭流程") import main main.stop() # 重载所有模块 - pkg.utils.context.context['exceeded_keys'] = pkg.utils.context.get_openai_manager().key_mgr.exceeded - context = pkg.utils.context.context + context.context['exceeded_keys'] = context.get_openai_manager().key_mgr.exceeded + this_context = context.context walk(pkg) + importlib.reload(__import__("config-template")) importlib.reload(__import__('config')) importlib.reload(__import__('main')) importlib.reload(__import__('banlist')) - pkg.utils.context.context = context + context.context = this_context # 重载插件 import plugins @@ -43,8 +44,16 @@ def reload_all(notify=True): # 执行启动流程 logging.info("执行程序启动流程") - threading.Thread(target=main.main, args=(False,), daemon=False).start() + main.load_config() + context.get_thread_ctl().reload( + admin_pool_num=context.get_config().admin_pool_num, + user_pool_num=context.get_config().user_pool_num + ) + context.get_thread_ctl().submit_sys_task( + main.start, + False + ) logging.info('程序启动完成') if notify: - pkg.utils.context.get_qqbot_manager().notify_admin("重载完成") + context.get_qqbot_manager().notify_admin("重载完成") diff --git a/pkg/utils/threadctl.py b/pkg/utils/threadctl.py new file mode 100644 index 00000000..4cf35a9a --- /dev/null +++ b/pkg/utils/threadctl.py @@ -0,0 +1,96 @@ +import threading +import time +from concurrent.futures import ThreadPoolExecutor + + +class Pool: + ''' + 线程池结构 + ''' + pool_num:int = None + ctl:ThreadPoolExecutor = None + task_list:list = None + task_list_lock:threading.Lock = None + monitor_type = True + + def __init__(self, pool_num): + self.pool_num = pool_num + self.ctl = ThreadPoolExecutor(max_workers = self.pool_num) + self.task_list = [] + self.task_list_lock = threading.Lock() + + def __thread_monitor__(self): + while self.monitor_type: + for t in self.task_list: + if not t.done(): + continue + try: + self.task_list.pop(self.task_list.index(t)) + except: + continue + time.sleep(1) + + +class ThreadCtl: + def __init__(self, sys_pool_num, admin_pool_num, user_pool_num): + ''' + 线程池控制类 + sys_pool_num:分配系统使用的线程池数量(>=8) + admin_pool_num:用于处理管理员消息的线程池数量(>=1) + user_pool_num:分配用于处理用户消息的线程池的数量(>=1) + ''' + if sys_pool_num < 5: + raise Exception("Too few system threads(sys_pool_num needs >= 8, but received {})".format(sys_pool_num)) + if admin_pool_num < 1: + raise Exception("Too few admin threads(admin_pool_num needs >= 1, but received {})".format(admin_pool_num)) + if user_pool_num < 1: + raise Exception("Too few user threads(user_pool_num needs >= 1, but received {})".format(admin_pool_num)) + self.__sys_pool__ = Pool(sys_pool_num) + self.__admin_pool__ = Pool(admin_pool_num) + self.__user_pool__ = Pool(user_pool_num) + self.submit_sys_task(self.__sys_pool__.__thread_monitor__) + self.submit_sys_task(self.__admin_pool__.__thread_monitor__) + self.submit_sys_task(self.__user_pool__.__thread_monitor__) + + def __submit__(self, pool: Pool, fn, /, *args, **kwargs ): + t = pool.ctl.submit(fn, *args, **kwargs) + pool.task_list_lock.acquire() + pool.task_list.append(t) + pool.task_list_lock.release() + return t + + def submit_sys_task(self, fn, /, *args, **kwargs): + return self.__submit__( + self.__sys_pool__, + fn, *args, **kwargs + ) + + def submit_admin_task(self, fn, /, *args, **kwargs): + return self.__submit__( + self.__admin_pool__, + fn, *args, **kwargs + ) + + def submit_user_task(self, fn, /, *args, **kwargs): + return self.__submit__( + self.__user_pool__, + fn, *args, **kwargs + ) + + def shutdown(self): + self.__user_pool__.ctl.shutdown(cancel_futures=True) + self.__user_pool__.monitor_type = False + self.__admin_pool__.ctl.shutdown(cancel_futures=True) + self.__admin_pool__.monitor_type = False + self.__sys_pool__.monitor_type = False + self.__sys_pool__.ctl.shutdown(wait=True, cancel_futures=False) + + def reload(self, admin_pool_num, user_pool_num): + self.__user_pool__.ctl.shutdown(cancel_futures=True) + self.__user_pool__.monitor_type = False + self.__admin_pool__.ctl.shutdown(cancel_futures=True) + self.__admin_pool__.monitor_type = False + self.__admin_pool__ = Pool(admin_pool_num) + self.__user_pool__ = Pool(user_pool_num) + self.submit_sys_task(self.__admin_pool__.__thread_monitor__) + self.submit_sys_task(self.__user_pool__.__thread_monitor__)