refactor: 适配配置管理器读取方式

This commit is contained in:
RockChinQ
2023-11-26 23:58:06 +08:00
parent 549a7eff7f
commit 3e17bbb90f
20 changed files with 147 additions and 112 deletions

View File

@@ -218,6 +218,10 @@ async def start_process(first_time_init=False):
except Exception as e: except Exception as e:
print("更新openai库失败:{}, 请忽略或自行更新".format(e)) print("更新openai库失败:{}, 请忽略或自行更新".format(e))
# 初始化文字转图片
from pkg.utils import text2img
text2img.initialize()
known_exception_caught = False known_exception_caught = False
try: try:
try: try:

View File

@@ -47,10 +47,10 @@ class DataGatherer:
def thread_func(): def thread_func():
try: try:
config = context.get_config() config = context.get_config_manager().data
if not config.report_usage: if not config['report_usage']:
return return
res = requests.get("http://reports.rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}&msg_source={}".format(subservice_name, self.version_str, count, config.msg_source_adapter)) res = requests.get("http://reports.rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}&msg_source={}".format(subservice_name, self.version_str, count, config['msg_source_adapter']))
if res.status_code != 200 or res.text != "ok": if res.status_code != 200 or res.text != "ok":
logging.warning("report to server failed, status_code: {}, text: {}".format(res.status_code, res.text)) logging.warning("report to server failed, status_code: {}, text: {}".format(res.status_code, res.text))
except: except:

View File

@@ -144,11 +144,11 @@ class DatabaseManager:
# 从数据库加载还没过期的session数据 # 从数据库加载还没过期的session数据
def load_valid_sessions(self) -> dict: def load_valid_sessions(self) -> dict:
# 从数据库中加载所有还没过期的session # 从数据库中加载所有还没过期的session
config = context.get_config() config = context.get_config_manager().data
self.__execute__(""" self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `last_interact_timestamp` > {} from `sessions` where `last_interact_timestamp` > {}
""".format(int(time.time()) - config.session_expire_time)) """.format(int(time.time()) - config['session_expire_time']))
results = self.cursor.fetchall() results = self.cursor.fetchall()
sessions = {} sessions = {}
for result in results: for result in results:

View File

@@ -3,6 +3,8 @@ import logging
import openai import openai
from ...utils import context
class RequestBase: class RequestBase:
@@ -14,7 +16,6 @@ class RequestBase:
raise NotImplementedError raise NotImplementedError
def _next_key(self): def _next_key(self):
import pkg.utils.context as context
switched, name = context.get_openai_manager().key_mgr.auto_switch() switched, name = context.get_openai_manager().key_mgr.auto_switch()
logging.debug("切换api-key: switched={}, name={}".format(switched, name)) logging.debug("切换api-key: switched={}, name={}".format(switched, name))
self.client.api_key = context.get_openai_manager().key_mgr.get_using_key() self.client.api_key = context.get_openai_manager().key_mgr.get_using_key()
@@ -22,12 +23,12 @@ class RequestBase:
def _req(self, **kwargs): def _req(self, **kwargs):
"""处理代理问题""" """处理代理问题"""
logging.debug("请求接口参数: %s", str(kwargs)) logging.debug("请求接口参数: %s", str(kwargs))
import config config = context.get_config_manager().data
ret = self.req_func(**kwargs) ret = self.req_func(**kwargs)
logging.debug("接口请求返回:%s", str(ret)) logging.debug("接口请求返回:%s", str(ret))
if config.switch_strategy == 'active': if config['switch_strategy'] == 'active':
self._next_key() self._next_key()
return ret return ret

View File

@@ -1,9 +1,10 @@
# 多情景预设值管理 # 多情景预设值管理
import json import json
import logging import logging
import config
import os import os
from ..utils import context
# __current__ = "default" # __current__ = "default"
# """当前默认使用的情景预设的名称 # """当前默认使用的情景预设的名称
@@ -62,22 +63,24 @@ class NormalScenarioMode(ScenarioMode):
"""普通情景预设模式""" """普通情景预设模式"""
def __init__(self): def __init__(self):
config = context.get_config_manager().data
# 加载config中的default_prompt值 # 加载config中的default_prompt值
if type(config.default_prompt) == str: if type(config['default_prompt']) == str:
self.using_prompt_name = "default" self.using_prompt_name = "default"
self.prompts = {"default": [ self.prompts = {"default": [
{ {
"role": "system", "role": "system",
"content": config.default_prompt "content": config['default_prompt']
} }
]} ]}
elif type(config.default_prompt) == dict: elif type(config['default_prompt']) == dict:
for key in config.default_prompt: for key in config['default_prompt']:
self.prompts[key] = [ self.prompts[key] = [
{ {
"role": "system", "role": "system",
"content": config.default_prompt[key] "content": config['default_prompt'][key]
} }
] ]
@@ -123,9 +126,9 @@ def register_all():
def mode_inst() -> ScenarioMode: def mode_inst() -> ScenarioMode:
"""获取指定名称的情景预设模式对象""" """获取指定名称的情景预设模式对象"""
import config config = context.get_config_manager().data
if config.preset_mode == "default": if config['preset_mode'] == "default":
config.preset_mode = "normal" config['preset_mode'] = "normal"
return scenario_mode_mapping[config.preset_mode] return scenario_mode_mapping[config['preset_mode']]

View File

@@ -43,13 +43,13 @@ class OpenAIInteract:
"""请求补全接口回复= """请求补全接口回复=
""" """
# 选择接口请求类 # 选择接口请求类
config = context.get_config() config = context.get_config_manager().data
request: api_model.RequestBase request: api_model.RequestBase
model: str = config.completion_api_params['model'] model: str = config['completion_api_params']['model']
cp_parmas = config.completion_api_params.copy() cp_parmas = config['completion_api_params'].copy()
del cp_parmas['model'] del cp_parmas['model']
request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas) request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas)
@@ -74,8 +74,8 @@ class OpenAIInteract:
Returns: Returns:
dict: 响应 dict: 响应
""" """
config = context.get_config() config = context.get_config_manager().data
params = config.image_api_params params = config['image_api_params']
response = openai.Image.create( response = openai.Image.create(
prompt=prompt, prompt=prompt,

View File

@@ -36,11 +36,11 @@ def reset_session_prompt(session_name, prompt):
f.write(prompt) f.write(prompt)
f.close() f.close()
# 生成新数据 # 生成新数据
config = context.get_config() config = context.get_config_manager().data
prompt = [ prompt = [
{ {
'role': 'system', 'role': 'system',
'content': config.default_prompt['default'] if type(config.default_prompt) == dict else config.default_prompt 'content': config['default_prompt']['default'] if type(config['default_prompt']) == dict else config['default_prompt']
} }
] ]
# 警告 # 警告
@@ -170,15 +170,15 @@ class Session:
if self.create_timestamp != create_timestamp or self not in sessions.values(): if self.create_timestamp != create_timestamp or self not in sessions.values():
return return
config = context.get_config() config = context.get_config_manager().data
if int(time.time()) - self.last_interact_timestamp > config.session_expire_time: if int(time.time()) - self.last_interact_timestamp > config['session_expire_time']:
logging.info('session {} 已过期'.format(self.name)) logging.info('session {} 已过期'.format(self.name))
# 触发插件事件 # 触发插件事件
args = { args = {
'session_name': self.name, 'session_name': self.name,
'session': self, 'session': self,
'session_expire_time': config.session_expire_time 'session_expire_time': config['session_expire_time']
} }
event = plugin_host.emit(plugin_models.SessionExpired, **args) event = plugin_host.emit(plugin_models.SessionExpired, **args)
if event.is_prevented_default(): if event.is_prevented_default():
@@ -216,8 +216,8 @@ class Session:
if event.is_prevented_default(): if event.is_prevented_default():
return None, None, None return None, None, None
config = context.get_config() config = context.get_config_manager().data
max_length = config.prompt_submit_length max_length = config['prompt_submit_length']
local_default_prompt = self.default_prompt.copy() local_default_prompt = self.default_prompt.copy()
local_prompt = self.prompt.copy() local_prompt = self.prompt.copy()
@@ -254,7 +254,7 @@ class Session:
funcs = [] funcs = []
trace_func_calls = config.trace_function_calls trace_func_calls = config['trace_function_calls']
botmgr = context.get_qqbot_manager() botmgr = context.get_qqbot_manager()
session_name_spt: list[str] = self.name.split("_") session_name_spt: list[str] = self.name.split("_")
@@ -381,7 +381,7 @@ class Session:
# 包装目前的对话回合内容 # 包装目前的对话回合内容
changable_prompts = [] changable_prompts = []
use_model = context.get_config().completion_api_params['model'] use_model = context.get_config_manager().data['completion_api_params']['model']
ptr = len(prompt) - 1 ptr = len(prompt) - 1

View File

@@ -9,7 +9,7 @@ from mirai.models.message import ForwardMessageNode
from mirai.models.base import MiraiBaseModel from mirai.models.base import MiraiBaseModel
from ..utils import text2img from ..utils import text2img
import config from ..utils import context
class ForwardMessageDiaplay(MiraiBaseModel): class ForwardMessageDiaplay(MiraiBaseModel):
@@ -64,13 +64,16 @@ def text_to_image(text: str) -> MessageComponent:
def check_text(text: str) -> list: def check_text(text: str) -> list:
"""检查文本是否为长消息,并转换成该使用的消息链组件""" """检查文本是否为长消息,并转换成该使用的消息链组件"""
if len(text) > config.blob_message_threshold:
config = context.get_config_manager().data
if len(text) > config['blob_message_threshold']:
# logging.info("长消息: {}".format(text)) # logging.info("长消息: {}".format(text))
if config.blob_message_strategy == 'image': if config['blob_message_strategy'] == 'image':
# 转换成图片 # 转换成图片
return [text_to_image(text)] return [text_to_image(text)]
elif config.blob_message_strategy == 'forward': elif config['blob_message_strategy'] == 'forward':
# 包装转发消息 # 包装转发消息
display = ForwardMessageDiaplay( display = ForwardMessageDiaplay(
@@ -82,7 +85,7 @@ def check_text(text: str) -> list:
) )
node = ForwardMessageNode( node = ForwardMessageNode(
sender_id=config.mirai_http_api_config['qq'], sender_id=config['mirai_http_api_config']['qq'],
sender_name='bot', sender_name='bot',
message_chain=MessageChain([text]) message_chain=MessageChain([text])
) )

View File

@@ -3,7 +3,7 @@ import logging
import mirai import mirai
from .. import aamgr from .. import aamgr
import config from ....utils import context
@aamgr.AbstractCommandNode.register( @aamgr.AbstractCommandNode.register(
@@ -30,8 +30,8 @@ class DrawCommand(aamgr.AbstractCommandNode):
logging.debug("draw_image result:{}".format(res)) logging.debug("draw_image result:{}".format(res))
reply = [mirai.Image(url=res['data'][0]['url'])] reply = [mirai.Image(url=res['data'][0]['url'])]
if not (hasattr(config, 'include_image_description') config = context.get_config_manager().data
and not config.include_image_description): if config['include_image_description']:
reply.append(" ".join(ctx.params)) reply.append(" ".join(ctx.params))
return True, reply return True, reply

View File

@@ -1,4 +1,6 @@
from .. import aamgr from .. import aamgr
from ....utils import context
@aamgr.AbstractCommandNode.register( @aamgr.AbstractCommandNode.register(
parent=None, parent=None,
@@ -15,12 +17,13 @@ class DefaultCommand(aamgr.AbstractCommandNode):
session_name = ctx.session_name session_name = ctx.session_name
params = ctx.params params = ctx.params
reply = [] reply = []
import config
config = context.get_config_manager().data
if len(params) == 0: if len(params) == 0:
# 输出目前所有情景预设 # 输出目前所有情景预设
import pkg.openai.dprompt as dprompt import pkg.openai.dprompt as dprompt
reply_str = "[bot]当前所有情景预设({}模式):\n\n".format(config.preset_mode) reply_str = "[bot]当前所有情景预设({}模式):\n\n".format(config['preset_mode'])
prompts = dprompt.mode_inst().list() prompts = dprompt.mode_inst().list()

View File

@@ -4,7 +4,7 @@ import logging
from ..qqbot.cmds import aamgr as cmdmgr from ..qqbot.cmds import aamgr as cmdmgr
def process_command(session_name: str, text_message: str, mgr, config, def process_command(session_name: str, text_message: str, mgr, config: dict,
launcher_type: str, launcher_id: int, sender_id: int, is_admin: bool) -> list: launcher_type: str, launcher_id: int, sender_id: int, is_admin: bool) -> list:
reply = [] reply = []
try: try:

View File

@@ -4,6 +4,8 @@ import requests
import json import json
import logging import logging
from ..utils import context
class ReplyFilter: class ReplyFilter:
sensitive_words = [] sensitive_words = []
@@ -20,12 +22,13 @@ class ReplyFilter:
self.sensitive_words = sensitive_words self.sensitive_words = sensitive_words
self.mask = mask self.mask = mask
self.mask_word = mask_word self.mask_word = mask_word
import config
self.baidu_check = config.baidu_check config = context.get_config_manager().data
self.baidu_api_key = config.baidu_api_key
self.baidu_secret_key = config.baidu_secret_key self.baidu_check = config['baidu_check']
self.inappropriate_message_tips = config.inappropriate_message_tips self.baidu_api_key = config['baidu_api_key']
self.baidu_secret_key = config['baidu_secret_key']
self.inappropriate_message_tips = config['inappropriate_message_tips']
def is_illegal(self, message: str) -> bool: def is_illegal(self, message: str) -> bool:
processed = self.process(message) processed = self.process(message)

View File

@@ -1,16 +1,18 @@
import re import re
from ..utils import context
def ignore(msg: str) -> bool: def ignore(msg: str) -> bool:
"""检查消息是否应该被忽略""" """检查消息是否应该被忽略"""
import config config = context.get_config_manager().data
if 'prefix' in config.ignore_rules: if 'prefix' in config['ignore_rules']:
for rule in config.ignore_rules['prefix']: for rule in config['ignore_rules']['prefix']:
if msg.startswith(rule): if msg.startswith(rule):
return True return True
if 'regexp' in config.ignore_rules: if 'regexp' in config['ignore_rules']:
for rule in config.ignore_rules['regexp']: for rule in config['ignore_rules']['regexp']:
if re.search(rule, msg): if re.search(rule, msg):
return True return True

View File

@@ -13,15 +13,15 @@ import tips as tips_custom
def handle_exception(notify_admin: str = "", set_reply: str = "") -> list: def handle_exception(notify_admin: str = "", set_reply: str = "") -> list:
"""处理异常当notify_admin不为空时会通知管理员返回通知用户的消息""" """处理异常当notify_admin不为空时会通知管理员返回通知用户的消息"""
import config config = context.get_config_manager().data
context.get_qqbot_manager().notify_admin(notify_admin) context.get_qqbot_manager().notify_admin(notify_admin)
if config.hide_exce_info_to_user: if config['hide_exce_info_to_user']:
return [tips_custom.alter_tip_message] if tips_custom.alter_tip_message else [] return [tips_custom.alter_tip_message] if tips_custom.alter_tip_message else []
else: else:
return [set_reply] return [set_reply]
def process_normal_message(text_message: str, mgr, config, launcher_type: str, def process_normal_message(text_message: str, mgr, config: dict, launcher_type: str,
launcher_id: int, sender_id: int) -> list: launcher_id: int, sender_id: int) -> list:
session_name = f"{launcher_type}_{launcher_id}" session_name = f"{launcher_type}_{launcher_id}"
logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + ( logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + (
@@ -39,7 +39,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str,
reply = handle_exception(notify_admin=f"{session_name},多次尝试失败。", set_reply=f"[bot]多次尝试失败,请重试或联系管理员") reply = handle_exception(notify_admin=f"{session_name},多次尝试失败。", set_reply=f"[bot]多次尝试失败,请重试或联系管理员")
break break
try: try:
prefix = "[GPT]" if config.show_prefix else "" prefix = "[GPT]" if config['show_prefix'] else ""
text, finish_reason, funcs = session.query(text_message) text, finish_reason, funcs = session.query(text_message)
@@ -118,7 +118,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str,
reply = handle_exception("{}会话调用API失败:{}".format(session_name, e), reply = handle_exception("{}会话调用API失败:{}".format(session_name, e),
"[bot]err:RateLimitError,请重试或联系作者,或等待修复") "[bot]err:RateLimitError,请重试或联系作者,或等待修复")
except openai.BadRequestError as e: except openai.BadRequestError as e:
if config.auto_reset and "This model's maximum context length is" in str(e): if config['auto_reset'] and "This model's maximum context length is" in str(e):
session.reset(persist=True) session.reset(persist=True)
reply = [tips_custom.session_auto_reset_message] reply = [tips_custom.session_auto_reset_message]
else: else:

View File

@@ -1,6 +1,7 @@
# 此模块提供了消息处理的具体逻辑的接口 # 此模块提供了消息处理的具体逻辑的接口
import asyncio import asyncio
import time import time
import traceback
import mirai import mirai
import logging import logging
@@ -28,11 +29,11 @@ processing = []
def is_admin(qq: int) -> bool: def is_admin(qq: int) -> bool:
"""兼容list和int类型的管理员判断""" """兼容list和int类型的管理员判断"""
import config config = context.get_config_manager().data
if type(config.admin_qq) == list: if type(config['admin_qq']) == list:
return qq in config.admin_qq return qq in config['admin_qq']
else: else:
return qq == config.admin_qq return qq == config['admin_qq']
def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: mirai.MessageChain, def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: mirai.MessageChain,
@@ -53,9 +54,9 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
logging.info("根据忽略规则忽略消息: {}".format(text_message)) logging.info("根据忽略规则忽略消息: {}".format(text_message))
return [] return []
import config config = context.get_config_manager().data
if not config.wait_last_done and session_name in processing: if not config['wait_last_done'] and session_name in processing:
return mirai.MessageChain([mirai.Plain(tips_custom.message_drop_tip)]) return mirai.MessageChain([mirai.Plain(tips_custom.message_drop_tip)])
# 检查是否被禁言 # 检查是否被禁言
@@ -65,8 +66,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
logging.info("机器人被禁言,跳过消息处理(group_{})".format(launcher_id)) logging.info("机器人被禁言,跳过消息处理(group_{})".format(launcher_id))
return reply return reply
import config if config['income_msg_check']:
if config.income_msg_check:
if mgr.reply_filter.is_illegal(text_message): if mgr.reply_filter.is_illegal(text_message):
return mirai.MessageChain(mirai.Plain("[bot] 消息中存在不合适的内容, 请更换措辞")) return mirai.MessageChain(mirai.Plain("[bot] 消息中存在不合适的内容, 请更换措辞"))
@@ -81,8 +81,6 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
# 处理消息 # 处理消息
try: try:
config = context.get_config()
processing.append(session_name) processing.append(session_name)
try: try:
if text_message.startswith('!') or text_message.startswith(""): # 指令 if text_message.startswith('!') or text_message.startswith(""): # 指令
@@ -114,7 +112,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
else: # 消息 else: # 消息
# 限速丢弃检查 # 限速丢弃检查
# print(ratelimit.__crt_minute_usage__[session_name]) # print(ratelimit.__crt_minute_usage__[session_name])
if config.rate_limit_strategy == "drop": if config['rate_limit_strategy'] == "drop":
if ratelimit.is_reach_limit(session_name): if ratelimit.is_reach_limit(session_name):
logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message)) logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message))
@@ -144,7 +142,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
mgr, config, launcher_type, launcher_id, sender_id) mgr, config, launcher_type, launcher_id, sender_id)
# 限速等待时间 # 限速等待时间
if config.rate_limit_strategy == "wait": if config['rate_limit_strategy'] == "wait":
time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before)) time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before))
ratelimit.add_usage(session_name) ratelimit.add_usage(session_name)
@@ -167,13 +165,13 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
openai_session.get_session(session_name).release_response_lock() openai_session.get_session(session_name).release_response_lock()
# 检查延迟时间 # 检查延迟时间
if config.force_delay_range[1] == 0: if config['force_delay_range'][1] == 0:
delay_time = 0 delay_time = 0
else: else:
import random import random
# 从延迟范围中随机取一个值(浮点) # 从延迟范围中随机取一个值(浮点)
rdm = random.uniform(config.force_delay_range[0], config.force_delay_range[1]) rdm = random.uniform(config['force_delay_range'][0], config['force_delay_range'][1])
spent = time.time() - start_time spent = time.time() - start_time

View File

@@ -3,6 +3,9 @@ import time
import logging import logging
import threading import threading
from ..utils import context
__crt_minute_usage__ = {} __crt_minute_usage__ = {}
"""当前分钟每个会话的对话次数""" """当前分钟每个会话的对话次数"""
@@ -12,16 +15,16 @@ __timer_thr__: threading.Thread = None
def get_limitation(session_name: str) -> int: def get_limitation(session_name: str) -> int:
"""获取会话的限制次数""" """获取会话的限制次数"""
import config config = context.get_config_manager().data
if type(config.rate_limitation) == dict: if type(config['rate_limitation']) == dict:
# 如果被指定了 # 如果被指定了
if session_name in config.rate_limitation: if session_name in config['rate_limitation']:
return config.rate_limitation[session_name] return config['rate_limitation'][session_name]
else: else:
return config.rate_limitation["default"] return config['rate_limitation']["default"]
elif type(config.rate_limitation) == int: elif type(config['rate_limitation']) == int:
return config.rate_limitation return config['rate_limitation']
def add_usage(session_name: str): def add_usage(session_name: str):

View File

@@ -10,6 +10,7 @@ import nakuru.entities.components as nkc
from .. import adapter as adapter_model from .. import adapter as adapter_model
from ...qqbot import blob from ...qqbot import blob
from ...utils import context
class NakuruProjectMessageConverter(adapter_model.MessageConverter): class NakuruProjectMessageConverter(adapter_model.MessageConverter):
@@ -172,12 +173,14 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
self.listener_list = [] self.listener_list = []
# nakuru库有bug这个接口没法带access_token会失败 # nakuru库有bug这个接口没法带access_token会失败
# 所以目前自行发请求 # 所以目前自行发请求
import config
config = context.get_config_manager().data
import requests import requests
resp = requests.get( resp = requests.get(
url="http://{}:{}/get_login_info".format(config.nakuru_config['host'], config.nakuru_config['http_port']), url="http://{}:{}/get_login_info".format(config['nakuru_config']['host'], config['nakuru_config']['http_port']),
headers={ headers={
'Authorization': "Bearer " + config.nakuru_config['token'] if 'token' in config.nakuru_config else "" 'Authorization': "Bearer " + config['nakuru_config']['token'] if 'token' in config['nakuru_config']else ""
}, },
timeout=5 timeout=5
) )
@@ -270,7 +273,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
logging.debug("注册监听器: " + str(event_type) + " -> " + str(callback)) logging.debug("注册监听器: " + str(event_type) + " -> " + str(callback))
# 包装函数 # 包装函数
async def listener_wrapper(app: nakuru.CQHTTP, source: self.event_converter.yiri2target(event_type)): async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.event_converter.yiri2target(event_type)):
callback(self.event_converter.target2yiri(source)) callback(self.event_converter.target2yiri(source))
# 将包装函数和原函数的对应关系存入列表 # 将包装函数和原函数的对应关系存入列表

View File

@@ -3,6 +3,8 @@ import time
import logging import logging
import shutil import shutil
from . import context
log_file_name = "qchatgpt.log" log_file_name = "qchatgpt.log"
@@ -36,7 +38,6 @@ def init_runtime_log_file():
def reset_logging(): def reset_logging():
global log_file_name global log_file_name
import config
import pkg.utils.context import pkg.utils.context
import colorlog import colorlog
@@ -46,7 +47,11 @@ def reset_logging():
for handler in logging.getLogger().handlers: for handler in logging.getLogger().handlers:
logging.getLogger().removeHandler(handler) logging.getLogger().removeHandler(handler)
logging.basicConfig(level=config.logging_level, # 设置日志输出格式 config_mgr = context.get_config_manager()
logging_level = logging.INFO if config_mgr is None else config_mgr.data['logging_level']
logging.basicConfig(level=logging_level, # 设置日志输出格式
filename=log_file_name, # log日志输出的文件位置和文件名 filename=log_file_name, # log日志输出的文件位置和文件名
format="[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", format="[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
# 日志输出的格式 # 日志输出的格式
@@ -54,7 +59,7 @@ def reset_logging():
datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式 datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式
) )
sh = logging.StreamHandler() sh = logging.StreamHandler()
sh.setLevel(config.logging_level) sh.setLevel(logging_level)
sh.setFormatter(colorlog.ColoredFormatter( sh.setFormatter(colorlog.ColoredFormatter(
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : " fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : "
"%(message)s", "%(message)s",

View File

@@ -1,9 +1,11 @@
from . import context
def wrapper_proxies() -> dict: def wrapper_proxies() -> dict:
"""获取代理""" """获取代理"""
import config config = context.get_config_manager().data
return { return {
"http": config.openai_config['proxy'], "http": config['openai_config']['proxy'],
"https": config.openai_config['proxy'] "https": config['openai_config']['proxy']
} if 'proxy' in config.openai_config and (config.openai_config['proxy'] is not None) else None } if 'proxy' in config['openai_config'] and (config['openai_config']['proxy'] is not None) else None

View File

@@ -1,37 +1,42 @@
import logging import logging
import re import re
import os import os
import config
import traceback import traceback
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from ..utils import context
text_render_font: ImageFont = None text_render_font: ImageFont = None
if config.blob_message_strategy == "image": # 仅在启用了image时才加载字体 def initialize():
use_font = config.font_path config = context.get_config_manager().data
try:
# 检查是否存在 if config['blob_message_strategy'] == "image": # 仅在启用了image时才加载字体
if not os.path.exists(use_font): use_font = config['font_path']
# 若是windows系统使用微软雅黑 try:
if os.name == "nt":
use_font = "C:/Windows/Fonts/msyh.ttc" # 检查是否存在
if not os.path.exists(use_font): if not os.path.exists(use_font):
logging.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。") # 若是windows系统使用微软雅黑
config.blob_message_strategy = "forward" if os.name == "nt":
use_font = "C:/Windows/Fonts/msyh.ttc"
if not os.path.exists(use_font):
logging.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。")
config['blob_message_strategy'] = "forward"
else:
logging.info("使用Windows自带字体" + use_font)
text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8")
else: else:
logging.info("使用Windows自带字体" + use_font) logging.warn("未找到字体文件,且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。")
text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") config['blob_message_strategy'] = "forward"
else: else:
logging.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。") text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8")
config.blob_message_strategy = "forward" except:
else: traceback.print_exc()
text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8") logging.error("加载字体文件失败({})更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。".format(use_font))
except: config['blob_message_strategy'] = "forward"
traceback.print_exc()
logging.error("加载字体文件失败({})更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。".format(use_font))
config.blob_message_strategy = "forward"
def indexNumber(path=''): def indexNumber(path=''):