From 5ada507c2baafe9906e8cb1fe02a9fb74e1a9baf Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Sat, 25 Feb 2023 17:05:39 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=AD=98=E6=94=BE=E6=83=85=E6=99=AF=E9=A2=84?= =?UTF-8?q?=E8=AE=BE=20(#167)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + main.py | 6 ++++++ pkg/openai/dprompt.py | 29 +++++++++++++++++++++++++++-- 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 8cb06a46..f78cf750 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ banlist.py plugins/ !plugins/__init__.py /revcfg.py +prompts/ \ No newline at end of file diff --git a/main.py b/main.py index 6b22c55f..b7e08d02 100644 --- a/main.py +++ b/main.py @@ -82,6 +82,12 @@ def reset_logging(): def main(first_time_init=False): global known_exception_caught + # 检查并创建plugins、prompts目录 + check_path = ["plugins", "prompts"] + for path in check_path: + if not os.path.exists(path): + os.mkdir(path) + known_exception_caught = False try: # 导入config.py diff --git a/pkg/openai/dprompt.py b/pkg/openai/dprompt.py index 30fa53fa..f044dcdb 100644 --- a/pkg/openai/dprompt.py +++ b/pkg/openai/dprompt.py @@ -2,17 +2,39 @@ __current__ = "default" +__prompts_from_files__ = {} + + +def read_prompt_from_file() -> str: + """从文件读取预设值""" + # 读取prompts/目录下的所有文件,以文件名为键,文件内容为值 + # 保存在__prompts_from_files__中 + global __prompts_from_files__ + import os + + __prompts_from_files__ = {} + for file in os.listdir("prompts"): + with open(os.path.join("prompts", file), encoding="utf-8") as f: + __prompts_from_files__[file] = f.read() + + def get_prompt_dict() -> dict: """获取预设值字典""" import config default_prompt = config.default_prompt if type(default_prompt) == str: - return {"default": default_prompt} + default_prompt = {"default": default_prompt} elif type(default_prompt) == dict: - return default_prompt + pass else: raise TypeError("default_prompt must be str or dict") + # 将文件中的预设值合并到default_prompt中 + for key in __prompts_from_files__: + default_prompt[key] = __prompts_from_files__[key] + + return default_prompt + def set_current(name): global __current__ @@ -48,4 +70,7 @@ def get_prompt(name: str = None) -> str: for key in default_dict: if key.lower().startswith(name.lower()): return default_dict[key] + raise KeyError("未找到情景预设: " + name) + +read_prompt_from_file()