From db2e3660145eeb0da9c3697b46f4d0650502255b Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 26 Nov 2023 22:46:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E7=AE=A1=E7=90=86=E5=99=A8=E5=B9=B6=E9=80=82?= =?UTF-8?q?=E9=85=8Dmain.py=E4=B8=AD=E7=9A=84=E5=BC=95=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 37 +++++++++++---------- pkg/config/__init__.py | 0 pkg/config/impls/pymodule.py | 62 ++++++++++++++++++++++++++++++++++++ pkg/config/manager.py | 23 +++++++++++++ pkg/config/model.py | 27 ++++++++++++++++ pkg/utils/context.py | 15 +++++++++ 6 files changed, 145 insertions(+), 19 deletions(-) create mode 100644 pkg/config/__init__.py create mode 100644 pkg/config/impls/pymodule.py create mode 100644 pkg/config/manager.py create mode 100644 pkg/config/model.py diff --git a/main.py b/main.py index c855bfbd..55fea88b 100644 --- a/main.py +++ b/main.py @@ -205,11 +205,11 @@ async def start_process(first_time_init=False): # 检查tips模块 complete_tips() - config = pkg.utils.context.get_config() + cfg = pkg.utils.context.get_config_manager().data # 更新openai库到最新版本 - if not hasattr(config, 'upgrade_dependencies') or config.upgrade_dependencies: + if 'upgrade_dependencies' not in cfg or cfg['upgrade_dependencies']: print("正在更新依赖库,请等待...") - if not hasattr(config, 'upgrade_dependencies'): + if 'upgrade_dependencies' not in cfg: print("这个操作不是必须的,如果不想更新,请在config.py中添加upgrade_dependencies=False") else: print("这个操作不是必须的,如果不想更新,请在config.py中将upgrade_dependencies设置为False") @@ -226,11 +226,11 @@ async def start_process(first_time_init=False): pkg.utils.context.context['logger_handler'] = sh # 检查是否设置了管理员 - if not (hasattr(config, 'admin_qq') and config.admin_qq != 0): + if cfg['admin_qq'] == 0: # logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段") while True: try: - config.admin_qq = int(input("未设置管理员QQ,管理员权限指令及运行告警将无法使用,请输入管理员QQ号: ")) + cfg['admin_qq'] = int(input("未设置管理员QQ,管理员权限指令及运行告警将无法使用,请输入管理员QQ号: ")) # 写入到文件 # 读取文件 @@ -238,7 +238,7 @@ async def start_process(first_time_init=False): with open("config.py", "r", encoding="utf-8") as f: config_file_str = f.read() # 替换 - config_file_str = config_file_str.replace("admin_qq = 0", "admin_qq = " + str(config.admin_qq)) + config_file_str = config_file_str.replace("admin_qq = 0", "admin_qq = " + str(cfg['admin_qq'])) # 写入 with open("config.py", "w", encoding="utf-8") as f: f.write(config_file_str) @@ -267,23 +267,23 @@ async def start_process(first_time_init=False): # 配置OpenAI proxy import openai openai.proxies = None # 先重置,因为重载后可能需要清除proxy - if "http_proxy" in config.openai_config and config.openai_config["http_proxy"] is not None: + if "http_proxy" in cfg['openai_config'] and cfg['openai_config']["http_proxy"] is not None: openai.proxies = { - "http": config.openai_config["http_proxy"], - "https": config.openai_config["http_proxy"] + "http": cfg['openai_config']["http_proxy"], + "https": cfg['openai_config']["http_proxy"] } # 配置openai api_base - if "reverse_proxy" in config.openai_config and config.openai_config["reverse_proxy"] is not None: - logging.debug("设置反向代理: "+config.openai_config['reverse_proxy']) - openai.base_url = config.openai_config["reverse_proxy"] + if "reverse_proxy" in cfg['openai_config'] and cfg['openai_config']["reverse_proxy"] is not None: + logging.debug("设置反向代理: "+cfg['openai_config']['reverse_proxy']) + openai.base_url = cfg['openai_config']["reverse_proxy"] # 主启动流程 database = pkg.database.manager.DatabaseManager() database.initialize_database() - openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key']) + openai_interact = pkg.openai.manager.OpenAIInteract(cfg['openai_config']['api_key']) # 加载所有未超时的session pkg.openai.session.load_sessions() @@ -372,13 +372,12 @@ async def start_process(first_time_init=False): if first_time_init: if not known_exception_caught: - import config - if config.msg_source_adapter == "yirimirai": - logging.info("QQ: {}, MAH: {}".format(config.mirai_http_api_config['qq'], config.mirai_http_api_config['host']+":"+str(config.mirai_http_api_config['port']))) + if cfg['msg_source_adapter'] == "yirimirai": + logging.info("QQ: {}, MAH: {}".format(cfg['mirai_http_api_config']['qq'], cfg['mirai_http_api_config']['host']+":"+str(cfg['mirai_http_api_config']['port']))) logging.critical('程序启动完成,如长时间未显示 "成功登录到账号xxxxx" ,并且不回复消息,解决办法(请勿到群里问): ' 'https://github.com/RockChinQ/QChatGPT/issues/37') - elif config.msg_source_adapter == 'nakuru': - logging.info("host: {}, port: {}, http_port: {}".format(config.nakuru_config['host'], config.nakuru_config['port'], config.nakuru_config['http_port'])) + elif cfg['msg_source_adapter'] == 'nakuru': + logging.info("host: {}, port: {}, http_port: {}".format(cfg['nakuru_config']['host'], cfg['nakuru_config']['port'], cfg['nakuru_config']['http_port'])) logging.critical('程序启动完成,如长时间未显示 "Protocol: connected" ,并且不回复消息,请检查config.py中的nakuru_config是否正确') else: sys.exit(1) @@ -386,7 +385,7 @@ async def start_process(first_time_init=False): logging.info('热重载完成') # 发送赞赏码 - if config.encourage_sponsor_at_start \ + if cfg['encourage_sponsor_at_start'] \ and pkg.utils.context.get_openai_manager().audit_mgr.get_total_text_length() >= 2048: logging.info("发送赞赏码") diff --git a/pkg/config/__init__.py b/pkg/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/config/impls/pymodule.py b/pkg/config/impls/pymodule.py new file mode 100644 index 00000000..691082de --- /dev/null +++ b/pkg/config/impls/pymodule.py @@ -0,0 +1,62 @@ +import os +import shutil +import importlib +import logging + +from .. import model as file_model + + +class PythonModuleConfigFile(file_model.ConfigFile): + """Python模块配置文件""" + + config_file_name: str = None + """配置文件名""" + + template_file_name: str = None + """模板文件名""" + + def __init__(self, config_file_name: str, template_file_name: str) -> None: + self.config_file_name = config_file_name + self.template_file_name = template_file_name + + def exists(self) -> bool: + return os.path.exists(self.config_file_name) + + async def create(self): + shutil.copyfile(self.template_file_name, self.config_file_name) + + async def load(self) -> dict: + module_name = os.path.splitext(os.path.basename(self.config_file_name))[0] + module = importlib.import_module(module_name) + + cfg = {} + + allowed_types = (int, float, str, bool, list, dict) + + for key in dir(module): + if key.startswith('__'): + continue + + if not isinstance(getattr(module, key), allowed_types): + continue + + cfg[key] = getattr(module, key) + + # 从模板模块文件中进行补全 + module_name = os.path.splitext(os.path.basename(self.template_file_name))[0] + module = importlib.import_module(module_name) + + for key in dir(module): + if key.startswith('__'): + continue + + if not isinstance(getattr(module, key), allowed_types): + continue + + if key not in cfg: + cfg[key] = getattr(module, key) + + return cfg + + async def save(self, data: dict): + logging.warning('Python模块配置文件不支持保存') diff --git a/pkg/config/manager.py b/pkg/config/manager.py new file mode 100644 index 00000000..53a6b099 --- /dev/null +++ b/pkg/config/manager.py @@ -0,0 +1,23 @@ +from . import model as file_model +from ..utils import context + + +class ConfigManager: + """配置文件管理器""" + + file: file_model.ConfigFile = None + """配置文件实例""" + + data: dict = None + """配置数据""" + + def __init__(self, cfg_file: file_model.ConfigFile) -> None: + self.file = cfg_file + self.data = {} + context.set_config_manager(self) + + async def load_config(self): + self.data = await self.file.load() + + async def dump_config(self): + await self.file.save(self.data) diff --git a/pkg/config/model.py b/pkg/config/model.py new file mode 100644 index 00000000..e72371ff --- /dev/null +++ b/pkg/config/model.py @@ -0,0 +1,27 @@ +import abc + + +class ConfigFile(metaclass=abc.ABCMeta): + """配置文件抽象类""" + + config_file_name: str = None + """配置文件名""" + + template_file_name: str = None + """模板文件名""" + + @abc.abstractmethod + def exists(self) -> bool: + pass + + @abc.abstractmethod + async def create(self): + pass + + @abc.abstractmethod + async def load(self) -> dict: + pass + + @abc.abstractmethod + async def save(self, data: dict): + pass diff --git a/pkg/utils/context.py b/pkg/utils/context.py index b208dac8..e26c702b 100644 --- a/pkg/utils/context.py +++ b/pkg/utils/context.py @@ -6,6 +6,7 @@ from . import threadctl from ..database import manager as db_mgr from ..openai import manager as openai_mgr from ..qqbot import manager as qqbot_mgr +from ..config import manager as config_mgr from ..plugin import host as plugin_host @@ -14,6 +15,7 @@ context = { 'database.manager.DatabaseManager': None, 'openai.manager.OpenAIInteract': None, 'qqbot.manager.QQBotManager': None, + 'config.manager.ConfigManager': None, }, 'pool_ctl': None, 'logger_handler': None, @@ -75,6 +77,19 @@ def get_qqbot_manager() -> qqbot_mgr.QQBotManager: return t +def set_config_manager(inst: config_mgr.ConfigManager): + context_lock.acquire() + context['inst']['config.manager.ConfigManager'] = inst + context_lock.release() + + +def get_config_manager() -> config_mgr.ConfigManager: + context_lock.acquire() + t = context['inst']['config.manager.ConfigManager'] + context_lock.release() + return t + + def set_plugin_host(inst: plugin_host.PluginHost): context_lock.acquire() context['plugin_host'] = inst