feat: 实现配置文件管理器并适配main.py中的引用

This commit is contained in:
RockChinQ
2023-11-26 22:46:27 +08:00
parent 26e4215054
commit db2e366014
6 changed files with 145 additions and 19 deletions

37
main.py
View File

@@ -205,11 +205,11 @@ async def start_process(first_time_init=False):
# 检查tips模块
complete_tips()
config = pkg.utils.context.get_config()
cfg = pkg.utils.context.get_config_manager().data
# 更新openai库到最新版本
if not hasattr(config, 'upgrade_dependencies') or config.upgrade_dependencies:
if 'upgrade_dependencies' not in cfg or cfg['upgrade_dependencies']:
print("正在更新依赖库,请等待...")
if not hasattr(config, 'upgrade_dependencies'):
if 'upgrade_dependencies' not in cfg:
print("这个操作不是必须的,如果不想更新,请在config.py中添加upgrade_dependencies=False")
else:
print("这个操作不是必须的,如果不想更新,请在config.py中将upgrade_dependencies设置为False")
@@ -226,11 +226,11 @@ async def start_process(first_time_init=False):
pkg.utils.context.context['logger_handler'] = sh
# 检查是否设置了管理员
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
if cfg['admin_qq'] == 0:
# logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
while True:
try:
config.admin_qq = int(input("未设置管理员QQ,管理员权限指令及运行告警将无法使用,请输入管理员QQ号: "))
cfg['admin_qq'] = int(input("未设置管理员QQ,管理员权限指令及运行告警将无法使用,请输入管理员QQ号: "))
# 写入到文件
# 读取文件
@@ -238,7 +238,7 @@ async def start_process(first_time_init=False):
with open("config.py", "r", encoding="utf-8") as f:
config_file_str = f.read()
# 替换
config_file_str = config_file_str.replace("admin_qq = 0", "admin_qq = " + str(config.admin_qq))
config_file_str = config_file_str.replace("admin_qq = 0", "admin_qq = " + str(cfg['admin_qq']))
# 写入
with open("config.py", "w", encoding="utf-8") as f:
f.write(config_file_str)
@@ -267,23 +267,23 @@ async def start_process(first_time_init=False):
# 配置OpenAI proxy
import openai
openai.proxies = None # 先重置因为重载后可能需要清除proxy
if "http_proxy" in config.openai_config and config.openai_config["http_proxy"] is not None:
if "http_proxy" in cfg['openai_config'] and cfg['openai_config']["http_proxy"] is not None:
openai.proxies = {
"http": config.openai_config["http_proxy"],
"https": config.openai_config["http_proxy"]
"http": cfg['openai_config']["http_proxy"],
"https": cfg['openai_config']["http_proxy"]
}
# 配置openai api_base
if "reverse_proxy" in config.openai_config and config.openai_config["reverse_proxy"] is not None:
logging.debug("设置反向代理: "+config.openai_config['reverse_proxy'])
openai.base_url = config.openai_config["reverse_proxy"]
if "reverse_proxy" in cfg['openai_config'] and cfg['openai_config']["reverse_proxy"] is not None:
logging.debug("设置反向代理: "+cfg['openai_config']['reverse_proxy'])
openai.base_url = cfg['openai_config']["reverse_proxy"]
# 主启动流程
database = pkg.database.manager.DatabaseManager()
database.initialize_database()
openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'])
openai_interact = pkg.openai.manager.OpenAIInteract(cfg['openai_config']['api_key'])
# 加载所有未超时的session
pkg.openai.session.load_sessions()
@@ -372,13 +372,12 @@ async def start_process(first_time_init=False):
if first_time_init:
if not known_exception_caught:
import config
if config.msg_source_adapter == "yirimirai":
logging.info("QQ: {}, MAH: {}".format(config.mirai_http_api_config['qq'], config.mirai_http_api_config['host']+":"+str(config.mirai_http_api_config['port'])))
if cfg['msg_source_adapter'] == "yirimirai":
logging.info("QQ: {}, MAH: {}".format(cfg['mirai_http_api_config']['qq'], cfg['mirai_http_api_config']['host']+":"+str(cfg['mirai_http_api_config']['port'])))
logging.critical('程序启动完成,如长时间未显示 "成功登录到账号xxxxx" ,并且不回复消息,解决办法(请勿到群里问): '
'https://github.com/RockChinQ/QChatGPT/issues/37')
elif config.msg_source_adapter == 'nakuru':
logging.info("host: {}, port: {}, http_port: {}".format(config.nakuru_config['host'], config.nakuru_config['port'], config.nakuru_config['http_port']))
elif cfg['msg_source_adapter'] == 'nakuru':
logging.info("host: {}, port: {}, http_port: {}".format(cfg['nakuru_config']['host'], cfg['nakuru_config']['port'], cfg['nakuru_config']['http_port']))
logging.critical('程序启动完成,如长时间未显示 "Protocol: connected" ,并且不回复消息,请检查config.py中的nakuru_config是否正确')
else:
sys.exit(1)
@@ -386,7 +385,7 @@ async def start_process(first_time_init=False):
logging.info('热重载完成')
# 发送赞赏码
if config.encourage_sponsor_at_start \
if cfg['encourage_sponsor_at_start'] \
and pkg.utils.context.get_openai_manager().audit_mgr.get_total_text_length() >= 2048:
logging.info("发送赞赏码")

0
pkg/config/__init__.py Normal file
View File

View File

@@ -0,0 +1,62 @@
import os
import shutil
import importlib
import logging
from .. import model as file_model
class PythonModuleConfigFile(file_model.ConfigFile):
"""Python模块配置文件"""
config_file_name: str = None
"""配置文件名"""
template_file_name: str = None
"""模板文件名"""
def __init__(self, config_file_name: str, template_file_name: str) -> None:
self.config_file_name = config_file_name
self.template_file_name = template_file_name
def exists(self) -> bool:
return os.path.exists(self.config_file_name)
async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name)
async def load(self) -> dict:
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
module = importlib.import_module(module_name)
cfg = {}
allowed_types = (int, float, str, bool, list, dict)
for key in dir(module):
if key.startswith('__'):
continue
if not isinstance(getattr(module, key), allowed_types):
continue
cfg[key] = getattr(module, key)
# 从模板模块文件中进行补全
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
module = importlib.import_module(module_name)
for key in dir(module):
if key.startswith('__'):
continue
if not isinstance(getattr(module, key), allowed_types):
continue
if key not in cfg:
cfg[key] = getattr(module, key)
return cfg
async def save(self, data: dict):
logging.warning('Python模块配置文件不支持保存')

23
pkg/config/manager.py Normal file
View File

@@ -0,0 +1,23 @@
from . import model as file_model
from ..utils import context
class ConfigManager:
"""配置文件管理器"""
file: file_model.ConfigFile = None
"""配置文件实例"""
data: dict = None
"""配置数据"""
def __init__(self, cfg_file: file_model.ConfigFile) -> None:
self.file = cfg_file
self.data = {}
context.set_config_manager(self)
async def load_config(self):
self.data = await self.file.load()
async def dump_config(self):
await self.file.save(self.data)

27
pkg/config/model.py Normal file
View File

@@ -0,0 +1,27 @@
import abc
class ConfigFile(metaclass=abc.ABCMeta):
"""配置文件抽象类"""
config_file_name: str = None
"""配置文件名"""
template_file_name: str = None
"""模板文件名"""
@abc.abstractmethod
def exists(self) -> bool:
pass
@abc.abstractmethod
async def create(self):
pass
@abc.abstractmethod
async def load(self) -> dict:
pass
@abc.abstractmethod
async def save(self, data: dict):
pass

View File

@@ -6,6 +6,7 @@ from . import threadctl
from ..database import manager as db_mgr
from ..openai import manager as openai_mgr
from ..qqbot import manager as qqbot_mgr
from ..config import manager as config_mgr
from ..plugin import host as plugin_host
@@ -14,6 +15,7 @@ context = {
'database.manager.DatabaseManager': None,
'openai.manager.OpenAIInteract': None,
'qqbot.manager.QQBotManager': None,
'config.manager.ConfigManager': None,
},
'pool_ctl': None,
'logger_handler': None,
@@ -75,6 +77,19 @@ def get_qqbot_manager() -> qqbot_mgr.QQBotManager:
return t
def set_config_manager(inst: config_mgr.ConfigManager):
context_lock.acquire()
context['inst']['config.manager.ConfigManager'] = inst
context_lock.release()
def get_config_manager() -> config_mgr.ConfigManager:
context_lock.acquire()
t = context['inst']['config.manager.ConfigManager']
context_lock.release()
return t
def set_plugin_host(inst: plugin_host.PluginHost):
context_lock.acquire()
context['plugin_host'] = inst