From 6b8fa664f14dea5edd2f256ae4267fb3d2418153 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Mon, 31 Jul 2023 17:21:09 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9EPromptPreprocessing?= =?UTF-8?q?=E4=BA=8B=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/openai/session.py | 34 ++++++++++++++++++++++++++++------ pkg/plugin/host.py | 2 +- pkg/plugin/models.py | 14 ++++++++++++++ 3 files changed, 43 insertions(+), 7 deletions(-) diff --git a/pkg/openai/session.py b/pkg/openai/session.py index ac85c408..7bb0368f 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -214,7 +214,29 @@ class Session: config = pkg.utils.context.get_config() max_length = config.prompt_submit_length - prompts, _ = self.cut_out(text, max_length) + local_default_prompt = self.default_prompt.copy() + local_prompt = self.prompt.copy() + + # 触发PromptPreProcessing事件 + args = { + 'session_name': self.name, + 'default_prompt': self.default_prompt, + 'prompt': self.prompt, + 'text_message': text, + } + + event = pkg.plugin.host.emit(plugin_models.PromptPreProcessing, **args) + + if event.get_return_value('default_prompt') is not None: + local_default_prompt = event.get_return_value('default_prompt') + + if event.get_return_value('prompt') is not None: + local_prompt = event.get_return_value('prompt') + + if event.get_return_value('text_message') is not None: + text = event.get_return_value('text_message') + + prompts, _ = self.cut_out(text, max_length, local_default_prompt, local_prompt) res_text = "" @@ -301,7 +323,7 @@ class Session: return question # 构建对话体 - def cut_out(self, msg: str, max_tokens: int) -> tuple[list, list]: + def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) -> tuple[list, list]: """将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens :return: (新的prompt, 新的token_counts) @@ -317,19 +339,19 @@ class Session: use_model = pkg.utils.context.get_config().completion_api_params['model'] - ptr = len(self.prompt) - 1 + ptr = len(prompt) - 1 # 直接从后向前扫描拼接,不管是否是整回合 while ptr >= 0: - if count_tokens(self.prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens: + if count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens: break - changable_prompts.insert(0, self.prompt[ptr]) + changable_prompts.insert(0, prompt[ptr]) ptr -= 1 # 将default_prompt和changable_prompts合并 - result_prompt = self.default_prompt + changable_prompts + result_prompt = default_prompt + changable_prompts # 添加当前问题 if msg: diff --git a/pkg/plugin/host.py b/pkg/plugin/host.py index 12269bf4..7c8ed7d5 100644 --- a/pkg/plugin/host.py +++ b/pkg/plugin/host.py @@ -266,7 +266,7 @@ class EventContext: self.__return_value__[key] = [] self.__return_value__[key].append(ret) - def get_return(self, key: str): + def get_return(self, key: str) -> list: """获取key的所有返回值""" if key in self.__return_value__: return self.__return_value__[key] diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py index 40756757..c0f2ca8b 100644 --- a/pkg/plugin/models.py +++ b/pkg/plugin/models.py @@ -132,6 +132,20 @@ KeySwitched = "key_switched" key_list: list[str] api-key列表 """ +PromptPreProcessing = "prompt_pre_processing" +"""每回合调用接口前对prompt进行预处理时触发,此事件不支持阻止默认行为 + kwargs: + session_name: str 会话名称(_) + default_prompt: list 此session使用的情景预设内容 + prompt: list 此session现有的prompt内容 + text_message: str 用户发送的消息文本 + + returns (optional): + default_prompt: list 修改后的情景预设内容 + prompt: list 修改后的prompt内容 + text_message: str 修改后的消息文本 +""" + def on(*args, **kwargs): """注册事件监听器