feat: 运行时原动态引入config的地方现在均使用初始化时导入的config对象

This commit is contained in:
Rock Chin
2023-01-04 17:09:57 +08:00
parent b318f6d4f0
commit 95ad911a6c
8 changed files with 59 additions and 20 deletions
+1 -2
View File
@@ -6,7 +6,6 @@ from sqlite3 import Cursor
import sqlite3
import config
import pkg.utils.context
@@ -25,7 +24,6 @@ class DatabaseManager:
# 连接到数据库文件
def reconnect(self):
self.conn = sqlite3.connect('database.db', check_same_thread=False)
# self.conn.isolation_level = None
self.cursor = self.conn.cursor()
def close(self):
@@ -127,6 +125,7 @@ class DatabaseManager:
# 从数据库加载还没过期的session数据
def load_valid_sessions(self) -> dict:
# 从数据库中加载所有还没过期的session
config = pkg.utils.context.get_config()
self.execute("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
from `sessions` where `last_interact_timestamp` > {}
+3 -1
View File
@@ -5,7 +5,6 @@ import logging
import pkg.database.manager
import pkg.qqbot.manager
import pkg.utils.context
import config
class KeysManager:
@@ -34,6 +33,8 @@ class KeysManager:
def __init__(self, api_key):
# if hasattr(config, 'api_key_usage_threshold'):
# self.api_key_usage_threshold = config.api_key_usage_threshold
config = pkg.utils.context.get_config()
if hasattr(config, 'api_key_fee_threshold'):
self.api_key_fee_threshold = config.api_key_fee_threshold
self.load_fee()
@@ -108,6 +109,7 @@ class KeysManager:
self.fee[md5] += fee
config = pkg.utils.context.get_config()
if self.fee[md5] >= self.api_key_fee_threshold and \
hasattr(config, 'auto_switch_api_key') and config.auto_switch_api_key:
switch_result, key_name = self.auto_switch()
+2 -2
View File
@@ -2,8 +2,6 @@ import logging
import openai
import config
import pkg.openai.keymgr
import pkg.openai.pricing as pricing
import pkg.utils.context
@@ -34,6 +32,7 @@ class OpenAIInteract:
# 请求OpenAI Completion
def request_completion(self, prompt, stop):
config = pkg.utils.context.get_config()
response = openai.Completion.create(
prompt=prompt,
stop=stop,
@@ -53,6 +52,7 @@ class OpenAIInteract:
def request_image(self, prompt):
config = pkg.utils.context.get_config()
params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params
response = openai.Image.create(
+6 -1
View File
@@ -2,7 +2,6 @@ import logging
import threading
import time
import config
import pkg.openai.manager
import pkg.database.manager
import pkg.utils.context
@@ -54,6 +53,7 @@ def dump_session(session_name: str):
# 从配置文件获取会话预设信息
def get_default_prompt():
import config
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \
@@ -85,6 +85,8 @@ class Session:
prompt = get_default_prompt()
import config
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
@@ -130,6 +132,8 @@ class Session:
# 不是此session已更换,退出
if self.create_timestamp != create_timestamp or self not in sessions.values():
return
config = pkg.utils.context.get_config()
if int(time.time()) - self.last_interact_timestamp > config.session_expire_time:
logging.info('session {} 已过期'.format(self.name))
self.reset(expired=True, schedule_new=False)
@@ -144,6 +148,7 @@ class Session:
self.last_interact_timestamp = int(time.time())
# max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7
config = pkg.utils.context.get_config()
max_rounds = 1000 # 不再限制回合数
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
+4 -1
View File
@@ -7,7 +7,6 @@ import mirai.models.bus
from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
FriendMessage, Image
import config
import pkg.openai.session
import pkg.openai.manager
from func_timeout import FunctionTimedOut
@@ -26,6 +25,7 @@ def go(func, args=()):
# 检查消息是否符合泛响应匹配机制
def check_response_rule(text: str) -> (bool, str):
config = pkg.utils.context.get_config()
if not hasattr(config, 'response_rules'):
return False, ''
@@ -60,6 +60,7 @@ class QQBotManager:
self.timeout = timeout
self.retry = retry
config = pkg.utils.context.get_config()
if os.path.exists("sensitive.json") \
and config.sensitive_word_filter is not None \
and config.sensitive_word_filter:
@@ -134,6 +135,7 @@ class QQBotManager:
self.bot = bot
def send(self, event, msg, check_quote=True):
config = pkg.utils.context.get_config()
asyncio.run(
self.bot.send(event, msg, quote=True if hasattr(config,
"quote_origin") and config.quote_origin and check_quote else False))
@@ -216,6 +218,7 @@ class QQBotManager:
# 通知系统管理员
def notify_admin(self, message: str):
config = pkg.utils.context.get_config()
if hasattr(config, "admin_qq") and config.admin_qq != 0:
logging.info("通知管理员:{}".format(message))
send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message))
+3
View File
@@ -9,6 +9,9 @@ import openai
from mirai import Image, MessageChain
# 这里不使用动态引入config
# 因为在这里动态引入会卡死程序
# 而此模块静态引用config与动态引入的表现一致
import config
import pkg.openai.session
+9
View File
@@ -9,9 +9,18 @@ context = {
'qqbot.manager.QQBotManager': None,
},
'logger_handler': None,
'config': None,
}
def set_config(inst):
context['config'] = inst
def get_config():
return context['config']
def set_database_manager(inst):
context['inst']['database.manager.DatabaseManager'] = inst