From 95ad911a6c1637f61b8a46607c15ff23c434f502 Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Wed, 4 Jan 2023 17:09:57 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=BF=90=E8=A1=8C=E6=97=B6=E5=8E=9F?= =?UTF-8?q?=E5=8A=A8=E6=80=81=E5=BC=95=E5=85=A5config=E7=9A=84=E5=9C=B0?= =?UTF-8?q?=E6=96=B9=E7=8E=B0=E5=9C=A8=E5=9D=87=E4=BD=BF=E7=94=A8=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96=E6=97=B6=E5=AF=BC=E5=85=A5=E7=9A=84config?= =?UTF-8?q?=E5=AF=B9=E8=B1=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 44 +++++++++++++++++++++++++++++------------ pkg/database/manager.py | 3 +-- pkg/openai/keymgr.py | 4 +++- pkg/openai/manager.py | 4 ++-- pkg/openai/session.py | 7 ++++++- pkg/qqbot/manager.py | 5 ++++- pkg/qqbot/process.py | 3 +++ pkg/utils/context.py | 9 +++++++++ 8 files changed, 59 insertions(+), 20 deletions(-) diff --git a/main.py b/main.py index e7427270..24460e1e 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import importlib import os import shutil import threading @@ -33,8 +34,10 @@ def init_db(): database.initialize_database() + known_exception_caught = False + def main(first_time_init=False): global known_exception_caught @@ -43,12 +46,11 @@ def main(first_time_init=False): # 导入config.py assert os.path.exists('config.py') - # 检查是否设置了管理员 - import config - if not (hasattr(config, 'admin_qq') and config.admin_qq != 0): - logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段") + config = importlib.import_module('config') import pkg.utils.context + pkg.utils.context.set_config(config) + if pkg.utils.context.context['logger_handler'] is not None: logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler']) @@ -69,6 +71,10 @@ def main(first_time_init=False): )) logging.getLogger().addHandler(sh) + # 检查是否设置了管理员 + if not (hasattr(config, 'admin_qq') and config.admin_qq != 0): + logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段") + import pkg.openai.manager import pkg.database.manager import pkg.openai.session @@ -98,35 +104,42 @@ def main(first_time_init=False): qqbot.bot.run() except TypeError as e: if str(e).__contains__("argument 'debug'"): - logging.error("连接bot失败:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/82".format(e)) + logging.error( + "连接bot失败:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/82".format(e)) known_exception_caught = True elif str(e).__contains__("As of 3.10, the *loop*"): - logging.error("Websockets版本过低:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/5".format(e)) + logging.error( + "Websockets版本过低:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/5".format(e)) known_exception_caught = True except websockets.exceptions.InvalidStatus as e: - logging.error("mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(e)) + logging.error( + "mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format( + e)) known_exception_caught = True except mirai.exceptions.NetworkError as e: logging.error("连接mirai-api-http失败:{}, 请检查是否已按照文档启动mirai".format(e)) known_exception_caught = True except Exception as e: - if str(e).__contains__("HTTP 404"): - logging.error("mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(e)) + if str(e).__contains__("404"): + logging.error( + "mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format( + e)) known_exception_caught = True else: - logging.error("捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/issues 查找或提issue".format(e)) + logging.error( + "捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/issues 查找或提issue".format(e)) known_exception_caught = True raise e qq_bot_thread = threading.Thread(target=run_bot_wrapper, args=(), daemon=True) qq_bot_thread.start() finally: - time.sleep(10) + time.sleep(12) if first_time_init: if not known_exception_caught: - logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 ' - 'https://github.com/RockChinQ/QChatGPT/issues/37') + logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 ' + 'https://github.com/RockChinQ/QChatGPT/issues/37') else: sys.exit(1) else: @@ -177,10 +190,15 @@ if __name__ == '__main__': elif len(sys.argv) > 1 and sys.argv[1] == 'update': try: from dulwich import porcelain + repo = porcelain.open_repo('.') porcelain.pull(repo) except ModuleNotFoundError: print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") sys.exit(0) + # import pkg.utils.configmgr + # + # pkg.utils.configmgr.set_config_and_reload("quote_origin", False) + main(True) diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 6c24ad20..725970b5 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -6,7 +6,6 @@ from sqlite3 import Cursor import sqlite3 -import config import pkg.utils.context @@ -25,7 +24,6 @@ class DatabaseManager: # 连接到数据库文件 def reconnect(self): self.conn = sqlite3.connect('database.db', check_same_thread=False) - # self.conn.isolation_level = None self.cursor = self.conn.cursor() def close(self): @@ -127,6 +125,7 @@ class DatabaseManager: # 从数据库加载还没过期的session数据 def load_valid_sessions(self) -> dict: # 从数据库中加载所有还没过期的session + config = pkg.utils.context.get_config() self.execute(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` from `sessions` where `last_interact_timestamp` > {} diff --git a/pkg/openai/keymgr.py b/pkg/openai/keymgr.py index 1e8ebae2..9c1e9d2f 100644 --- a/pkg/openai/keymgr.py +++ b/pkg/openai/keymgr.py @@ -5,7 +5,6 @@ import logging import pkg.database.manager import pkg.qqbot.manager import pkg.utils.context -import config class KeysManager: @@ -34,6 +33,8 @@ class KeysManager: def __init__(self, api_key): # if hasattr(config, 'api_key_usage_threshold'): # self.api_key_usage_threshold = config.api_key_usage_threshold + + config = pkg.utils.context.get_config() if hasattr(config, 'api_key_fee_threshold'): self.api_key_fee_threshold = config.api_key_fee_threshold self.load_fee() @@ -108,6 +109,7 @@ class KeysManager: self.fee[md5] += fee + config = pkg.utils.context.get_config() if self.fee[md5] >= self.api_key_fee_threshold and \ hasattr(config, 'auto_switch_api_key') and config.auto_switch_api_key: switch_result, key_name = self.auto_switch() diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index 607b73b6..5208fce5 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -2,8 +2,6 @@ import logging import openai -import config - import pkg.openai.keymgr import pkg.openai.pricing as pricing import pkg.utils.context @@ -34,6 +32,7 @@ class OpenAIInteract: # 请求OpenAI Completion def request_completion(self, prompt, stop): + config = pkg.utils.context.get_config() response = openai.Completion.create( prompt=prompt, stop=stop, @@ -53,6 +52,7 @@ class OpenAIInteract: def request_image(self, prompt): + config = pkg.utils.context.get_config() params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params response = openai.Image.create( diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 631d4131..4bb5f970 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -2,7 +2,6 @@ import logging import threading import time -import config import pkg.openai.manager import pkg.database.manager import pkg.utils.context @@ -54,6 +53,7 @@ def dump_session(session_name: str): # 从配置文件获取会话预设信息 def get_default_prompt(): + import config user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You' bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot' return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \ @@ -85,6 +85,8 @@ class Session: prompt = get_default_prompt() + import config + user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You' bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot' @@ -130,6 +132,8 @@ class Session: # 不是此session已更换,退出 if self.create_timestamp != create_timestamp or self not in sessions.values(): return + + config = pkg.utils.context.get_config() if int(time.time()) - self.last_interact_timestamp > config.session_expire_time: logging.info('session {} 已过期'.format(self.name)) self.reset(expired=True, schedule_new=False) @@ -144,6 +148,7 @@ class Session: self.last_interact_timestamp = int(time.time()) # max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7 + config = pkg.utils.context.get_config() max_rounds = 1000 # 不再限制回合数 max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024 diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 506fdb26..1df36739 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -7,7 +7,6 @@ import mirai.models.bus from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage, WebSocketAdapter, HTTPAdapter, \ FriendMessage, Image -import config import pkg.openai.session import pkg.openai.manager from func_timeout import FunctionTimedOut @@ -26,6 +25,7 @@ def go(func, args=()): # 检查消息是否符合泛响应匹配机制 def check_response_rule(text: str) -> (bool, str): + config = pkg.utils.context.get_config() if not hasattr(config, 'response_rules'): return False, '' @@ -60,6 +60,7 @@ class QQBotManager: self.timeout = timeout self.retry = retry + config = pkg.utils.context.get_config() if os.path.exists("sensitive.json") \ and config.sensitive_word_filter is not None \ and config.sensitive_word_filter: @@ -134,6 +135,7 @@ class QQBotManager: self.bot = bot def send(self, event, msg, check_quote=True): + config = pkg.utils.context.get_config() asyncio.run( self.bot.send(event, msg, quote=True if hasattr(config, "quote_origin") and config.quote_origin and check_quote else False)) @@ -216,6 +218,7 @@ class QQBotManager: # 通知系统管理员 def notify_admin(self, message: str): + config = pkg.utils.context.get_config() if hasattr(config, "admin_qq") and config.admin_qq != 0: logging.info("通知管理员:{}".format(message)) send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message)) diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index 0eb46282..736caaf4 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -9,6 +9,9 @@ import openai from mirai import Image, MessageChain +# 这里不使用动态引入config +# 因为在这里动态引入会卡死程序 +# 而此模块静态引用config与动态引入的表现一致 import config import pkg.openai.session diff --git a/pkg/utils/context.py b/pkg/utils/context.py index 449168eb..4d227f42 100644 --- a/pkg/utils/context.py +++ b/pkg/utils/context.py @@ -9,9 +9,18 @@ context = { 'qqbot.manager.QQBotManager': None, }, 'logger_handler': None, + 'config': None, } +def set_config(inst): + context['config'] = inst + + +def get_config(): + return context['config'] + + def set_database_manager(inst): context['inst']['database.manager.DatabaseManager'] = inst