diff --git a/README.md b/README.md index 58bc9161..dad555b3 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,12 @@ - 现已支持OpenAI的对话`Completion API`和绘图`Image API` - 向机器人发送指令`!draw `即可使用绘图模型 +
+✅支持指令控制热重载、热更新 + + - 允许在运行期间修改`config.py`或其他代码后,以管理员账号向机器人发送指令`!reload`进行热重载,无需重启 + - 运行期间允许以管理员账号向机器人发送指令`!update`进行热更新,拉取远程最新代码并执行热重载 +
## 💻技术栈 diff --git a/config-template.py b/config-template.py index 13fa8ad9..06367cda 100644 --- a/config-template.py +++ b/config-template.py @@ -8,6 +8,8 @@ import logging # port: 运行mirai的主机端口 # verifyKey: mirai-api-http的verifyKey # qq: 机器人的QQ号 +# +# 注意: QQ机器人配置不支持热重载及热更新 mirai_http_api_config = { "adapter": "WebSocketAdapter", "host": "localhost", @@ -30,6 +32,9 @@ openai_config = { }, } +# 管理员QQ号,用于接收报错等通知及执行管理员级别指令,为0时关闭此功能 +admin_qq = 0 + # 情景预设(机器人人格) # 每个会话的预设信息,影响所有会话,无视指令重置 # 可以通过这个字段指定某些情况的回复,可直接用自然语言描述指令 @@ -38,9 +43,6 @@ openai_config = { # 可参考 https://github.com/PlexPt/awesome-chatgpt-prompts-zh default_prompt = "如果我之后想获取帮助,请你说“输入!help获取帮助”" -# 管理员QQ号,用于接收报错等通知,为0时不发送通知 -admin_qq = 0 - # 群内响应规则 # 符合此消息的群内消息即使不包含at机器人也会响应 # 支持消息前缀匹配及正则表达式匹配 diff --git a/main.py b/main.py index 23150f1d..330c1e9d 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import asyncio import os import shutil import sys @@ -7,6 +8,8 @@ import time import logging import colorlog +from mirai.bot import MiraiRunner + import sys sys.path.append(".") @@ -27,10 +30,18 @@ def init_db(): database.initialize_database() -def main(): +def main(first_time_init=False): # 导入config.py assert os.path.exists('config.py') + + # 检查是否设置了管理员 import config + if not (hasattr(config, 'admin_qq') and config.admin_qq != 0): + logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段") + + import pkg.utils.context + if pkg.utils.context.context['logger_handler'] is not None: + logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler']) logging.basicConfig(level=config.logging_level, # 设置日志输出格式 filename='qchatgpt.log', # log日志输出的文件位置和文件名 @@ -54,6 +65,7 @@ def main(): import pkg.openai.session import pkg.qqbot.manager + pkg.utils.context.context['logger_handler'] = sh # 主启动流程 database = pkg.database.manager.DatabaseManager() @@ -66,29 +78,46 @@ def main(): # 初始化qq机器人 qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config, - timeout=config.process_message_timeout, retry=config.retry_times) + timeout=config.process_message_timeout, retry=config.retry_times, + first_time_init=first_time_init) - qq_bot_thread = threading.Thread(target=qqbot.bot.run, args=(), daemon=True) - qq_bot_thread.start() + if first_time_init: # 不是热重载之后的启动,则不启动新的bot线程 + qq_bot_thread = threading.Thread(target=qqbot.bot.run, args=(), daemon=True) + qq_bot_thread.start() - logging.info('程序启动完成') + time.sleep(2) + logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 https://github.com/RockChinQ/QChatGPT/issues/37') while True: try: - time.sleep(86400) + time.sleep(10000) + if qqbot != pkg.utils.context.get_qqbot_manager(): # 已经reload了 + logging.info("以前的main流程由于reload退出") + break except KeyboardInterrupt: - try: - pkg.openai.manager.get_inst().key_mgr.dump_fee() - for session in pkg.openai.session.sessions: - logging.info('持久化session: %s', session) - pkg.openai.session.sessions[session].persistence() - except Exception as e: - if not isinstance(e, KeyboardInterrupt): - raise e + stop() + print("程序退出") sys.exit(0) +def stop(): + import pkg.utils.context + import pkg.qqbot.manager + import pkg.openai.session + try: + qqbot_inst = pkg.utils.context.get_qqbot_manager() + assert isinstance(qqbot_inst, pkg.qqbot.manager.QQBotManager) + + pkg.utils.context.get_openai_manager().key_mgr.dump_fee() + for session in pkg.openai.session.sessions: + logging.info('持久化session: %s', session) + pkg.openai.session.sessions[session].persistence() + except Exception as e: + if not isinstance(e, KeyboardInterrupt): + raise e + + if __name__ == '__main__': # 检查是否有config.py,如果没有就把config-template.py复制一份,并退出程序 if not os.path.exists('config.py'): @@ -109,4 +138,4 @@ if __name__ == '__main__': print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") sys.exit(0) - main() + main(True) diff --git a/pkg/database/manager.py b/pkg/database/manager.py index cbe243e9..5c9b700b 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -6,8 +6,7 @@ from sqlite3 import Cursor import sqlite3 import config - -inst = None +import pkg.utils.context # 数据库管理 @@ -20,8 +19,7 @@ class DatabaseManager: self.reconnect() - global inst - inst = self + pkg.utils.context.set_database_manager(self) # 连接到数据库文件 def reconnect(self): @@ -312,6 +310,3 @@ class DatabaseManager: fee[key_md5] = fee_count return fee -def get_inst() -> DatabaseManager: - global inst - return inst diff --git a/pkg/openai/keymgr.py b/pkg/openai/keymgr.py index e8992c7e..cf9898aa 100644 --- a/pkg/openai/keymgr.py +++ b/pkg/openai/keymgr.py @@ -4,6 +4,7 @@ import logging import pkg.database.manager import pkg.qqbot.manager +import pkg.utils.context import config @@ -62,43 +63,6 @@ class KeysManager: def add(self, key_name, key): self.api_key[key_name] = key - # def get_usage(self, api_key): - # md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest() - # if md5 not in self.usage: - # self.usage[md5] = 0 - # return self.usage[md5] - - # 报告使用 - # 返回是否需要将openai的api-key切换 - # def report_usage(self, new_content: str) -> bool: - # md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest() - # if md5 not in self.usage: - # self.usage[md5] = 0 - # - # # 经测算得出的理论与实际的偏差比例 - # salt_rate = 0.91 - # - # self.usage[md5] += ( (len(new_content.encode('utf-8')) - len(new_content)) / 2 + len(new_content) )*salt_rate - # - # self.usage[md5] = int(self.usage[md5]) - # - # if self.usage[md5] >= self.api_key_usage_threshold: - # switch_result, key_name = self.auto_switch() - # - # # 检查是否切换到新的 - # if switch_result: - # if key_name not in self.alerted: - # # 通知管理员 - # pkg.qqbot.manager.get_inst().notify_admin("api-key已切换到:" + key_name) - # self.alerted.append(key_name) - # return True - # else: - # if key_name not in self.alerted: - # # 通知管理员 - # pkg.qqbot.manager.get_inst().notify_admin("api-key已用完,无未使用的api-key可供切换") - # self.alerted.append(key_name) - # return False - # 设置当前使用的api-key使用量超限 # 这是在尝试调用api时发生超限异常时调用的 def set_current_exceeded(self): @@ -107,14 +71,6 @@ class KeysManager: self.fee[md5] = self.api_key_fee_threshold self.dump_fee() - # def dump_usage(self): - # pkg.database.manager.get_inst().dump_api_key_usage(api_keys=self.api_key, usage=self.usage) - - # def load_usage(self): - # self.usage = pkg.database.manager.get_inst().load_api_key_usage() - # logging.debug("load usage:" + str(self.usage)) - # print("load usage:" + str(self.usage)) - def get_fee(self, api_key): md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest() if md5 not in self.fee: @@ -135,19 +91,19 @@ class KeysManager: if switch_result: if key_name not in self.alerted: # 通知管理员 - pkg.qqbot.manager.get_inst().notify_admin("api-key已切换到:" + key_name) + pkg.utils.context.get_qqbot_manager().notify_admin("api-key已切换到:" + key_name) self.alerted.append(key_name) return True else: if key_name not in self.alerted: # 通知管理员 - pkg.qqbot.manager.get_inst().notify_admin("api-key已用完,无未使用的api-key可供切换") + pkg.utils.context.get_qqbot_manager().notify_admin("api-key已用完,无未使用的api-key可供切换") self.alerted.append(key_name) return False def dump_fee(self): - pkg.database.manager.get_inst().dump_api_key_fee(api_keys=self.api_key, fee=self.fee) + pkg.utils.context.get_database_manager().dump_api_key_fee(api_keys=self.api_key, fee=self.fee) def load_fee(self): - self.fee = pkg.database.manager.get_inst().load_api_key_fee() - logging.info("load fee:" + str(self.fee)) \ No newline at end of file + self.fee = pkg.utils.context.get_database_manager().load_api_key_fee() + logging.info("load fee:" + str(self.fee)) diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index d65f0e54..e59146d6 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -6,8 +6,7 @@ import config import pkg.openai.keymgr import pkg.openai.pricing as pricing - -inst = None +import pkg.utils.context # 为其他模块提供与OpenAI交互的接口 @@ -27,11 +26,11 @@ class OpenAIInteract: openai.api_key = self.key_mgr.get_using_key() - global inst - inst = self + pkg.utils.context.set_openai_manager(self) # 请求OpenAI Completion def request_completion(self, prompt, stop): + # print("request") response = openai.Completion.create( prompt=prompt, stop=stop, @@ -41,7 +40,6 @@ class OpenAIInteract: switched = self.key_mgr.report_fee(pricing.language_base_price(config.completion_api_params['model'], prompt + response['choices'][0]['text'])) - if switched: openai.api_key = self.key_mgr.get_using_key() @@ -64,7 +62,3 @@ class OpenAIInteract: return response - -def get_inst() -> OpenAIInteract: - global inst - return inst diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py new file mode 100644 index 00000000..c106824e --- /dev/null +++ b/pkg/openai/modelmgr.py @@ -0,0 +1,34 @@ +# 提供与模型交互的抽象接口 + +COMPLETION_MODELS = { + 'text-davinci-003' +} + +EDIT_MODELS = { + +} + +IMAGE_MODELS = { + +} + + +# ModelManager +# 由session包含 +class ModelMgr(object): + + using_completion_model = "" + using_edit_model = "" + using_image_model = "" + + def __init__(self): + pass + + def get_using_completion_model(self): + return self.using_completion_model + + def get_using_edit_model(self): + return self.using_edit_model + + def get_using_image_model(self): + return self.using_image_model diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 9bc9e9e1..631d4131 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -5,6 +5,7 @@ import time import config import pkg.openai.manager import pkg.database.manager +import pkg.utils.context # 运行时保存的所有session sessions = {} @@ -19,7 +20,7 @@ class SessionOfflineStatus: def load_sessions(): global sessions - db_inst = pkg.database.manager.get_inst() + db_inst = pkg.utils.context.get_database_manager() session_data = db_inst.load_valid_sessions() @@ -147,10 +148,11 @@ class Session: max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024 # 向API请求补全 - response = pkg.openai.manager.get_inst().request_completion(self.cut_out(self.prompt + self.user_name + ':' + - text + '\n' + self.bot_name + ':', - max_rounds, max_length), - self.user_name + ':') + response = pkg.utils.context.get_openai_manager().request_completion( + self.cut_out(self.prompt + self.user_name + ':' + + text + '\n' + self.bot_name + ':', + max_rounds, max_length), + self.user_name + ':') self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':' # print(response) @@ -202,7 +204,7 @@ class Session: if self.prompt == get_default_prompt(): return - db_inst = pkg.database.manager.get_inst() + db_inst = pkg.utils.context.get_database_manager() name_spt = self.name.split('_') @@ -217,10 +219,10 @@ class Session: if self.prompt != get_default_prompt(): self.persistence() if explicit: - pkg.database.manager.get_inst().explicit_close_session(self.name, self.create_timestamp) + pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp) if expired: - pkg.database.manager.get_inst().set_session_expired(self.name, self.create_timestamp) + pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp) self.prompt = get_default_prompt() self.create_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time()) @@ -233,11 +235,11 @@ class Session: # 将本session的数据库状态设置为on_going def set_ongoing(self): - pkg.database.manager.get_inst().set_session_ongoing(self.name, self.create_timestamp) + pkg.utils.context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp) # 切换到上一个session def last_session(self): - last_one = pkg.database.manager.get_inst().last_session(self.name, self.last_interact_timestamp) + last_one = pkg.utils.context.get_database_manager().last_session(self.name, self.last_interact_timestamp) if last_one is None: return None else: @@ -252,7 +254,7 @@ class Session: # 切换到下一个session def next_session(self): - next_one = pkg.database.manager.get_inst().next_session(self.name, self.last_interact_timestamp) + next_one = pkg.utils.context.get_database_manager().next_session(self.name, self.last_interact_timestamp) if next_one is None: return None else: @@ -266,8 +268,8 @@ class Session: return self def list_history(self, capacity: int = 10, page: int = 0): - return pkg.database.manager.get_inst().list_history(self.name, capacity, page, - get_default_prompt()) + return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page, + get_default_prompt()) def draw_image(self, prompt: str): - return pkg.openai.manager.get_inst().request_image(prompt) + return pkg.utils.context.get_openai_manager().request_image(prompt) diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index a301664f..b8d80454 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -3,10 +3,13 @@ import json import os import threading +import mirai.models.bus import openai.error from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage, WebSocketAdapter, HTTPAdapter, \ FriendMessage, Image +from mirai.models.bus import ModelEventBus + from mirai.models.message import Quote import config @@ -17,8 +20,7 @@ import logging import pkg.qqbot.filter import pkg.qqbot.process as processor - -inst = None +import pkg.utils.context # 并行运行 @@ -58,7 +60,7 @@ class QQBotManager: reply_filter = None - def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3): + def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True): self.timeout = timeout self.retry = retry @@ -71,6 +73,47 @@ class QQBotManager: else: self.reply_filter = pkg.qqbot.filter.ReplyFilter([]) + # 由于YiriMirai的bot对象是单例的,且shutdown方法暂时无法使用 + # 故只在第一次初始化时创建bot对象,重载之后使用原bot对象 + # 因此,bot的配置不支持热重载 + if first_time_init: + self.first_time_init(mirai_http_api_config) + else: + self.bot = pkg.utils.context.get_qqbot_manager().bot + + pkg.utils.context.set_qqbot_manager(self) + + # Caution: 注册新的事件处理器之后,请务必在unsubscribe_all中编写相应的取消订阅代码 + @self.bot.on(FriendMessage) + async def on_friend_message(event: FriendMessage): + go(self.on_person_message, (event,)) + + @self.bot.on(StrangerMessage) + async def on_stranger_message(event: StrangerMessage): + go(self.on_person_message, (event,)) + + @self.bot.on(GroupMessage) + async def on_group_message(event: GroupMessage): + go(self.on_group_message, (event,)) + + def unsubscribe_all(): + """取消所有订阅 + + 用于在热重载流程中卸载所有事件处理器 + """ + assert isinstance(self.bot, Mirai) + bus = self.bot.bus + assert isinstance(bus, mirai.models.bus.ModelEventBus) + + bus.unsubscribe(FriendMessage, on_friend_message) + bus.unsubscribe(StrangerMessage, on_stranger_message) + bus.unsubscribe(GroupMessage, on_group_message) + + self.unsubscribe_all = unsubscribe_all + + def first_time_init(self, mirai_http_api_config: dict): + """热重载后不再运行此函数""" + if 'adapter' not in mirai_http_api_config or mirai_http_api_config['adapter'] == "WebSocketAdapter": bot = Mirai( qq=mirai_http_api_config['qq'], @@ -93,23 +136,8 @@ class QQBotManager: else: raise Exception("未知的适配器类型") - @bot.on(FriendMessage) - async def on_friend_message(event: FriendMessage): - go(self.on_person_message, (event,)) - - @bot.on(StrangerMessage) - async def on_stranger_message(event: StrangerMessage): - go(self.on_person_message, (event,)) - - @bot.on(GroupMessage) - async def on_group_message(event: GroupMessage): - go(self.on_group_message, (event,)) - self.bot = bot - global inst - inst = self - def send(self, event, msg, check_quote=True): asyncio.run( self.bot.send(event, msg, quote=True if hasattr(config, @@ -117,7 +145,6 @@ class QQBotManager: # 私聊消息处理 def on_person_message(self, event: MessageEvent): - reply = '' if event.sender.id == self.bot.qq: @@ -167,11 +194,13 @@ class QQBotManager: event.sender.id) break except FunctionTimedOut: + pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock() failed += 1 continue if failed == self.retry: - self.notify_admin("{} 请求超时".format("group_{}".format(event.sender.id))) + pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock() + self.notify_admin("{} 请求超时".format("group_{}".format(event.group.id))) replys = ["[bot]err:请求超时"] return replys @@ -196,8 +225,3 @@ class QQBotManager: logging.info("通知管理员:{}".format(message)) send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message)) threading.Thread(target=asyncio.run, args=(send_task,)).start() - - -def get_inst() -> QQBotManager: - global inst - return inst diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index 16a4ced9..6d95ce45 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -1,6 +1,7 @@ # 此模块提供了消息处理的具体逻辑的接口 import asyncio import datetime +import threading import pkg.qqbot.manager as manager from func_timeout import func_set_timeout @@ -8,12 +9,14 @@ import logging import openai from mirai import Image, MessageChain -from mirai.models.message import Quote import config import pkg.openai.session import pkg.openai.manager +import pkg.utils.reloader +import pkg.utils.updater +import pkg.utils.context processing = [] @@ -23,7 +26,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes sender_id: int) -> MessageChain: global processing - mgr = pkg.qqbot.manager.get_inst() + mgr = pkg.utils.context.get_qqbot_manager() reply = [] session_name = "{}_{}".format(launcher_type, launcher_id) @@ -123,22 +126,22 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes reply = [reply_str] elif cmd == 'usage': - api_keys = pkg.openai.manager.get_inst().key_mgr.api_key + api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key reply_str = "[bot]api-key使用情况:(阈值:{})\n\n".format( - pkg.openai.manager.get_inst().key_mgr.api_key_fee_threshold) + pkg.utils.context.get_openai_manager().key_mgr.api_key_fee_threshold) using_key_name = "" for api_key in api_keys: reply_str += "{}:\n - {}美元 {}%\n".format(api_key, round( - pkg.openai.manager.get_inst().key_mgr.get_fee( + pkg.utils.context.get_openai_manager().key_mgr.get_fee( api_keys[api_key]), 6), round( - pkg.openai.manager.get_inst().key_mgr.get_fee( + pkg.utils.context.get_openai_manager().key_mgr.get_fee( api_keys[ - api_key]) / pkg.openai.manager.get_inst().key_mgr.api_key_fee_threshold * 100, + api_key]) / pkg.utils.context.get_openai_manager().key_mgr.api_key_fee_threshold * 100, 3)) - if api_keys[api_key] == pkg.openai.manager.get_inst().key_mgr.using_key: + if api_keys[api_key] == pkg.utils.context.get_openai_manager().key_mgr.using_key: using_key_name = api_key reply_str += "\n当前使用:{}".format(using_key_name) @@ -157,6 +160,23 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes if not (hasattr(config, 'include_image_description') and not config.include_image_description): reply.append(" ".join(params)) + elif cmd == 'reload' and launcher_type == 'person' and launcher_id == config.admin_qq: + def reload_task(): + pkg.utils.reloader.reload_all() + + threading.Thread(target=reload_task, daemon=True).start() + elif cmd == 'update' and launcher_type == 'person' and launcher_id == config.admin_qq: + def update_task(): + try: + pkg.utils.updater.update_all() + except Exception as e0: + pkg.utils.context.get_qqbot_manager().notify_admin("更新失败:{}".format(e0)) + return + pkg.utils.reloader.reload_all() + + threading.Thread(target=update_task, daemon=True).start() + else: + reply = ["[bot]err:未知的指令或权限不足: "+cmd] except Exception as e: mgr.notify_admin("{}指令执行失败:{}".format(session_name, e)) logging.exception(e) @@ -174,17 +194,17 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes reply = ["[bot]err:调用API失败,请重试或联系作者,或等待修复"] except openai.error.RateLimitError as e: # 尝试切换api-key - current_tokens_amt = pkg.openai.manager.get_inst().key_mgr.get_fee( - pkg.openai.manager.get_inst().key_mgr.get_using_key()) - pkg.openai.manager.get_inst().key_mgr.set_current_exceeded() - switched, name = pkg.openai.manager.get_inst().key_mgr.auto_switch() + current_tokens_amt = pkg.utils.context.get_openai_manager().key_mgr.get_fee( + pkg.utils.context.get_openai_manager().key_mgr.get_using_key()) + pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded() + switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch() if not switched: mgr.notify_admin("API调用额度超限({}),请向OpenAI账户充值或在config.py中更换api_key".format( current_tokens_amt)) reply = ["[bot]err:API调用额度超额,请联系作者,或等待修复"] else: - openai.api_key = pkg.openai.manager.get_inst().key_mgr.get_using_key() + openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key() mgr.notify_admin("API调用额度超限({}),已切换到{}".format(current_tokens_amt, name)) reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"] except openai.error.InvalidRequestError as e: diff --git a/pkg/utils/__init__.py b/pkg/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/utils/context.py b/pkg/utils/context.py new file mode 100644 index 00000000..2576609b --- /dev/null +++ b/pkg/utils/context.py @@ -0,0 +1,32 @@ +context = { + 'inst': { + 'database.manager.DatabaseManager': None, + 'openai.manager.OpenAIInteract': None, + 'qqbot.manager.QQBotManager': None, + }, + 'logger_handler': None, +} + + +def set_database_manager(inst): + context['inst']['database.manager.DatabaseManager'] = inst + + +def get_database_manager(): + return context['inst']['database.manager.DatabaseManager'] + + +def set_openai_manager(inst): + context['inst']['openai.manager.OpenAIInteract'] = inst + + +def get_openai_manager(): + return context['inst']['openai.manager.OpenAIInteract'] + + +def set_qqbot_manager(inst): + context['inst']['qqbot.manager.QQBotManager'] = inst + + +def get_qqbot_manager(): + return context['inst']['qqbot.manager.QQBotManager'] \ No newline at end of file diff --git a/pkg/utils/reloader.py b/pkg/utils/reloader.py new file mode 100644 index 00000000..b2903570 --- /dev/null +++ b/pkg/utils/reloader.py @@ -0,0 +1,45 @@ +import logging +import os +import threading + +import colorlog + +import pkg +import importlib +import pkgutil +import pkg.utils.context +from main import log_colors_config + + +def walk(module, prefix=''): + """遍历并重载所有模块""" + for item in pkgutil.iter_modules(module.__path__): + if item.ispkg: + walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.') + else: + logging.info('reload module: {}'.format(prefix + item.name)) + importlib.reload(__import__(module.__name__ + '.' + item.name, fromlist=[''])) + + +def reload_all(): + # 解除bot的事件注册 + import pkg + pkg.utils.context.get_qqbot_manager().unsubscribe_all() + # 执行关闭流程 + logging.info("执行程序关闭流程") + import main + main.stop() + + # 重载所有模块 + context = pkg.utils.context.context + walk(pkg) + importlib.reload(__import__('config')) + importlib.reload(__import__('main')) + pkg.utils.context.context = context + + # 执行启动流程 + logging.info("执行程序启动流程") + threading.Thread(target=main.main, args=(False,), daemon=False).start() + + logging.info('程序启动完成') + pkg.utils.context.get_qqbot_manager().notify_admin("重载完成") diff --git a/pkg/utils/updater.py b/pkg/utils/updater.py new file mode 100644 index 00000000..90b14004 --- /dev/null +++ b/pkg/utils/updater.py @@ -0,0 +1,13 @@ +import dulwich.porcelain + + +def update_all(): + """使用dulwich更新源码""" + try: + from dulwich import porcelain + repo = porcelain.open_repo('.') + porcelain.pull(repo) + except ModuleNotFoundError: + raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") + except dulwich.porcelain.DivergedBranches: + raise Exception("分支不一致,自动更新仅支持master分支,请手动更新(https://github.com/RockChinQ/QChatGPT/issues/76)")