feat: 重载后恢复诸个单例对象

This commit is contained in:
Rock Chin
2023-01-01 23:18:32 +08:00
parent 7e83ba3f77
commit 82e3ef6497
9 changed files with 79 additions and 100 deletions

View File

@@ -53,6 +53,7 @@ def main():
import pkg.database.manager
import pkg.openai.session
import pkg.qqbot.manager
import pkg.utils.context
# 主启动流程
database = pkg.database.manager.DatabaseManager()
@@ -78,7 +79,7 @@ def main():
time.sleep(86400)
except KeyboardInterrupt:
try:
pkg.openai.manager.get_inst().key_mgr.dump_fee()
pkg.utils.context.get_openai_manager().key_mgr.dump_fee()
for session in pkg.openai.session.sessions:
logging.info('持久化session: %s', session)
pkg.openai.session.sessions[session].persistence()

View File

@@ -6,8 +6,7 @@ from sqlite3 import Cursor
import sqlite3
import config
inst = None
import pkg.utils.context
# 数据库管理
@@ -20,8 +19,7 @@ class DatabaseManager:
self.reconnect()
global inst
inst = self
pkg.utils.context.set_database_manager(self)
# 连接到数据库文件
def reconnect(self):
@@ -312,6 +310,3 @@ class DatabaseManager:
fee[key_md5] = fee_count
return fee
def get_inst() -> DatabaseManager:
global inst
return inst

View File

@@ -4,6 +4,7 @@ import logging
import pkg.database.manager
import pkg.qqbot.manager
import pkg.utils.context
import config
@@ -62,43 +63,6 @@ class KeysManager:
def add(self, key_name, key):
self.api_key[key_name] = key
# def get_usage(self, api_key):
# md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest()
# if md5 not in self.usage:
# self.usage[md5] = 0
# return self.usage[md5]
# 报告使用
# 返回是否需要将openai的api-key切换
# def report_usage(self, new_content: str) -> bool:
# md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
# if md5 not in self.usage:
# self.usage[md5] = 0
#
# # 经测算得出的理论与实际的偏差比例
# salt_rate = 0.91
#
# self.usage[md5] += ( (len(new_content.encode('utf-8')) - len(new_content)) / 2 + len(new_content) )*salt_rate
#
# self.usage[md5] = int(self.usage[md5])
#
# if self.usage[md5] >= self.api_key_usage_threshold:
# switch_result, key_name = self.auto_switch()
#
# # 检查是否切换到新的
# if switch_result:
# if key_name not in self.alerted:
# # 通知管理员
# pkg.qqbot.manager.get_inst().notify_admin("api-key已切换到:" + key_name)
# self.alerted.append(key_name)
# return True
# else:
# if key_name not in self.alerted:
# # 通知管理员
# pkg.qqbot.manager.get_inst().notify_admin("api-key已用完无未使用的api-key可供切换")
# self.alerted.append(key_name)
# return False
# 设置当前使用的api-key使用量超限
# 这是在尝试调用api时发生超限异常时调用的
def set_current_exceeded(self):
@@ -107,14 +71,6 @@ class KeysManager:
self.fee[md5] = self.api_key_fee_threshold
self.dump_fee()
# def dump_usage(self):
# pkg.database.manager.get_inst().dump_api_key_usage(api_keys=self.api_key, usage=self.usage)
# def load_usage(self):
# self.usage = pkg.database.manager.get_inst().load_api_key_usage()
# logging.debug("load usage:" + str(self.usage))
# print("load usage:" + str(self.usage))
def get_fee(self, api_key):
md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest()
if md5 not in self.fee:
@@ -135,19 +91,19 @@ class KeysManager:
if switch_result:
if key_name not in self.alerted:
# 通知管理员
pkg.qqbot.manager.get_inst().notify_admin("api-key已切换到:" + key_name)
pkg.utils.context.get_qqbot_manager().notify_admin("api-key已切换到:" + key_name)
self.alerted.append(key_name)
return True
else:
if key_name not in self.alerted:
# 通知管理员
pkg.qqbot.manager.get_inst().notify_admin("api-key已用完无未使用的api-key可供切换")
pkg.utils.context.get_qqbot_manager().notify_admin("api-key已用完无未使用的api-key可供切换")
self.alerted.append(key_name)
return False
def dump_fee(self):
pkg.database.manager.get_inst().dump_api_key_fee(api_keys=self.api_key, fee=self.fee)
pkg.utils.context.get_database_manager().dump_api_key_fee(api_keys=self.api_key, fee=self.fee)
def load_fee(self):
self.fee = pkg.database.manager.get_inst().load_api_key_fee()
logging.info("load fee:" + str(self.fee))
self.fee = pkg.utils.context.get_database_manager().load_api_key_fee()
logging.info("load fee:" + str(self.fee))

View File

@@ -6,8 +6,7 @@ import config
import pkg.openai.keymgr
import pkg.openai.pricing as pricing
inst = None
import pkg.utils.context
# 为其他模块提供与OpenAI交互的接口
@@ -27,11 +26,11 @@ class OpenAIInteract:
openai.api_key = self.key_mgr.get_using_key()
global inst
inst = self
pkg.utils.context.set_openai_manager(self)
# 请求OpenAI Completion
def request_completion(self, prompt, stop):
print("request")
response = openai.Completion.create(
prompt=prompt,
stop=stop,
@@ -41,7 +40,6 @@ class OpenAIInteract:
switched = self.key_mgr.report_fee(pricing.language_base_price(config.completion_api_params['model'],
prompt + response['choices'][0]['text']))
if switched:
openai.api_key = self.key_mgr.get_using_key()
@@ -64,7 +62,3 @@ class OpenAIInteract:
return response
def get_inst() -> OpenAIInteract:
global inst
return inst

View File

@@ -5,6 +5,7 @@ import time
import config
import pkg.openai.manager
import pkg.database.manager
import pkg.utils.context
# 运行时保存的所有session
sessions = {}
@@ -19,7 +20,7 @@ class SessionOfflineStatus:
def load_sessions():
global sessions
db_inst = pkg.database.manager.get_inst()
db_inst = pkg.utils.context.get_database_manager()
session_data = db_inst.load_valid_sessions()
@@ -147,10 +148,11 @@ class Session:
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
# 向API请求补全
response = pkg.openai.manager.get_inst().request_completion(self.cut_out(self.prompt + self.user_name + ':' +
text + '\n' + self.bot_name + ':',
max_rounds, max_length),
self.user_name + ':')
response = pkg.utils.context.get_openai_manager().request_completion(
self.cut_out(self.prompt + self.user_name + ':' +
text + '\n' + self.bot_name + ':',
max_rounds, max_length),
self.user_name + ':')
self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':'
# print(response)
@@ -202,7 +204,7 @@ class Session:
if self.prompt == get_default_prompt():
return
db_inst = pkg.database.manager.get_inst()
db_inst = pkg.utils.context.get_database_manager()
name_spt = self.name.split('_')
@@ -217,10 +219,10 @@ class Session:
if self.prompt != get_default_prompt():
self.persistence()
if explicit:
pkg.database.manager.get_inst().explicit_close_session(self.name, self.create_timestamp)
pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
if expired:
pkg.database.manager.get_inst().set_session_expired(self.name, self.create_timestamp)
pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
self.prompt = get_default_prompt()
self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time())
@@ -233,11 +235,11 @@ class Session:
# 将本session的数据库状态设置为on_going
def set_ongoing(self):
pkg.database.manager.get_inst().set_session_ongoing(self.name, self.create_timestamp)
pkg.utils.context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
# 切换到上一个session
def last_session(self):
last_one = pkg.database.manager.get_inst().last_session(self.name, self.last_interact_timestamp)
last_one = pkg.utils.context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
if last_one is None:
return None
else:
@@ -252,7 +254,7 @@ class Session:
# 切换到下一个session
def next_session(self):
next_one = pkg.database.manager.get_inst().next_session(self.name, self.last_interact_timestamp)
next_one = pkg.utils.context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
if next_one is None:
return None
else:
@@ -266,8 +268,8 @@ class Session:
return self
def list_history(self, capacity: int = 10, page: int = 0):
return pkg.database.manager.get_inst().list_history(self.name, capacity, page,
get_default_prompt())
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page,
get_default_prompt())
def draw_image(self, prompt: str):
return pkg.openai.manager.get_inst().request_image(prompt)
return pkg.utils.context.get_openai_manager().request_image(prompt)

View File

@@ -17,8 +17,7 @@ import logging
import pkg.qqbot.filter
import pkg.qqbot.process as processor
inst = None
import pkg.utils.context
# 并行运行
@@ -107,8 +106,8 @@ class QQBotManager:
self.bot = bot
global inst
inst = self
pkg.utils.context.set_qqbot_manager(self)
def send(self, event, msg, check_quote=True):
asyncio.run(
@@ -198,6 +197,3 @@ class QQBotManager:
threading.Thread(target=asyncio.run, args=(send_task,)).start()
def get_inst() -> QQBotManager:
global inst
return inst

View File

@@ -16,6 +16,7 @@ import pkg.openai.session
import pkg.openai.manager
import pkg.utils.reloader
import pkg.utils.updater
import pkg.utils.context
processing = []
@@ -25,7 +26,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
sender_id: int) -> MessageChain:
global processing
mgr = pkg.qqbot.manager.get_inst()
mgr = pkg.utils.context.get_qqbot_manager()
reply = []
session_name = "{}_{}".format(launcher_type, launcher_id)
@@ -125,22 +126,22 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
reply = [reply_str]
elif cmd == 'usage':
api_keys = pkg.openai.manager.get_inst().key_mgr.api_key
api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key
reply_str = "[bot]api-key使用情况:(阈值:{})\n\n".format(
pkg.openai.manager.get_inst().key_mgr.api_key_fee_threshold)
pkg.utils.context.get_openai_manager().key_mgr.api_key_fee_threshold)
using_key_name = ""
for api_key in api_keys:
reply_str += "{}:\n - {}美元 {}%\n".format(api_key,
round(
pkg.openai.manager.get_inst().key_mgr.get_fee(
pkg.utils.context.get_openai_manager().key_mgr.get_fee(
api_keys[api_key]), 6),
round(
pkg.openai.manager.get_inst().key_mgr.get_fee(
pkg.utils.context.get_openai_manager().key_mgr.get_fee(
api_keys[
api_key]) / pkg.openai.manager.get_inst().key_mgr.api_key_fee_threshold * 100,
api_key]) / pkg.utils.context.get_openai_manager().key_mgr.api_key_fee_threshold * 100,
3))
if api_keys[api_key] == pkg.openai.manager.get_inst().key_mgr.using_key:
if api_keys[api_key] == pkg.utils.context.get_openai_manager().key_mgr.using_key:
using_key_name = api_key
reply_str += "\n当前使用:{}".format(using_key_name)
@@ -191,17 +192,17 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
reply = ["[bot]err:调用API失败请重试或联系作者或等待修复"]
except openai.error.RateLimitError as e:
# 尝试切换api-key
current_tokens_amt = pkg.openai.manager.get_inst().key_mgr.get_fee(
pkg.openai.manager.get_inst().key_mgr.get_using_key())
pkg.openai.manager.get_inst().key_mgr.set_current_exceeded()
switched, name = pkg.openai.manager.get_inst().key_mgr.auto_switch()
current_tokens_amt = pkg.utils.context.get_openai_manager().key_mgr.get_fee(
pkg.utils.context.get_openai_manager().key_mgr.get_using_key())
pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded()
switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch()
if not switched:
mgr.notify_admin("API调用额度超限({}),请向OpenAI账户充值或在config.py中更换api_key".format(
current_tokens_amt))
reply = ["[bot]err:API调用额度超额请联系作者或等待修复"]
else:
openai.api_key = pkg.openai.manager.get_inst().key_mgr.get_using_key()
openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key()
mgr.notify_admin("API调用额度超限({}),已切换到{}".format(current_tokens_amt, name))
reply = ["[bot]err:API调用额度超额已自动切换请重新发送消息"]
except openai.error.InvalidRequestError as e:

31
pkg/utils/context.py Normal file
View File

@@ -0,0 +1,31 @@
context = {
'inst': {
'database.manager.DatabaseManager': None,
'openai.manager.OpenAIInteract': None,
'qqbot.manager.QQBotManager': None,
}
}
def set_database_manager(inst):
context['inst']['database.manager.DatabaseManager'] = inst
def get_database_manager():
return context['inst']['database.manager.DatabaseManager']
def set_openai_manager(inst):
context['inst']['openai.manager.OpenAIInteract'] = inst
def get_openai_manager():
return context['inst']['openai.manager.OpenAIInteract']
def set_qqbot_manager(inst):
context['inst']['qqbot.manager.QQBotManager'] = inst
def get_qqbot_manager():
return context['inst']['qqbot.manager.QQBotManager']

View File

@@ -3,6 +3,7 @@ import logging
import pkg
import importlib
import pkgutil
import pkg.utils.context
def walk(module, prefix=''):
@@ -15,5 +16,7 @@ def walk(module, prefix=''):
def reload_all():
context = pkg.utils.context.context
walk(pkg)
importlib.reload(__import__('config'))
pkg.utils.context.context = context