From bb4b897934c04a6361bd3784fb8e1ad146a00877 Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Sun, 26 Mar 2023 13:28:26 +0000 Subject: [PATCH] =?UTF-8?q?feat(dprompt.py):=20=E8=A7=A3=E8=80=A6=E5=AE=8C?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 5 +- pkg/openai/dprompt.py | 222 +++++++++++++++++++++----------------- pkg/openai/session.py | 4 +- pkg/qqbot/cmds/session.py | 104 ++++++++++-------- 4 files changed, 187 insertions(+), 148 deletions(-) diff --git a/main.py b/main.py index 16130fe6..43c3fc8d 100644 --- a/main.py +++ b/main.py @@ -191,9 +191,8 @@ def start(first_time_init=False): import pkg.openai.session import pkg.qqbot.manager import pkg.openai.dprompt - - pkg.openai.dprompt.read_prompt_from_file() - pkg.openai.dprompt.read_scenario_from_file() + + pkg.openai.dprompt.register_all() # 主启动流程 database = pkg.database.manager.DatabaseManager() diff --git a/pkg/openai/dprompt.py b/pkg/openai/dprompt.py index ce09ab80..adb0a4d8 100644 --- a/pkg/openai/dprompt.py +++ b/pkg/openai/dprompt.py @@ -1,121 +1,145 @@ # 多情景预设值管理 import json import logging +import config +import os -__current__ = "default" -"""当前默认使用的情景预设的名称 +# __current__ = "default" +# """当前默认使用的情景预设的名称 -由管理员使用`!default <名称>`指令切换 -""" +# 由管理员使用`!default <名称>`指令切换 +# """ -__prompts_from_files__ = {} -"""从文件中读取的情景预设值""" +# __prompts_from_files__ = {} +# """从文件中读取的情景预设值""" -__scenario_from_files__ = {} +# __scenario_from_files__ = {} -def read_prompt_from_file(): - """从文件读取预设值""" - # 读取prompts/目录下的所有文件,以文件名为键,文件内容为值 - # 保存在__prompts_from_files__中 - global __prompts_from_files__ - import os - - __prompts_from_files__ = {} - for file in os.listdir("prompts"): - with open(os.path.join("prompts", file), encoding="utf-8") as f: - __prompts_from_files__[file] = f.read() +__universal_first_reply__ = "ok, I'll follow your commands." +"""通用首次回复""" -def read_scenario_from_file(): - """从JSON文件读取情景预设""" - global __scenario_from_files__ - import os +class ScenarioMode: + """情景预设模式抽象类""" - __scenario_from_files__ = {} - for file in os.listdir("scenario"): - if file == "default-template.json": - continue - with open(os.path.join("scenario", file), encoding="utf-8") as f: - __scenario_from_files__[file] = json.load(f) + using_prompt_name = "default" + """新session创建时使用的prompt名称""" + + prompts: dict[str, list] = {} + + def __init__(self): + logging.debug("prompts: {}".format(self.prompts)) + + def list(self) -> dict[str, list]: + """获取所有情景预设的名称及内容""" + return self.prompts + + def get_prompt(self, name: str) -> tuple[list, str]: + """获取指定情景预设的名称及内容""" + for key in self.prompts: + if key.startswith(name): + return self.prompts[key], key + raise Exception("没有找到情景预设: {}".format(name)) + + def set_using_name(self, name: str) -> str: + """设置默认情景预设""" + for key in self.prompts: + if key.startswith(name): + self.using_prompt_name = key + return key + raise Exception("没有找到情景预设: {}".format(name)) + + def get_full_name(self, name: str) -> str: + """获取完整的情景预设名称""" + for key in self.prompts: + if key.startswith(name): + return key + raise Exception("没有找到情景预设: {}".format(name)) + + def get_using_name(self) -> str: + """获取默认情景预设""" + return self.using_prompt_name -def get_prompt_dict() -> dict: - """获取预设值字典""" - import config - default_prompt = config.default_prompt - if type(default_prompt) == str: - default_prompt = {"default": default_prompt} - elif type(default_prompt) == dict: - pass - else: - raise TypeError("default_prompt must be str or dict") +class NormalScenarioMode(ScenarioMode): + """普通情景预设模式""" - # 将文件中的预设值合并到default_prompt中 - for key in __prompts_from_files__: - default_prompt[key] = __prompts_from_files__[key] - - return default_prompt - - -def set_current(name): - global __current__ - for key in get_prompt_dict(): - if key.lower().startswith(name.lower()): - __current__ = key - return - raise KeyError("未找到情景预设: " + name) - - -def get_current(): - global __current__ - return __current__ - - -def set_to_default(): - global __current__ - default_dict = get_prompt_dict() - - if "default" in default_dict: - __current__ = "default" - else: - __current__ = list(default_dict.keys())[0] - - -def get_prompt(name: str = None) -> list: - global __scenario_from_files__ - import config - preset_mode = config.preset_mode - - """获取预设值""" - if name is None: - name = get_current() - - # JSON预设方式 - if preset_mode == 'full_scenario': - import os - - for key in __scenario_from_files__: - if key.lower().startswith(name.lower()): - logging.debug('成功加载情景预设从JSON文件: {}'.format(key)) - return __scenario_from_files__[key]['prompt'] + def __init__(self): + global __universal_first_reply__ + # 加载config中的default_prompt值 + if type(config.default_prompt) == str: + self.using_prompt_name = "default" + self.prompts = {"default": [ + { + "role": "user", + "content": config.default_prompt + },{ + "role": "assistant", + "content": __universal_first_reply__ + } + ]} - # 默认预设方式 - elif preset_mode == 'default' or preset_mode == 'normal': - - default_dict = get_prompt_dict() - - for key in default_dict: - if key.lower().startswith(name.lower()): - return [ + elif type(config.default_prompt) == dict: + for key in config.default_prompt: + self.prompts[key] = [ { "role": "user", - "content": default_dict[key] - }, - { + "content": config.default_prompt[key] + },{ "role": "assistant", - "content": "好的。" + "content": __universal_first_reply__ } ] - raise KeyError("未找到默认情景预设: " + name) + # 从prompts/目录下的文件中载入 + # 遍历文件 + for file in os.listdir("prompts"): + with open(os.path.join("prompts", file), encoding="utf-8") as f: + self.prompts[file] = [ + { + "role": "user", + "content": f.read() + },{ + "role": "assistant", + "content": __universal_first_reply__ + } + ] + + +class FullScenarioMode(ScenarioMode): + """完整情景预设模式""" + + def __init__(self): + """从json读取所有""" + # 遍历scenario/目录下的所有文件,以文件名为键,文件内容中的prompt为值 + for file in os.listdir("scenario"): + if file == "default-template.json": + continue + with open(os.path.join("scenario", file), encoding="utf-8") as f: + self.prompts[file] = json.load(f)["prompt"] + + super().__init__() + + +scenario_mode_mapping = {} +"""情景预设模式名称与对象的映射""" + + +def register_all(): + """注册所有情景预设模式,不使用装饰器,因为装饰器的方式不支持热重载""" + global scenario_mode_mapping + scenario_mode_mapping = { + "normal": NormalScenarioMode(), + "full_scenario": FullScenarioMode() + } + + +def mode_inst() -> ScenarioMode: + """获取指定名称的情景预设模式对象""" + import config + + if config.preset_mode == "default": + config.preset_mode = "normal" + + return scenario_mode_mapping[config.preset_mode] diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 6bef950d..6575de6e 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -141,9 +141,9 @@ class Session: import pkg.openai.dprompt as dprompt if use_default is None: - use_default = dprompt.get_current() + use_default = dprompt.mode_inst().get_using_name() - current_default_prompt = dprompt.get_prompt(use_default) + current_default_prompt, _ = dprompt.mode_inst().get_prompt(use_default) return current_default_prompt def __init__(self, name: str): diff --git a/pkg/qqbot/cmds/session.py b/pkg/qqbot/cmds/session.py index 173eefd6..aa395016 100644 --- a/pkg/qqbot/cmds/session.py +++ b/pkg/qqbot/cmds/session.py @@ -7,6 +7,7 @@ import pkg.openai.session import pkg.utils.context import config + @command( "reset", "重置当前会话", @@ -14,19 +15,23 @@ import config [], False ) -def cmd_reset(cmd: str, params: list, session_name: str, +def cmd_reset(cmd: str, params: list, session_name: str, text_message: str, launcher_type: str, launcher_id: int, - sender_id: int, is_admin: bool) -> list: + sender_id: int, is_admin: bool) -> list: """重置会话""" reply = [] - + if len(params) == 0: pkg.openai.session.get_session(session_name).reset(explicit=True) reply = ["[bot]会话已重置"] else: - pkg.openai.session.get_session(session_name).reset(explicit=True, use_prompt=params[0]) - reply = ["[bot]会话已重置,使用场景预设:{}".format(params[0])] - + try: + import pkg.openai.dprompt as dprompt + pkg.openai.session.get_session(session_name).reset(explicit=True, use_prompt=params[0]) + reply = ["[bot]会话已重置,使用场景预设:{}".format(dprompt.mode_inst().get_full_name(params[0]))] + except Exception as e: + reply = ["[bot]会话重置失败,错误信息:{}".format(e)] + return reply @@ -37,9 +42,9 @@ def cmd_reset(cmd: str, params: list, session_name: str, [], False ) -def cmd_last(cmd: str, params: list, session_name: str, +def cmd_last(cmd: str, params: list, session_name: str, text_message: str, launcher_type: str, launcher_id: int, - sender_id: int, is_admin: bool) -> list: + sender_id: int, is_admin: bool) -> list: """切换到前一次会话""" reply = [] result = pkg.openai.session.get_session(session_name).last_session() @@ -52,6 +57,7 @@ def cmd_last(cmd: str, params: list, session_name: str, return reply + @command( "next", "切换到后一次会话", @@ -59,12 +65,12 @@ def cmd_last(cmd: str, params: list, session_name: str, [], False ) -def cmd_next(cmd: str, params: list, session_name: str, +def cmd_next(cmd: str, params: list, session_name: str, text_message: str, launcher_type: int, launcher_id: int, - sender_id: int, is_admin: bool) -> list: + sender_id: int, is_admin: bool) -> list: """切换到后一次会话""" reply = [] - + result = pkg.openai.session.get_session(session_name).next_session() if result is None: reply = ["[bot]没有后一次的对话"] @@ -84,13 +90,13 @@ def cmd_next(cmd: str, params: list, session_name: str, False ) def cmd_prompt(cmd: str, params: list, session_name: str, - text_message: str, launcher_type: str, launcher_id: int, - sender_id: int, is_admin: bool) -> list: + text_message: str, launcher_type: str, launcher_id: int, + sender_id: int, is_admin: bool) -> list: """获取当前会话的前文""" reply = [] - + msgs = "" - session:list = pkg.openai.session.get_session(session_name).prompt + session: list = pkg.openai.session.get_session(session_name).prompt for msg in session: if len(params) != 0 and params[0] in ['-all', '-a']: msgs = msgs + "{}: {}\n\n".format(msg['role'], msg['content']) @@ -111,11 +117,11 @@ def cmd_prompt(cmd: str, params: list, session_name: str, False ) def cmd_list(cmd: str, params: list, session_name: str, - text_message: str, launcher_type: str, launcher_id: int, - sender_id: int, is_admin: bool) -> list: + text_message: str, launcher_type: str, launcher_id: int, + sender_id: int, is_admin: bool) -> list: """列出当前会话的所有历史记录""" reply = [] - + pkg.openai.session.get_session(session_name).persistence() page = 0 @@ -143,12 +149,12 @@ def cmd_list(cmd: str, params: list, session_name: str, pkg.openai.session.get_session(session_name).persistence() if len(msg) >= 2: reply_str += "#{} 创建:{} {}\n".format(i + page * 10, - datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), - msg[0]['content']) + datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), + msg[0]['content']) else: reply_str += "#{} 创建:{} {}\n".format(i + page * 10, - datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), - "无内容") + datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), + "无内容") if results[i]['create_timestamp'] == pkg.openai.session.get_session( session_name).create_timestamp: current = i + page * 10 @@ -172,19 +178,19 @@ def cmd_list(cmd: str, params: list, session_name: str, False ) def cmd_resend(cmd: str, params: list, session_name: str, - text_message: str, launcher_type: str, launcher_id: int, - sender_id: int, is_admin: bool) -> list: + text_message: str, launcher_type: str, launcher_id: int, + sender_id: int, is_admin: bool) -> list: """重新获取上一次问题的回复""" reply = [] - + session = pkg.openai.session.get_session(session_name) to_send = session.undo() mgr = pkg.utils.context.get_qqbot_manager() reply = pkg.qqbot.message.process_normal_message(to_send, mgr, config, - launcher_type, launcher_id, sender_id) - + launcher_type, launcher_id, sender_id) + return reply @@ -196,11 +202,11 @@ def cmd_resend(cmd: str, params: list, session_name: str, False ) def cmd_del(cmd: str, params: list, session_name: str, - text_message: str, launcher_type: str, launcher_id: int, - sender_id: int, is_admin: bool) -> list: + text_message: str, launcher_type: str, launcher_id: int, + sender_id: int, is_admin: bool) -> list: """删除当前会话的历史记录""" reply = [] - + if len(params) == 0: reply = ["[bot]参数不足, 格式: !del <序号>\n可以通过!list查看序号"] else: @@ -226,28 +232,37 @@ def cmd_del(cmd: str, params: list, session_name: str, ) def cmd_default(cmd: str, params: list, session_name: str, text_message: str, launcher_type: str, launcher_id: int, - sender_id: int, is_admin: bool) -> list: + sender_id: int, is_admin: bool) -> list: """操作情景预设""" reply = [] if len(params) == 0: # 输出目前所有情景预设 import pkg.openai.dprompt as dprompt - reply_str = "[bot]当前所有情景预设:\n\n" - for key, value in dprompt.get_prompt_dict().items(): - reply_str += " - {}: {}\n".format(key, value) + reply_str = "[bot]当前所有情景预设({}模式):\n\n".format(config.preset_mode) - reply_str += "\n当前默认情景预设:{}\n".format(dprompt.get_current()) - reply_str += "请使用!default <情景预设>来设置默认情景预设" + prompts = dprompt.mode_inst().list() + + for key in prompts: + pro = prompts[key] + reply_str += "名称: {}".format(key) + + for r in pro: + reply_str += "\n - [{}]: {}".format(r['role'], r['content']) + + reply_str += "\n\n" + + reply_str += "\n当前默认情景预设:{}\n".format(dprompt.mode_inst().get_using_name()) + reply_str += "请使用 !default <情景预设名称> 来设置默认情景预设" reply = [reply_str] elif len(params) > 0 and is_admin: # 设置默认情景 import pkg.openai.dprompt as dprompt try: - dprompt.set_current(params[0]) - reply = ["[bot]已设置默认情景预设为:{}".format(dprompt.get_current())] - except KeyError: - reply = ["[bot]err: 未找到情景预设:{}".format(params[0])] + full_name = dprompt.mode_inst().set_using_name(params[0]) + reply = ["[bot]已设置默认情景预设为:{}".format(full_name)] + except Exception as e: + reply = ["[bot]err: {}".format(e)] else: reply = ["[bot]err: 仅管理员可设置默认情景预设"] @@ -262,13 +277,14 @@ def cmd_default(cmd: str, params: list, session_name: str, True ) def cmd_delhst(cmd: str, params: list, session_name: str, - text_message: str, launcher_type: str, launcher_id: int, - sender_id: int, is_admin: bool) -> list: + text_message: str, launcher_type: str, launcher_id: int, + sender_id: int, is_admin: bool) -> list: """删除指定会话的所有历史记录""" reply = [] - + if len(params) == 0: - reply = ["[bot]err:请输入要删除的会话名: group_<群号> 或者 person_, 或使用 !delhst all 删除所有会话的历史记录"] + reply = [ + "[bot]err:请输入要删除的会话名: group_<群号> 或者 person_, 或使用 !delhst all 删除所有会话的历史记录"] else: if params[0] == "all": pkg.utils.context.get_database_manager().delete_all_session_history()