diff --git a/pkg/openai/session.py b/pkg/openai/session.py index ac85c408..7bb0368f 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -214,7 +214,29 @@ class Session: config = pkg.utils.context.get_config() max_length = config.prompt_submit_length - prompts, _ = self.cut_out(text, max_length) + local_default_prompt = self.default_prompt.copy() + local_prompt = self.prompt.copy() + + # 触发PromptPreProcessing事件 + args = { + 'session_name': self.name, + 'default_prompt': self.default_prompt, + 'prompt': self.prompt, + 'text_message': text, + } + + event = pkg.plugin.host.emit(plugin_models.PromptPreProcessing, **args) + + if event.get_return_value('default_prompt') is not None: + local_default_prompt = event.get_return_value('default_prompt') + + if event.get_return_value('prompt') is not None: + local_prompt = event.get_return_value('prompt') + + if event.get_return_value('text_message') is not None: + text = event.get_return_value('text_message') + + prompts, _ = self.cut_out(text, max_length, local_default_prompt, local_prompt) res_text = "" @@ -301,7 +323,7 @@ class Session: return question # 构建对话体 - def cut_out(self, msg: str, max_tokens: int) -> tuple[list, list]: + def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) -> tuple[list, list]: """将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens :return: (新的prompt, 新的token_counts) @@ -317,19 +339,19 @@ class Session: use_model = pkg.utils.context.get_config().completion_api_params['model'] - ptr = len(self.prompt) - 1 + ptr = len(prompt) - 1 # 直接从后向前扫描拼接,不管是否是整回合 while ptr >= 0: - if count_tokens(self.prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens: + if count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens: break - changable_prompts.insert(0, self.prompt[ptr]) + changable_prompts.insert(0, prompt[ptr]) ptr -= 1 # 将default_prompt和changable_prompts合并 - result_prompt = self.default_prompt + changable_prompts + result_prompt = default_prompt + changable_prompts # 添加当前问题 if msg: diff --git a/pkg/plugin/host.py b/pkg/plugin/host.py index 12269bf4..7c8ed7d5 100644 --- a/pkg/plugin/host.py +++ b/pkg/plugin/host.py @@ -266,7 +266,7 @@ class EventContext: self.__return_value__[key] = [] self.__return_value__[key].append(ret) - def get_return(self, key: str): + def get_return(self, key: str) -> list: """获取key的所有返回值""" if key in self.__return_value__: return self.__return_value__[key] diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py index 40756757..6a6ba9d3 100644 --- a/pkg/plugin/models.py +++ b/pkg/plugin/models.py @@ -132,6 +132,20 @@ KeySwitched = "key_switched" key_list: list[str] api-key列表 """ +PromptPreProcessing = "prompt_pre_processing" +"""每回合调用接口前对prompt进行预处理时触发,此事件不支持阻止默认行为 + kwargs: + session_name: str 会话名称(_) + default_prompt: list 此session使用的情景预设内容 + prompt: list 此session现有的prompt内容 + text_message: str 用户发送的消息文本 + + returns (optional): + default_prompt: list 修改后的情景预设内容 + prompt: list 修改后的prompt内容 + text_message: str 修改后的消息文本 +""" + def on(*args, **kwargs): """注册事件监听器 @@ -150,6 +164,32 @@ def func(*args, **kwargs): __current_registering_plugin__ = "" +def require_ver(ge: str, le: str="v999.9.9") -> bool: + """插件版本要求装饰器 + + Args: + ge (str): 最低版本要求 + le (str, optional): 最高版本要求 + + Returns: + bool: 是否满足要求, False时为无法获取版本号,True时为满足要求,报错为不满足要求 + """ + qchatgpt_version = "" + + from pkg.utils.updater import get_current_tag, compare_version_str + + try: + qchatgpt_version = get_current_tag() # 从updater模块获取版本号 + except: + return False + + if compare_version_str(qchatgpt_version, ge) < 0 or \ + (compare_version_str(qchatgpt_version, le) > 0): + raise Exception("QChatGPT 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{})".format(ge, le, qchatgpt_version)) + + return True + + class Plugin: """插件基类""" diff --git a/pkg/utils/updater.py b/pkg/utils/updater.py index 9a661e72..956a36bb 100644 --- a/pkg/utils/updater.py +++ b/pkg/utils/updater.py @@ -78,6 +78,34 @@ def get_current_tag() -> str: return current_tag +def compare_version_str(v0: str, v1: str) -> int: + """比较两个版本号""" + + # 删除版本号前的v + if v0.startswith("v"): + v0 = v0[1:] + if v1.startswith("v"): + v1 = v1[1:] + + v0:list = v0.split(".") + v1:list = v1.split(".") + + # 如果两个版本号节数不同,把短的后面用0补齐 + if len(v0) < len(v1): + v0.extend(["0"]*(len(v1)-len(v0))) + elif len(v0) > len(v1): + v1.extend(["0"]*(len(v0)-len(v1))) + + # 从高位向低位比较 + for i in range(len(v0)): + if int(v0[i]) > int(v1[i]): + return 1 + elif int(v0[i]) < int(v1[i]): + return -1 + + return 0 + + def update_all(cli: bool = False) -> bool: """检查更新并下载源码""" current_tag = get_current_tag() diff --git a/res/wiki/插件开发.md b/res/wiki/插件开发.md index 3e3f1ffa..f08e13d8 100644 --- a/res/wiki/插件开发.md +++ b/res/wiki/插件开发.md @@ -291,6 +291,21 @@ class HelloPlugin(Plugin): - 这仅仅是一个示例,需要更高效的网络访问能力支持插件,请查看[WebwlkrPlugin](https://github.com/RockChinQ/WebwlkrPlugin) +## 🔒版本要求 + +若您的插件对主程序的版本有要求,可以使用以下函数进行断言,若不符合版本,此函数将报错并打断此函数所在的流程: + +```python +require_ver("v2.5.1") # 要求最低版本为 v2.5.1 +``` + +```python +require_ver("v2.5.1", "v2.6.0") # 要求最低版本为 v2.5.1, 同时要求最高版本为 v2.6.0 +``` + +- 此函数在主程序`v2.5.1`中加入 +- 此函数声明在`pkg.plugin.models`模块中,在插件示例代码最前方已引入此模块所有内容,故可直接使用 + ## 📄API参考 ### 说明 @@ -435,6 +450,20 @@ KeySwitched = "key_switched" key_name: str 切换成功的api-key名称 key_list: list[str] api-key列表 """ + +PromptPreProcessing = "prompt_pre_processing" # 于v2.5.1加入 +"""每回合调用接口前对prompt进行预处理时触发,此事件不支持阻止默认行为 + kwargs: + session_name: str 会话名称(_) + default_prompt: list 此session使用的情景预设内容 + prompt: list 此session现有的prompt内容 + text_message: str 用户发送的消息文本 + + returns (optional): + default_prompt: list 修改后的情景预设内容 + prompt: list 修改后的prompt内容 + text_message: str 修改后的消息文本 +""" ``` ### host: PluginHost 详解