mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-26 23:44:19 +00:00
feat: 重载后恢复诸个单例对象
This commit is contained in:
+6
-50
@@ -4,6 +4,7 @@ import logging
|
||||
|
||||
import pkg.database.manager
|
||||
import pkg.qqbot.manager
|
||||
import pkg.utils.context
|
||||
import config
|
||||
|
||||
|
||||
@@ -62,43 +63,6 @@ class KeysManager:
|
||||
def add(self, key_name, key):
|
||||
self.api_key[key_name] = key
|
||||
|
||||
# def get_usage(self, api_key):
|
||||
# md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest()
|
||||
# if md5 not in self.usage:
|
||||
# self.usage[md5] = 0
|
||||
# return self.usage[md5]
|
||||
|
||||
# 报告使用
|
||||
# 返回是否需要将openai的api-key切换
|
||||
# def report_usage(self, new_content: str) -> bool:
|
||||
# md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
|
||||
# if md5 not in self.usage:
|
||||
# self.usage[md5] = 0
|
||||
#
|
||||
# # 经测算得出的理论与实际的偏差比例
|
||||
# salt_rate = 0.91
|
||||
#
|
||||
# self.usage[md5] += ( (len(new_content.encode('utf-8')) - len(new_content)) / 2 + len(new_content) )*salt_rate
|
||||
#
|
||||
# self.usage[md5] = int(self.usage[md5])
|
||||
#
|
||||
# if self.usage[md5] >= self.api_key_usage_threshold:
|
||||
# switch_result, key_name = self.auto_switch()
|
||||
#
|
||||
# # 检查是否切换到新的
|
||||
# if switch_result:
|
||||
# if key_name not in self.alerted:
|
||||
# # 通知管理员
|
||||
# pkg.qqbot.manager.get_inst().notify_admin("api-key已切换到:" + key_name)
|
||||
# self.alerted.append(key_name)
|
||||
# return True
|
||||
# else:
|
||||
# if key_name not in self.alerted:
|
||||
# # 通知管理员
|
||||
# pkg.qqbot.manager.get_inst().notify_admin("api-key已用完,无未使用的api-key可供切换")
|
||||
# self.alerted.append(key_name)
|
||||
# return False
|
||||
|
||||
# 设置当前使用的api-key使用量超限
|
||||
# 这是在尝试调用api时发生超限异常时调用的
|
||||
def set_current_exceeded(self):
|
||||
@@ -107,14 +71,6 @@ class KeysManager:
|
||||
self.fee[md5] = self.api_key_fee_threshold
|
||||
self.dump_fee()
|
||||
|
||||
# def dump_usage(self):
|
||||
# pkg.database.manager.get_inst().dump_api_key_usage(api_keys=self.api_key, usage=self.usage)
|
||||
|
||||
# def load_usage(self):
|
||||
# self.usage = pkg.database.manager.get_inst().load_api_key_usage()
|
||||
# logging.debug("load usage:" + str(self.usage))
|
||||
# print("load usage:" + str(self.usage))
|
||||
|
||||
def get_fee(self, api_key):
|
||||
md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest()
|
||||
if md5 not in self.fee:
|
||||
@@ -135,19 +91,19 @@ class KeysManager:
|
||||
if switch_result:
|
||||
if key_name not in self.alerted:
|
||||
# 通知管理员
|
||||
pkg.qqbot.manager.get_inst().notify_admin("api-key已切换到:" + key_name)
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("api-key已切换到:" + key_name)
|
||||
self.alerted.append(key_name)
|
||||
return True
|
||||
else:
|
||||
if key_name not in self.alerted:
|
||||
# 通知管理员
|
||||
pkg.qqbot.manager.get_inst().notify_admin("api-key已用完,无未使用的api-key可供切换")
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("api-key已用完,无未使用的api-key可供切换")
|
||||
self.alerted.append(key_name)
|
||||
return False
|
||||
|
||||
def dump_fee(self):
|
||||
pkg.database.manager.get_inst().dump_api_key_fee(api_keys=self.api_key, fee=self.fee)
|
||||
pkg.utils.context.get_database_manager().dump_api_key_fee(api_keys=self.api_key, fee=self.fee)
|
||||
|
||||
def load_fee(self):
|
||||
self.fee = pkg.database.manager.get_inst().load_api_key_fee()
|
||||
logging.info("load fee:" + str(self.fee))
|
||||
self.fee = pkg.utils.context.get_database_manager().load_api_key_fee()
|
||||
logging.info("load fee:" + str(self.fee))
|
||||
|
||||
@@ -6,8 +6,7 @@ import config
|
||||
|
||||
import pkg.openai.keymgr
|
||||
import pkg.openai.pricing as pricing
|
||||
|
||||
inst = None
|
||||
import pkg.utils.context
|
||||
|
||||
|
||||
# 为其他模块提供与OpenAI交互的接口
|
||||
@@ -27,11 +26,11 @@ class OpenAIInteract:
|
||||
|
||||
openai.api_key = self.key_mgr.get_using_key()
|
||||
|
||||
global inst
|
||||
inst = self
|
||||
pkg.utils.context.set_openai_manager(self)
|
||||
|
||||
# 请求OpenAI Completion
|
||||
def request_completion(self, prompt, stop):
|
||||
print("request")
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
@@ -41,7 +40,6 @@ class OpenAIInteract:
|
||||
|
||||
switched = self.key_mgr.report_fee(pricing.language_base_price(config.completion_api_params['model'],
|
||||
prompt + response['choices'][0]['text']))
|
||||
|
||||
if switched:
|
||||
openai.api_key = self.key_mgr.get_using_key()
|
||||
|
||||
@@ -64,7 +62,3 @@ class OpenAIInteract:
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def get_inst() -> OpenAIInteract:
|
||||
global inst
|
||||
return inst
|
||||
|
||||
+16
-14
@@ -5,6 +5,7 @@ import time
|
||||
import config
|
||||
import pkg.openai.manager
|
||||
import pkg.database.manager
|
||||
import pkg.utils.context
|
||||
|
||||
# 运行时保存的所有session
|
||||
sessions = {}
|
||||
@@ -19,7 +20,7 @@ class SessionOfflineStatus:
|
||||
def load_sessions():
|
||||
global sessions
|
||||
|
||||
db_inst = pkg.database.manager.get_inst()
|
||||
db_inst = pkg.utils.context.get_database_manager()
|
||||
|
||||
session_data = db_inst.load_valid_sessions()
|
||||
|
||||
@@ -147,10 +148,11 @@ class Session:
|
||||
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
|
||||
|
||||
# 向API请求补全
|
||||
response = pkg.openai.manager.get_inst().request_completion(self.cut_out(self.prompt + self.user_name + ':' +
|
||||
text + '\n' + self.bot_name + ':',
|
||||
max_rounds, max_length),
|
||||
self.user_name + ':')
|
||||
response = pkg.utils.context.get_openai_manager().request_completion(
|
||||
self.cut_out(self.prompt + self.user_name + ':' +
|
||||
text + '\n' + self.bot_name + ':',
|
||||
max_rounds, max_length),
|
||||
self.user_name + ':')
|
||||
|
||||
self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':'
|
||||
# print(response)
|
||||
@@ -202,7 +204,7 @@ class Session:
|
||||
if self.prompt == get_default_prompt():
|
||||
return
|
||||
|
||||
db_inst = pkg.database.manager.get_inst()
|
||||
db_inst = pkg.utils.context.get_database_manager()
|
||||
|
||||
name_spt = self.name.split('_')
|
||||
|
||||
@@ -217,10 +219,10 @@ class Session:
|
||||
if self.prompt != get_default_prompt():
|
||||
self.persistence()
|
||||
if explicit:
|
||||
pkg.database.manager.get_inst().explicit_close_session(self.name, self.create_timestamp)
|
||||
pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
|
||||
|
||||
if expired:
|
||||
pkg.database.manager.get_inst().set_session_expired(self.name, self.create_timestamp)
|
||||
pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
|
||||
self.prompt = get_default_prompt()
|
||||
self.create_timestamp = int(time.time())
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
@@ -233,11 +235,11 @@ class Session:
|
||||
|
||||
# 将本session的数据库状态设置为on_going
|
||||
def set_ongoing(self):
|
||||
pkg.database.manager.get_inst().set_session_ongoing(self.name, self.create_timestamp)
|
||||
pkg.utils.context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
|
||||
|
||||
# 切换到上一个session
|
||||
def last_session(self):
|
||||
last_one = pkg.database.manager.get_inst().last_session(self.name, self.last_interact_timestamp)
|
||||
last_one = pkg.utils.context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
|
||||
if last_one is None:
|
||||
return None
|
||||
else:
|
||||
@@ -252,7 +254,7 @@ class Session:
|
||||
|
||||
# 切换到下一个session
|
||||
def next_session(self):
|
||||
next_one = pkg.database.manager.get_inst().next_session(self.name, self.last_interact_timestamp)
|
||||
next_one = pkg.utils.context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
|
||||
if next_one is None:
|
||||
return None
|
||||
else:
|
||||
@@ -266,8 +268,8 @@ class Session:
|
||||
return self
|
||||
|
||||
def list_history(self, capacity: int = 10, page: int = 0):
|
||||
return pkg.database.manager.get_inst().list_history(self.name, capacity, page,
|
||||
get_default_prompt())
|
||||
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page,
|
||||
get_default_prompt())
|
||||
|
||||
def draw_image(self, prompt: str):
|
||||
return pkg.openai.manager.get_inst().request_image(prompt)
|
||||
return pkg.utils.context.get_openai_manager().request_image(prompt)
|
||||
|
||||
Reference in New Issue
Block a user