From af8c21f3d4ee42be074b08ca791f8fb34618ebfe Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 21 Dec 2023 18:19:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=8C=E5=96=84=20=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E4=BA=8B=E4=BB=B6=E8=B0=83=E7=94=A8=E6=8A=A5=E5=91=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/plugin/host.py | 191 +++++++++++++++++++++++++++------------------ 1 file changed, 113 insertions(+), 78 deletions(-) diff --git a/pkg/plugin/host.py b/pkg/plugin/host.py index d65a0916..631012a2 100644 --- a/pkg/plugin/host.py +++ b/pkg/plugin/host.py @@ -84,23 +84,34 @@ def iter_plugins_name(): __current_module_path__ = "" -def walk_plugin_path(module, prefix='', path_prefix=''): +def walk_plugin_path(module, prefix="", path_prefix=""): global __current_module_path__ """遍历插件路径""" for item in pkgutil.iter_modules(module.__path__): if item.ispkg: logging.debug("扫描插件包: plugins/{}".format(path_prefix + item.name)) - walk_plugin_path(__import__(module.__name__ + '.' + item.name, fromlist=['']), - prefix + item.name + '.', path_prefix + item.name + '/') + walk_plugin_path( + __import__(module.__name__ + "." + item.name, fromlist=[""]), + prefix + item.name + ".", + path_prefix + item.name + "/", + ) else: try: - logging.debug("扫描插件模块: plugins/{}".format(path_prefix + item.name + '.py')) - __current_module_path__ = "plugins/"+path_prefix + item.name + '.py' + logging.debug( + "扫描插件模块: plugins/{}".format(path_prefix + item.name + ".py") + ) + __current_module_path__ = "plugins/" + path_prefix + item.name + ".py" - importlib.import_module(module.__name__ + '.' + item.name) - logging.debug('加载模块: plugins/{} 成功'.format(path_prefix + item.name + '.py')) + importlib.import_module(module.__name__ + "." + item.name) + logging.debug( + "加载模块: plugins/{} 成功".format(path_prefix + item.name + ".py") + ) except: - logging.error('加载模块: plugins/{} 失败: {}'.format(path_prefix + item.name + '.py', sys.exc_info())) + logging.error( + "加载模块: plugins/{} 失败: {}".format( + path_prefix + item.name + ".py", sys.exc_info() + ) + ) traceback.print_exc() @@ -108,7 +119,7 @@ def load_plugins(): """加载插件""" logging.debug("加载插件") PluginHost() - walk_plugin_path(__import__('plugins')) + walk_plugin_path(__import__("plugins")) logging.debug(__plugins__) @@ -141,14 +152,14 @@ def initialize_plugins(): # if not plugin['enabled']: # continue try: - models.__current_registering_plugin__ = plugin['name'] - plugin['instance'] = plugin["class"](plugin_host=context.get_plugin_host()) + models.__current_registering_plugin__ = plugin["name"] + plugin["instance"] = plugin["class"](plugin_host=context.get_plugin_host()) # logging.info("插件 {} 已初始化".format(plugin['name'])) - successfully_initialized_plugins.append(plugin['name']) + successfully_initialized_plugins.append(plugin["name"]) except: - logging.error("插件{}初始化时发生错误: {}".format(plugin['name'], sys.exc_info())) + logging.error("插件{}初始化时发生错误: {}".format(plugin["name"], sys.exc_info())) logging.debug(traceback.format_exc()) - + logging.info("以下插件已初始化: {}".format(", ".join(successfully_initialized_plugins))) @@ -172,9 +183,12 @@ def get_github_plugin_repo_label(repo_url: str) -> list[str]: """获取username, repo""" # 提取 username/repo , 正则表达式 - repo = re.findall(r'(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)', repo_url) + repo = re.findall( + r"(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)", + repo_url, + ) - if len(repo) > 0: # github + if len(repo) > 0: # github return repo[0].split("/") else: return None @@ -183,53 +197,52 @@ def get_github_plugin_repo_label(repo_url: str) -> list[str]: def download_plugin_source_code(repo_url: str, target_path: str) -> str: """下载插件源码""" # 检查源类型 - + # 提取 username/repo , 正则表达式 - repo = get_github_plugin_repo_label(repo_url) + repo = get_github_plugin_repo_label(repo_url) target_path += repo[1] - if repo is not None: # github + if repo is not None: # github logging.info("从 GitHub 下载插件源码...") zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD" zip_resp = requests.get( - url=zipball_url, - proxies=network.wrapper_proxies(), - stream=True + url=zipball_url, proxies=network.wrapper_proxies(), stream=True ) if zip_resp.status_code != 200: raise Exception("下载源码失败: {}".format(zip_resp.text)) - - if os.path.exists("temp/"+target_path): - shutil.rmtree("temp/"+target_path) + + if os.path.exists("temp/" + target_path): + shutil.rmtree("temp/" + target_path) if os.path.exists(target_path): shutil.rmtree(target_path) - os.makedirs("temp/"+target_path) + os.makedirs("temp/" + target_path) - with open("temp/"+target_path+"/source.zip", "wb") as f: + with open("temp/" + target_path + "/source.zip", "wb") as f: for chunk in zip_resp.iter_content(chunk_size=1024): if chunk: f.write(chunk) logging.info("下载完成, 解压...") import zipfile - with zipfile.ZipFile("temp/"+target_path+"/source.zip", 'r') as zip_ref: - zip_ref.extractall("temp/"+target_path) - os.remove("temp/"+target_path+"/source.zip") + + with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref: + zip_ref.extractall("temp/" + target_path) + os.remove("temp/" + target_path + "/source.zip") # 目标是 username-repo-hash , 用正则表达式提取完整的文件夹名,复制到 plugins/repo import glob # 获取解压后的文件夹名 - unzip_dir = glob.glob("temp/"+target_path+"/*")[0] + unzip_dir = glob.glob("temp/" + target_path + "/*")[0] # 复制到 plugins/repo - shutil.copytree(unzip_dir, target_path+"/") + shutil.copytree(unzip_dir, target_path + "/") # 删除解压后的文件夹 shutil.rmtree(unzip_dir) @@ -237,18 +250,20 @@ def download_plugin_source_code(repo_url: str, target_path: str) -> str: logging.info("解压完成") else: raise Exception("暂不支持的源类型,请使用 GitHub 仓库发行插件。") - + return repo[1] def check_requirements(path: str): # 检查此目录是否包含requirements.txt - if os.path.exists(path+"/requirements.txt"): + if os.path.exists(path + "/requirements.txt"): logging.info("检测到requirements.txt,正在安装依赖") import pkg.utils.pkgmgr - pkg.utils.pkgmgr.install_requirements(path+"/requirements.txt") + + pkg.utils.pkgmgr.install_requirements(path + "/requirements.txt") import pkg.utils.log as log + log.reset_logging() @@ -257,7 +272,7 @@ def install_plugin(repo_url: str): repo_label = download_plugin_source_code(repo_url, "plugins/") - check_requirements("plugins/"+repo_label) + check_requirements("plugins/" + repo_label) metadata.set_plugin_metadata(repo_label, repo_url, int(time.time()), "HEAD") @@ -266,16 +281,16 @@ def uninstall_plugin(plugin_name: str) -> str: """卸载插件""" if plugin_name not in __plugins__: raise Exception("插件不存在") - + # 获取文件夹路径 - plugin_path = __plugins__[plugin_name]['path'].replace("\\", "/") + plugin_path = __plugins__[plugin_name]["path"].replace("\\", "/") # 剪切路径为plugins/插件名 plugin_path = plugin_path.split("plugins/")[1].split("/")[0] # 删除文件夹 - shutil.rmtree("plugins/"+plugin_path) - return "plugins/"+plugin_path + shutil.rmtree("plugins/" + plugin_path) + return "plugins/" + plugin_path def update_plugin(plugin_name: str): @@ -288,11 +303,17 @@ def update_plugin(plugin_name: str): if meta == {}: raise Exception("没有此插件元数据信息,无法更新") - remote_url = meta['source'] - if remote_url == "https://github.com/RockChinQ/QChatGPT" or remote_url == "https://gitee.com/RockChin/QChatGPT" \ - or remote_url == "" or remote_url is None or remote_url == "http://github.com/RockChinQ/QChatGPT" or remote_url == "http://gitee.com/RockChin/QChatGPT": + remote_url = meta["source"] + if ( + remote_url == "https://github.com/RockChinQ/QChatGPT" + or remote_url == "https://gitee.com/RockChin/QChatGPT" + or remote_url == "" + or remote_url is None + or remote_url == "http://github.com/RockChinQ/QChatGPT" + or remote_url == "http://gitee.com/RockChin/QChatGPT" + ): raise Exception("插件没有远程地址记录,无法更新") - + # 重新安装插件 logging.info("正在重新安装插件以进行更新...") @@ -301,7 +322,7 @@ def update_plugin(plugin_name: str): def get_plugin_name_by_path_name(plugin_path_name: str) -> str: for k, v in __plugins__.items(): - if v['path'] == "plugins/"+plugin_path_name+"/main.py": + if v["path"] == "plugins/" + plugin_path_name + "/main.py": return k return None @@ -309,8 +330,8 @@ def get_plugin_name_by_path_name(plugin_path_name: str) -> str: def get_plugin_path_name_by_plugin_name(plugin_name: str) -> str: if plugin_name not in __plugins__: return None - - plugin_main_module_path = __plugins__[plugin_name]['path'] + + plugin_main_module_path = __plugins__[plugin_name]["path"] plugin_main_module_path = plugin_main_module_path.replace("\\", "/") @@ -319,8 +340,29 @@ def get_plugin_path_name_by_plugin_name(plugin_name: str) -> str: return spt[1] +def get_plugin_info_for_audit(plugin_name: str) -> dict: + """获取插件信息""" + if plugin_name not in __plugins__: + return {} + plugin = __plugins__[plugin_name] + + name = plugin["name"] + meta = metadata.get_plugin_metadata(get_plugin_path_name_by_plugin_name(name)) + remote = meta["source"] if meta != {} else "" + author = plugin["author"] + version = plugin["version"] + + return { + "name": name, + "remote": remote, + "author": author, + "version": version, + } + + class EventContext: """事件上下文""" + eid = 0 """事件编号""" @@ -395,6 +437,7 @@ class EventContext: def emit(event_name: str, **kwargs) -> EventContext: """触发事件""" import pkg.utils.context as context + if context.get_plugin_host() is None: return None return context.get_plugin_host().emit(event_name, **kwargs) @@ -446,8 +489,7 @@ class PluginHost: emitted_plugins = [] for plugin in iter_plugins(): - - if not plugin['enabled']: + if not plugin["enabled"]: continue # if plugin['instance'] is None: @@ -459,10 +501,10 @@ class PluginHost: # logging.error("插件 {} 初始化时发生错误: {}".format(plugin['name'], sys.exc_info())) # continue - if 'hooks' not in plugin or event_name not in plugin['hooks']: + if "hooks" not in plugin or event_name not in plugin["hooks"]: continue - emitted_plugins.append(plugin) + emitted_plugins.append(plugin['name']) hooks = [] if event_name in plugin["hooks"]: @@ -471,45 +513,37 @@ class PluginHost: try: already_prevented_default = event_context.is_prevented_default() - kwargs['host'] = context.get_plugin_host() - kwargs['event'] = event_context + kwargs["host"] = context.get_plugin_host() + kwargs["event"] = event_context - hook(plugin['instance'], **kwargs) + hook(plugin["instance"], **kwargs) - if event_context.is_prevented_default() and not already_prevented_default: - logging.debug("插件 {} 已要求阻止事件 {} 的默认行为".format(plugin['name'], event_name)) + if ( + event_context.is_prevented_default() + and not already_prevented_default + ): + logging.debug( + "插件 {} 已要求阻止事件 {} 的默认行为".format(plugin["name"], event_name) + ) except Exception as e: - logging.error("插件{}响应事件{}时发生错误".format(plugin['name'], event_name)) + logging.error("插件{}响应事件{}时发生错误".format(plugin["name"], event_name)) logging.error(traceback.format_exc()) # print("done:{}".format(plugin['name'])) if event_context.is_prevented_postorder(): - logging.debug("插件 {} 阻止了后序插件的执行".format(plugin['name'])) + logging.debug("插件 {} 阻止了后序插件的执行".format(plugin["name"])) break - logging.debug("事件 {} ({}) 处理完毕,返回值: {}".format(event_name, event_context.eid, - event_context.__return_value__)) + logging.debug( + "事件 {} ({}) 处理完毕,返回值: {}".format( + event_name, event_context.eid, event_context.__return_value__ + ) + ) + print(emitted_plugins) if len(emitted_plugins) > 0: - - plugins_info = [] - - for plugin in emitted_plugins: - name = plugin['name'] - meta = metadata.get_plugin_metadata(get_plugin_path_name_by_plugin_name(name)) - remote = meta['source'] if meta != {} else "" - author = plugin['author'] - version = plugin['version'] - - plugins_info.append( - { - "name": name, - "remote": remote, - "author": author, - "version": version, - } - ) + plugins_info = [get_plugin_info_for_audit(p) for p in emitted_plugins] context.get_center_v2_api().usage.post_event_record( plugins=plugins_info, @@ -518,5 +552,6 @@ class PluginHost: return event_context + if __name__ == "__main__": pass