From 82e3ef6497e48644971df6973f8e9a0913a360d7 Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Sun, 1 Jan 2023 23:18:32 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=87=8D=E8=BD=BD=E5=90=8E=E6=81=A2?= =?UTF-8?q?=E5=A4=8D=E8=AF=B8=E4=B8=AA=E5=8D=95=E4=BE=8B=E5=AF=B9=E8=B1=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 3 ++- pkg/database/manager.py | 9 ++----- pkg/openai/keymgr.py | 56 +++++------------------------------------ pkg/openai/manager.py | 12 +++------ pkg/openai/session.py | 30 +++++++++++----------- pkg/qqbot/manager.py | 10 +++----- pkg/qqbot/process.py | 25 +++++++++--------- pkg/utils/context.py | 31 +++++++++++++++++++++++ pkg/utils/reloader.py | 3 +++ 9 files changed, 79 insertions(+), 100 deletions(-) create mode 100644 pkg/utils/context.py diff --git a/main.py b/main.py index 23150f1d..55bb6166 100644 --- a/main.py +++ b/main.py @@ -53,6 +53,7 @@ def main(): import pkg.database.manager import pkg.openai.session import pkg.qqbot.manager + import pkg.utils.context # 主启动流程 database = pkg.database.manager.DatabaseManager() @@ -78,7 +79,7 @@ def main(): time.sleep(86400) except KeyboardInterrupt: try: - pkg.openai.manager.get_inst().key_mgr.dump_fee() + pkg.utils.context.get_openai_manager().key_mgr.dump_fee() for session in pkg.openai.session.sessions: logging.info('持久化session: %s', session) pkg.openai.session.sessions[session].persistence() diff --git a/pkg/database/manager.py b/pkg/database/manager.py index cbe243e9..5c9b700b 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -6,8 +6,7 @@ from sqlite3 import Cursor import sqlite3 import config - -inst = None +import pkg.utils.context # 数据库管理 @@ -20,8 +19,7 @@ class DatabaseManager: self.reconnect() - global inst - inst = self + pkg.utils.context.set_database_manager(self) # 连接到数据库文件 def reconnect(self): @@ -312,6 +310,3 @@ class DatabaseManager: fee[key_md5] = fee_count return fee -def get_inst() -> DatabaseManager: - global inst - return inst diff --git a/pkg/openai/keymgr.py b/pkg/openai/keymgr.py index e8992c7e..cf9898aa 100644 --- a/pkg/openai/keymgr.py +++ b/pkg/openai/keymgr.py @@ -4,6 +4,7 @@ import logging import pkg.database.manager import pkg.qqbot.manager +import pkg.utils.context import config @@ -62,43 +63,6 @@ class KeysManager: def add(self, key_name, key): self.api_key[key_name] = key - # def get_usage(self, api_key): - # md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest() - # if md5 not in self.usage: - # self.usage[md5] = 0 - # return self.usage[md5] - - # 报告使用 - # 返回是否需要将openai的api-key切换 - # def report_usage(self, new_content: str) -> bool: - # md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest() - # if md5 not in self.usage: - # self.usage[md5] = 0 - # - # # 经测算得出的理论与实际的偏差比例 - # salt_rate = 0.91 - # - # self.usage[md5] += ( (len(new_content.encode('utf-8')) - len(new_content)) / 2 + len(new_content) )*salt_rate - # - # self.usage[md5] = int(self.usage[md5]) - # - # if self.usage[md5] >= self.api_key_usage_threshold: - # switch_result, key_name = self.auto_switch() - # - # # 检查是否切换到新的 - # if switch_result: - # if key_name not in self.alerted: - # # 通知管理员 - # pkg.qqbot.manager.get_inst().notify_admin("api-key已切换到:" + key_name) - # self.alerted.append(key_name) - # return True - # else: - # if key_name not in self.alerted: - # # 通知管理员 - # pkg.qqbot.manager.get_inst().notify_admin("api-key已用完,无未使用的api-key可供切换") - # self.alerted.append(key_name) - # return False - # 设置当前使用的api-key使用量超限 # 这是在尝试调用api时发生超限异常时调用的 def set_current_exceeded(self): @@ -107,14 +71,6 @@ class KeysManager: self.fee[md5] = self.api_key_fee_threshold self.dump_fee() - # def dump_usage(self): - # pkg.database.manager.get_inst().dump_api_key_usage(api_keys=self.api_key, usage=self.usage) - - # def load_usage(self): - # self.usage = pkg.database.manager.get_inst().load_api_key_usage() - # logging.debug("load usage:" + str(self.usage)) - # print("load usage:" + str(self.usage)) - def get_fee(self, api_key): md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest() if md5 not in self.fee: @@ -135,19 +91,19 @@ class KeysManager: if switch_result: if key_name not in self.alerted: # 通知管理员 - pkg.qqbot.manager.get_inst().notify_admin("api-key已切换到:" + key_name) + pkg.utils.context.get_qqbot_manager().notify_admin("api-key已切换到:" + key_name) self.alerted.append(key_name) return True else: if key_name not in self.alerted: # 通知管理员 - pkg.qqbot.manager.get_inst().notify_admin("api-key已用完,无未使用的api-key可供切换") + pkg.utils.context.get_qqbot_manager().notify_admin("api-key已用完,无未使用的api-key可供切换") self.alerted.append(key_name) return False def dump_fee(self): - pkg.database.manager.get_inst().dump_api_key_fee(api_keys=self.api_key, fee=self.fee) + pkg.utils.context.get_database_manager().dump_api_key_fee(api_keys=self.api_key, fee=self.fee) def load_fee(self): - self.fee = pkg.database.manager.get_inst().load_api_key_fee() - logging.info("load fee:" + str(self.fee)) \ No newline at end of file + self.fee = pkg.utils.context.get_database_manager().load_api_key_fee() + logging.info("load fee:" + str(self.fee)) diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index d65f0e54..fa77fb4c 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -6,8 +6,7 @@ import config import pkg.openai.keymgr import pkg.openai.pricing as pricing - -inst = None +import pkg.utils.context # 为其他模块提供与OpenAI交互的接口 @@ -27,11 +26,11 @@ class OpenAIInteract: openai.api_key = self.key_mgr.get_using_key() - global inst - inst = self + pkg.utils.context.set_openai_manager(self) # 请求OpenAI Completion def request_completion(self, prompt, stop): + print("request") response = openai.Completion.create( prompt=prompt, stop=stop, @@ -41,7 +40,6 @@ class OpenAIInteract: switched = self.key_mgr.report_fee(pricing.language_base_price(config.completion_api_params['model'], prompt + response['choices'][0]['text'])) - if switched: openai.api_key = self.key_mgr.get_using_key() @@ -64,7 +62,3 @@ class OpenAIInteract: return response - -def get_inst() -> OpenAIInteract: - global inst - return inst diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 9bc9e9e1..631d4131 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -5,6 +5,7 @@ import time import config import pkg.openai.manager import pkg.database.manager +import pkg.utils.context # 运行时保存的所有session sessions = {} @@ -19,7 +20,7 @@ class SessionOfflineStatus: def load_sessions(): global sessions - db_inst = pkg.database.manager.get_inst() + db_inst = pkg.utils.context.get_database_manager() session_data = db_inst.load_valid_sessions() @@ -147,10 +148,11 @@ class Session: max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024 # 向API请求补全 - response = pkg.openai.manager.get_inst().request_completion(self.cut_out(self.prompt + self.user_name + ':' + - text + '\n' + self.bot_name + ':', - max_rounds, max_length), - self.user_name + ':') + response = pkg.utils.context.get_openai_manager().request_completion( + self.cut_out(self.prompt + self.user_name + ':' + + text + '\n' + self.bot_name + ':', + max_rounds, max_length), + self.user_name + ':') self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':' # print(response) @@ -202,7 +204,7 @@ class Session: if self.prompt == get_default_prompt(): return - db_inst = pkg.database.manager.get_inst() + db_inst = pkg.utils.context.get_database_manager() name_spt = self.name.split('_') @@ -217,10 +219,10 @@ class Session: if self.prompt != get_default_prompt(): self.persistence() if explicit: - pkg.database.manager.get_inst().explicit_close_session(self.name, self.create_timestamp) + pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp) if expired: - pkg.database.manager.get_inst().set_session_expired(self.name, self.create_timestamp) + pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp) self.prompt = get_default_prompt() self.create_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time()) @@ -233,11 +235,11 @@ class Session: # 将本session的数据库状态设置为on_going def set_ongoing(self): - pkg.database.manager.get_inst().set_session_ongoing(self.name, self.create_timestamp) + pkg.utils.context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp) # 切换到上一个session def last_session(self): - last_one = pkg.database.manager.get_inst().last_session(self.name, self.last_interact_timestamp) + last_one = pkg.utils.context.get_database_manager().last_session(self.name, self.last_interact_timestamp) if last_one is None: return None else: @@ -252,7 +254,7 @@ class Session: # 切换到下一个session def next_session(self): - next_one = pkg.database.manager.get_inst().next_session(self.name, self.last_interact_timestamp) + next_one = pkg.utils.context.get_database_manager().next_session(self.name, self.last_interact_timestamp) if next_one is None: return None else: @@ -266,8 +268,8 @@ class Session: return self def list_history(self, capacity: int = 10, page: int = 0): - return pkg.database.manager.get_inst().list_history(self.name, capacity, page, - get_default_prompt()) + return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page, + get_default_prompt()) def draw_image(self, prompt: str): - return pkg.openai.manager.get_inst().request_image(prompt) + return pkg.utils.context.get_openai_manager().request_image(prompt) diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index a301664f..39bb8265 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -17,8 +17,7 @@ import logging import pkg.qqbot.filter import pkg.qqbot.process as processor - -inst = None +import pkg.utils.context # 并行运行 @@ -107,8 +106,8 @@ class QQBotManager: self.bot = bot - global inst - inst = self + pkg.utils.context.set_qqbot_manager(self) + def send(self, event, msg, check_quote=True): asyncio.run( @@ -198,6 +197,3 @@ class QQBotManager: threading.Thread(target=asyncio.run, args=(send_task,)).start() -def get_inst() -> QQBotManager: - global inst - return inst diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index afbdba6c..596a10ad 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -16,6 +16,7 @@ import pkg.openai.session import pkg.openai.manager import pkg.utils.reloader import pkg.utils.updater +import pkg.utils.context processing = [] @@ -25,7 +26,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes sender_id: int) -> MessageChain: global processing - mgr = pkg.qqbot.manager.get_inst() + mgr = pkg.utils.context.get_qqbot_manager() reply = [] session_name = "{}_{}".format(launcher_type, launcher_id) @@ -125,22 +126,22 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes reply = [reply_str] elif cmd == 'usage': - api_keys = pkg.openai.manager.get_inst().key_mgr.api_key + api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key reply_str = "[bot]api-key使用情况:(阈值:{})\n\n".format( - pkg.openai.manager.get_inst().key_mgr.api_key_fee_threshold) + pkg.utils.context.get_openai_manager().key_mgr.api_key_fee_threshold) using_key_name = "" for api_key in api_keys: reply_str += "{}:\n - {}美元 {}%\n".format(api_key, round( - pkg.openai.manager.get_inst().key_mgr.get_fee( + pkg.utils.context.get_openai_manager().key_mgr.get_fee( api_keys[api_key]), 6), round( - pkg.openai.manager.get_inst().key_mgr.get_fee( + pkg.utils.context.get_openai_manager().key_mgr.get_fee( api_keys[ - api_key]) / pkg.openai.manager.get_inst().key_mgr.api_key_fee_threshold * 100, + api_key]) / pkg.utils.context.get_openai_manager().key_mgr.api_key_fee_threshold * 100, 3)) - if api_keys[api_key] == pkg.openai.manager.get_inst().key_mgr.using_key: + if api_keys[api_key] == pkg.utils.context.get_openai_manager().key_mgr.using_key: using_key_name = api_key reply_str += "\n当前使用:{}".format(using_key_name) @@ -191,17 +192,17 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes reply = ["[bot]err:调用API失败,请重试或联系作者,或等待修复"] except openai.error.RateLimitError as e: # 尝试切换api-key - current_tokens_amt = pkg.openai.manager.get_inst().key_mgr.get_fee( - pkg.openai.manager.get_inst().key_mgr.get_using_key()) - pkg.openai.manager.get_inst().key_mgr.set_current_exceeded() - switched, name = pkg.openai.manager.get_inst().key_mgr.auto_switch() + current_tokens_amt = pkg.utils.context.get_openai_manager().key_mgr.get_fee( + pkg.utils.context.get_openai_manager().key_mgr.get_using_key()) + pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded() + switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch() if not switched: mgr.notify_admin("API调用额度超限({}),请向OpenAI账户充值或在config.py中更换api_key".format( current_tokens_amt)) reply = ["[bot]err:API调用额度超额,请联系作者,或等待修复"] else: - openai.api_key = pkg.openai.manager.get_inst().key_mgr.get_using_key() + openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key() mgr.notify_admin("API调用额度超限({}),已切换到{}".format(current_tokens_amt, name)) reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"] except openai.error.InvalidRequestError as e: diff --git a/pkg/utils/context.py b/pkg/utils/context.py new file mode 100644 index 00000000..be2849a8 --- /dev/null +++ b/pkg/utils/context.py @@ -0,0 +1,31 @@ +context = { + 'inst': { + 'database.manager.DatabaseManager': None, + 'openai.manager.OpenAIInteract': None, + 'qqbot.manager.QQBotManager': None, + } +} + + +def set_database_manager(inst): + context['inst']['database.manager.DatabaseManager'] = inst + + +def get_database_manager(): + return context['inst']['database.manager.DatabaseManager'] + + +def set_openai_manager(inst): + context['inst']['openai.manager.OpenAIInteract'] = inst + + +def get_openai_manager(): + return context['inst']['openai.manager.OpenAIInteract'] + + +def set_qqbot_manager(inst): + context['inst']['qqbot.manager.QQBotManager'] = inst + + +def get_qqbot_manager(): + return context['inst']['qqbot.manager.QQBotManager'] \ No newline at end of file diff --git a/pkg/utils/reloader.py b/pkg/utils/reloader.py index 5cb3b16d..7b7f1a35 100644 --- a/pkg/utils/reloader.py +++ b/pkg/utils/reloader.py @@ -3,6 +3,7 @@ import logging import pkg import importlib import pkgutil +import pkg.utils.context def walk(module, prefix=''): @@ -15,5 +16,7 @@ def walk(module, prefix=''): def reload_all(): + context = pkg.utils.context.context walk(pkg) importlib.reload(__import__('config')) + pkg.utils.context.context = context