mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: 持久保存bot对象以成功重启
This commit is contained in:
45
main.py
45
main.py
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
@@ -7,6 +8,8 @@ import time
|
||||
import logging
|
||||
import colorlog
|
||||
|
||||
from mirai.bot import MiraiRunner
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append(".")
|
||||
@@ -27,11 +30,15 @@ def init_db():
|
||||
database.initialize_database()
|
||||
|
||||
|
||||
def main():
|
||||
def main(first_time_init=False):
|
||||
# 导入config.py
|
||||
assert os.path.exists('config.py')
|
||||
import config
|
||||
|
||||
import pkg.utils.context
|
||||
if pkg.utils.context.context['logger_handler'] is not None:
|
||||
logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler'])
|
||||
|
||||
logging.basicConfig(level=config.logging_level, # 设置日志输出格式
|
||||
filename='qchatgpt.log', # log日志输出的文件位置和文件名
|
||||
format="[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s",
|
||||
@@ -53,8 +60,8 @@ def main():
|
||||
import pkg.database.manager
|
||||
import pkg.openai.session
|
||||
import pkg.qqbot.manager
|
||||
import pkg.utils.context
|
||||
|
||||
pkg.utils.context.context['logger_handler'] = sh
|
||||
# 主启动流程
|
||||
database = pkg.database.manager.DatabaseManager()
|
||||
|
||||
@@ -67,7 +74,8 @@ def main():
|
||||
|
||||
# 初始化qq机器人
|
||||
qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config,
|
||||
timeout=config.process_message_timeout, retry=config.retry_times)
|
||||
timeout=config.process_message_timeout, retry=config.retry_times,
|
||||
first_time_init=first_time_init)
|
||||
|
||||
qq_bot_thread = threading.Thread(target=qqbot.bot.run, args=(), daemon=True)
|
||||
qq_bot_thread.start()
|
||||
@@ -76,9 +84,33 @@ def main():
|
||||
|
||||
while True:
|
||||
try:
|
||||
time.sleep(86400)
|
||||
time.sleep(10000)
|
||||
if qqbot != pkg.utils.context.get_qqbot_manager(): # 已经reload了
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
stop()
|
||||
|
||||
print("程序退出")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def stop():
|
||||
import pkg.utils.context
|
||||
import pkg.qqbot.manager
|
||||
import pkg.openai.session
|
||||
try:
|
||||
qqbot_inst = pkg.utils.context.get_qqbot_manager()
|
||||
assert isinstance(qqbot_inst, pkg.qqbot.manager.QQBotManager)
|
||||
|
||||
# try:
|
||||
# asyncio.run(qqbot_inst.bot.shutdown())
|
||||
# except ValueError:
|
||||
# pass
|
||||
#
|
||||
# import mirai.utils
|
||||
# MiraiRunner.__class__._instance = None
|
||||
# mirai.utils.Singleton._instance = None
|
||||
|
||||
pkg.utils.context.get_openai_manager().key_mgr.dump_fee()
|
||||
for session in pkg.openai.session.sessions:
|
||||
logging.info('持久化session: %s', session)
|
||||
@@ -86,11 +118,10 @@ def main():
|
||||
except Exception as e:
|
||||
if not isinstance(e, KeyboardInterrupt):
|
||||
raise e
|
||||
print("程序退出")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('程序启动')
|
||||
# 检查是否有config.py,如果没有就把config-template.py复制一份,并退出程序
|
||||
if not os.path.exists('config.py'):
|
||||
shutil.copy('config-template.py', 'config.py')
|
||||
@@ -110,4 +141,4 @@ if __name__ == '__main__':
|
||||
print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
|
||||
sys.exit(0)
|
||||
|
||||
main()
|
||||
main(True)
|
||||
|
||||
@@ -30,7 +30,7 @@ class OpenAIInteract:
|
||||
|
||||
# 请求OpenAI Completion
|
||||
def request_completion(self, prompt, stop):
|
||||
print("request")
|
||||
# print("request")
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
|
||||
@@ -57,7 +57,7 @@ class QQBotManager:
|
||||
|
||||
reply_filter = None
|
||||
|
||||
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3):
|
||||
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True):
|
||||
|
||||
self.timeout = timeout
|
||||
self.retry = retry
|
||||
@@ -70,6 +70,28 @@ class QQBotManager:
|
||||
else:
|
||||
self.reply_filter = pkg.qqbot.filter.ReplyFilter([])
|
||||
|
||||
|
||||
if first_time_init:
|
||||
self.first_time_init(mirai_http_api_config)
|
||||
else:
|
||||
self.bot = pkg.utils.context.get_qqbot_manager().bot
|
||||
|
||||
pkg.utils.context.set_qqbot_manager(self)
|
||||
|
||||
@self.bot.on(FriendMessage)
|
||||
async def on_friend_message(event: FriendMessage):
|
||||
go(self.on_person_message, (event,))
|
||||
|
||||
@self.bot.on(StrangerMessage)
|
||||
async def on_stranger_message(event: StrangerMessage):
|
||||
go(self.on_person_message, (event,))
|
||||
|
||||
@self.bot.on(GroupMessage)
|
||||
async def on_group_message(event: GroupMessage):
|
||||
go(self.on_group_message, (event,))
|
||||
|
||||
def first_time_init(self, mirai_http_api_config: dict):
|
||||
|
||||
if 'adapter' not in mirai_http_api_config or mirai_http_api_config['adapter'] == "WebSocketAdapter":
|
||||
bot = Mirai(
|
||||
qq=mirai_http_api_config['qq'],
|
||||
@@ -92,22 +114,9 @@ class QQBotManager:
|
||||
else:
|
||||
raise Exception("未知的适配器类型")
|
||||
|
||||
@bot.on(FriendMessage)
|
||||
async def on_friend_message(event: FriendMessage):
|
||||
go(self.on_person_message, (event,))
|
||||
|
||||
@bot.on(StrangerMessage)
|
||||
async def on_stranger_message(event: StrangerMessage):
|
||||
go(self.on_person_message, (event,))
|
||||
|
||||
@bot.on(GroupMessage)
|
||||
async def on_group_message(event: GroupMessage):
|
||||
go(self.on_group_message, (event,))
|
||||
|
||||
self.bot = bot
|
||||
|
||||
pkg.utils.context.set_qqbot_manager(self)
|
||||
|
||||
def send(self, event, msg, check_quote=True):
|
||||
asyncio.run(
|
||||
self.bot.send(event, msg, quote=True if hasattr(config,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# 此模块提供了消息处理的具体逻辑的接口
|
||||
import asyncio
|
||||
import datetime
|
||||
import threading
|
||||
|
||||
import pkg.qqbot.manager as manager
|
||||
from func_timeout import func_set_timeout
|
||||
@@ -8,7 +9,6 @@ import logging
|
||||
import openai
|
||||
|
||||
from mirai import Image, MessageChain
|
||||
from mirai.models.message import Quote
|
||||
|
||||
import config
|
||||
|
||||
@@ -162,8 +162,9 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
||||
reply.append(" ".join(params))
|
||||
elif cmd == 'reload' and launcher_type == 'person' and launcher_id == config.admin_qq:
|
||||
try:
|
||||
pkg.utils.reloader.reload_all()
|
||||
reply = ["[bot]已重新加载所有模块"]
|
||||
# pkg.utils.reloader.reload_all()
|
||||
threading.Thread(target=pkg.utils.reloader.reload_all, daemon=True).start()
|
||||
# reply = ["[bot]已重新加载所有模块"]
|
||||
except Exception as e:
|
||||
logging.error("reload failed:{}".format(e))
|
||||
reply = ["[bot]重载失败:{}".format(e)]
|
||||
|
||||
@@ -3,7 +3,8 @@ context = {
|
||||
'database.manager.DatabaseManager': None,
|
||||
'openai.manager.OpenAIInteract': None,
|
||||
'qqbot.manager.QQBotManager': None,
|
||||
}
|
||||
},
|
||||
'logger_handler': None,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
import colorlog
|
||||
|
||||
import pkg
|
||||
import importlib
|
||||
import pkgutil
|
||||
import pkg.utils.context
|
||||
from main import log_colors_config
|
||||
|
||||
|
||||
def walk(module, prefix=''):
|
||||
@@ -16,7 +21,20 @@ def walk(module, prefix=''):
|
||||
|
||||
|
||||
def reload_all():
|
||||
# 执行关闭流程
|
||||
logging.info("执行程序关闭流程")
|
||||
import main
|
||||
main.stop()
|
||||
import pkg
|
||||
|
||||
context = pkg.utils.context.context
|
||||
walk(pkg)
|
||||
importlib.reload(__import__('config'))
|
||||
importlib.reload(__import__('main'))
|
||||
pkg.utils.context.context = context
|
||||
|
||||
# 执行启动流程
|
||||
logging.info("执行程序启动流程")
|
||||
main.main()
|
||||
|
||||
logging.info('程序启动完成')
|
||||
|
||||
Reference in New Issue
Block a user