From 07923f71bd639a0024f8ac4a72b3891030ce3ab6 Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Wed, 4 Jan 2023 21:46:01 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E8=81=8A=E5=A4=A9?= =?UTF-8?q?=E7=AA=97=E7=AE=A1=E7=90=86=E5=91=98=E4=BF=AE=E6=94=B9=E8=BF=90?= =?UTF-8?q?=E8=A1=8C=E6=97=B6=E9=85=8D=E7=BD=AE=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/openai/session.py | 30 +++++++++---------- pkg/qqbot/process.py | 68 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 81 insertions(+), 17 deletions(-) diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 4bb5f970..6dffdf47 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -51,16 +51,6 @@ def dump_session(session_name: str): del sessions[session_name] -# 从配置文件获取会话预设信息 -def get_default_prompt(): - import config - user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You' - bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot' - return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \ - and config.default_prompt != "" else '') + \ - bot_name + ":好的\n" - - # def blocked_func(lock: threading.Lock): # # def decorator(func): @@ -83,7 +73,7 @@ def get_default_prompt(): class Session: name = '' - prompt = get_default_prompt() + prompt = "" import config @@ -111,6 +101,15 @@ class Session: self.response_lock.release() logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock)) + # 从配置文件获取会话预设信息 + def get_default_prompt(self): + config = pkg.utils.context.get_config() + user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You' + bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot' + return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \ + and config.default_prompt != "" else '') + \ + bot_name + ":好的\n" + def __init__(self, name: str): self.name = name self.create_timestamp = int(time.time()) @@ -118,6 +117,7 @@ class Session: self.schedule() self.response_lock = threading.Lock() + self.prompt = self.get_default_prompt() # 设定检查session最后一次对话是否超过过期时间的计时器 def schedule(self): @@ -206,7 +206,7 @@ class Session: # 持久化session def persistence(self): - if self.prompt == get_default_prompt(): + if self.prompt == self.get_default_prompt(): return db_inst = pkg.utils.context.get_database_manager() @@ -221,14 +221,14 @@ class Session: # 重置session def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True): - if self.prompt != get_default_prompt(): + if self.prompt != self.get_default_prompt(): self.persistence() if explicit: pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp) if expired: pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp) - self.prompt = get_default_prompt() + self.prompt = self.get_default_prompt() self.create_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time()) self.just_switched_to_exist_session = False @@ -274,7 +274,7 @@ class Session: def list_history(self, capacity: int = 10, page: int = 0): return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page, - get_default_prompt()) + self.get_default_prompt()) def draw_image(self, prompt: str): return pkg.utils.context.get_openai_manager().request_image(prompt) diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index 736caaf4..fd754e6e 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -1,6 +1,7 @@ # 此模块提供了消息处理的具体逻辑的接口 import asyncio import datetime +import json import threading from func_timeout import func_set_timeout @@ -12,7 +13,7 @@ from mirai import Image, MessageChain # 这里不使用动态引入config # 因为在这里动态引入会卡死程序 # 而此模块静态引用config与动态引入的表现一致 -import config +import config as config_init_import import pkg.openai.session import pkg.openai.manager @@ -23,7 +24,7 @@ import pkg.utils.context processing = [] -@func_set_timeout(config.process_message_timeout) +@func_set_timeout(config_init_import.process_message_timeout) def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: MessageChain, sender_id: int) -> MessageChain: global processing @@ -51,6 +52,8 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes processing.append(session_name) + config = pkg.utils.context.get_config() + try: if text_message.startswith('!') or text_message.startswith("!"): # 指令 @@ -189,6 +192,67 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes pkg.utils.context.get_qqbot_manager().notify_admin("更新完成") threading.Thread(target=update_task, daemon=True).start() + elif cmd == 'cfg' and launcher_type == 'person' and launcher_id == config.admin_qq: + reply_str = "" + if len(params) == 0: + reply = ["[bot]err:请输入配置项"] + else: + cfg_name = params[0] + if cfg_name == 'all': + reply_str = "[bot]所有配置项:\n\n" + for cfg in dir(config): + print(cfg) + if not cfg.startswith('__') and not cfg == 'logging': + # 根据配置项类型进行格式化,如果是字典则转换为json并格式化 + if isinstance(getattr(config, cfg), str): + reply_str += "{}: \"{}\"\n".format(cfg, getattr(config, cfg)) + elif isinstance(getattr(config, cfg), dict): + # 不进行unicode转义,并格式化 + reply_str += "{}: {}\n".format(cfg, + json.dumps(getattr(config, cfg), + ensure_ascii=False, indent=4)) + else: + reply_str += "{}: {}\n".format(cfg, getattr(config, cfg)) + reply = [reply_str] + elif cfg_name in dir(config): + if len(params) == 1: + # 按照配置项类型进行格式化 + if isinstance(getattr(config, cfg_name), str): + reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, getattr(config, cfg_name)) + elif isinstance(getattr(config, cfg_name), dict): + reply_str = "[bot]配置项{}: {}\n".format(cfg_name, + json.dumps(getattr(config, cfg_name), + ensure_ascii=False, indent=4)) + else: + reply_str = "[bot]配置项{}: {}\n".format(cfg_name, getattr(config, cfg_name)) + reply = [reply_str] + else: + cfg_value = " ".join(params[1:]) + # 类型转换,如果是json则转换为字典 + if cfg_value == 'true': + cfg_value = True + elif cfg_value == 'false': + cfg_value = False + elif cfg_value.isdigit(): + cfg_value = int(cfg_value) + elif cfg_value.startswith('{') and cfg_value.endswith('}'): + cfg_value = json.loads(cfg_value) + else: + try: + cfg_value = float(cfg_value) + except ValueError: + pass + + # 检查类型是否匹配 + if isinstance(getattr(config, cfg_name), type(cfg_value)): + setattr(config, cfg_name, cfg_value) + pkg.utils.context.set_config(config) + reply = ["[bot]配置项{}修改成功".format(cfg_name)] + else: + reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)] + + else: + reply = ["[bot]err:未找到配置项 {}".format(cfg_name)] else: reply = ["[bot]err:未知的指令或权限不足: "+cmd] except Exception as e: