From 658eb278c445bc7fb088819503b504e368489a85 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sat, 16 Nov 2024 16:13:02 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E9=83=A8?= =?UTF-8?q?=E5=88=86=E6=8F=92=E4=BB=B6=E7=AE=A1=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/api/http/controller/groups/plugins.py | 4 +- pkg/command/operators/func.py | 11 +- pkg/command/operators/plugin.py | 8 +- pkg/core/app.py | 4 +- pkg/pipeline/controller.py | 18 ++- pkg/platform/manager.py | 78 ++++-------- pkg/plugin/context.py | 29 ++++- pkg/plugin/loader.py | 5 +- pkg/plugin/loaders/classic.py | 23 ++-- pkg/plugin/manager.py | 138 +++++++++++++++------- pkg/provider/session/sessionmgr.py | 6 +- pkg/provider/tools/entities.py | 2 - pkg/provider/tools/toolmgr.py | 51 ++++---- 13 files changed, 219 insertions(+), 158 deletions(-) diff --git a/pkg/api/http/controller/groups/plugins.py b/pkg/api/http/controller/groups/plugins.py index ef30f3a9..86ac7ec6 100644 --- a/pkg/api/http/controller/groups/plugins.py +++ b/pkg/api/http/controller/groups/plugins.py @@ -15,7 +15,7 @@ class PluginsRouterGroup(group.RouterGroup): async def initialize(self) -> None: @self.route('', methods=['GET']) async def _() -> str: - plugins = self.ap.plugin_mgr.plugins + plugins = self.ap.plugin_mgr.plugins() plugins_data = [plugin.model_dump() for plugin in plugins] @@ -27,7 +27,7 @@ class PluginsRouterGroup(group.RouterGroup): async def _(author: str, plugin_name: str) -> str: data = await quart.request.json target_enabled = data.get('target_enabled') - await self.ap.plugin_mgr.update_plugin_status(plugin_name, target_enabled) + await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled) return self.success() @self.route('///update', methods=['POST']) diff --git a/pkg/command/operators/func.py b/pkg/command/operators/func.py index 33031bfb..404813eb 100644 --- a/pkg/command/operators/func.py +++ b/pkg/command/operators/func.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import AsyncGenerator from .. import operator, entities, cmdmgr +from ...plugin import context as plugin_context @operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func') @@ -9,16 +10,18 @@ class FuncOperator(operator.CommandOperator): async def execute( self, context: entities.ExecuteContext ) -> AsyncGenerator[entities.CommandReturn, None]: - reply_str = "当前已加载的内容函数: \n\n" + reply_str = "当前已启用的内容函数: \n\n" index = 1 - all_functions = await self.ap.tool_mgr.get_all_functions() + all_functions = await self.ap.tool_mgr.get_all_functions( + plugin_enabled=True, + plugin_status=plugin_context.RuntimeContainerStatus.INITIALIZED, + ) for func in all_functions: - reply_str += "{}. {}{}:\n{}\n\n".format( + reply_str += "{}. {}:\n{}\n\n".format( index, - ("(已禁用) " if not func.enable else ""), func.name, func.description, ) diff --git a/pkg/command/operators/plugin.py b/pkg/command/operators/plugin.py index 715c7726..e50d0ba2 100644 --- a/pkg/command/operators/plugin.py +++ b/pkg/command/operators/plugin.py @@ -18,7 +18,7 @@ class PluginOperator(operator.CommandOperator): context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - plugin_list = self.ap.plugin_mgr.plugins + plugin_list = self.ap.plugin_mgr.plugins() reply_str = "所有插件({}):\n".format(len(plugin_list)) idx = 0 for plugin in plugin_list: @@ -110,7 +110,7 @@ class PluginUpdateAllOperator(operator.CommandOperator): try: plugins = [ p.plugin_name - for p in self.ap.plugin_mgr.plugins + for p in self.ap.plugin_mgr.plugins() ] if plugins: @@ -182,7 +182,7 @@ class PluginEnableOperator(operator.CommandOperator): plugin_name = context.crt_params[0] try: - if await self.ap.plugin_mgr.update_plugin_status(plugin_name, True): + if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True): yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name)) else: yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) @@ -210,7 +210,7 @@ class PluginDisableOperator(operator.CommandOperator): plugin_name = context.crt_params[0] try: - if await self.ap.plugin_mgr.update_plugin_status(plugin_name, False): + if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False): yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name)) else: yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) diff --git a/pkg/core/app.py b/pkg/core/app.py index ead769f1..907acfb5 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -141,7 +141,7 @@ class Application: tips = f""" ======================================= -✨ 您可通过以下方式访问管理面板: +✨ 您可通过以下方式访问管理面板 🏠 本地地址:http://{host_ip}:{port}/ 🌐 公网地址:http://{public_ip}:{port}/ @@ -150,6 +150,8 @@ class Application: 🔗 若要使用公网地址访问,请阅读以下须知 1. 公网地址仅供参考,请以您的主机公网 IP 为准; 2. 要使用公网地址访问,请确保您的主机具有公网 IP,并且系统防火墙已放行 {port} 端口; + +🤯 WebUI 仍处于 Beta 测试阶段,如有问题或建议请反馈到 https://github.com/RockChinQ/LangBot/issues ======================================= """.strip() for line in tips.split("\n"): diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index 92ba8173..3113b3bf 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -163,6 +163,23 @@ class Controller: async def process_query(self, query: entities.Query): """处理请求 """ + + # ======== 触发 MessageReceived 事件 ======== + event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=event_type( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + message_chain=query.message_chain, + query=query + ) + ) + + if event_ctx.is_prevented_default(): + return + self.ap.logger.debug(f"Processing query {query}") try: @@ -170,7 +187,6 @@ class Controller: except Exception as e: self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={query.current_stage.inst_name} : {e}") self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") - # traceback.print_exc() finally: self.ap.logger.debug(f"Query {query} processed") diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index b33c1e55..6a391a1f 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -41,72 +41,36 @@ class PlatformManager: async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessageSourceAdapter): - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.PersonMessageReceived( - launcher_type='person', - launcher_id=event.sender.id, - sender_id=event.sender.id, - message_chain=event.message_chain, - query=None - ) + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain, + adapter=adapter ) - if not event_ctx.is_prevented_default(): - - await self.ap.query_pool.add_query( - launcher_type=core_entities.LauncherTypes.PERSON, - launcher_id=event.sender.id, - sender_id=event.sender.id, - message_event=event, - message_chain=event.message_chain, - adapter=adapter - ) - async def on_stranger_message(event: platform_events.StrangerMessage, adapter: msadapter.MessageSourceAdapter): - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.PersonMessageReceived( - launcher_type='person', - launcher_id=event.sender.id, - sender_id=event.sender.id, - message_chain=event.message_chain, - query=None - ) + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain, + adapter=adapter ) - if not event_ctx.is_prevented_default(): - - await self.ap.query_pool.add_query( - launcher_type=core_entities.LauncherTypes.PERSON, - launcher_id=event.sender.id, - sender_id=event.sender.id, - message_event=event, - message_chain=event.message_chain, - adapter=adapter - ) - async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessageSourceAdapter): - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.GroupMessageReceived( - launcher_type='group', - launcher_id=event.group.id, - sender_id=event.sender.id, - message_chain=event.message_chain, - query=None - ) + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.GROUP, + launcher_id=event.group.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain, + adapter=adapter ) - - if not event_ctx.is_prevented_default(): - - await self.ap.query_pool.add_query( - launcher_type=core_entities.LauncherTypes.GROUP, - launcher_id=event.group.id, - sender_id=event.sender.id, - message_event=event, - message_chain=event.message_chain, - adapter=adapter - ) index = 0 diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py index 46ffb4ac..d4cd58d8 100644 --- a/pkg/plugin/context.py +++ b/pkg/plugin/context.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing import abc import pydantic +import enum from . import events from ..provider.tools import entities as tools_entities @@ -85,10 +86,18 @@ class BasePlugin(metaclass=abc.ABCMeta): """应用程序对象""" def __init__(self, host: APIHost): + """初始化阶段被调用""" self.host = host async def initialize(self): - """初始化插件""" + """初始化阶段被调用""" + pass + + async def destroy(self): + """释放/禁用插件时被调用""" + pass + + def __del__(self): pass @@ -247,6 +256,16 @@ class EventContext: EventContext.eid += 1 +class RuntimeContainerStatus(enum.Enum): + """插件容器状态""" + + MOUNTED = "mounted" + """已加载进内存,所有位于运行时记录中的 RuntimeContainer 至少是这个状态""" + + INITIALIZED = "initialized" + """已初始化""" + + class RuntimeContainer(pydantic.BaseModel): """运行时的插件容器 @@ -294,6 +313,9 @@ class RuntimeContainer(pydantic.BaseModel): content_functions: list[tools_entities.LLMFunction] = [] """内容函数""" + status: RuntimeContainerStatus = RuntimeContainerStatus.MOUNTED + """插件状态""" + class Config: arbitrary_types_allowed = True @@ -318,9 +340,6 @@ class RuntimeContainer(pydantic.BaseModel): self.priority = setting['priority'] self.enabled = setting['enabled'] - for function in self.content_functions: - function.enable = self.enabled - def model_dump(self, *args, **kwargs): return { 'name': self.plugin_name, @@ -342,9 +361,9 @@ class RuntimeContainer(pydantic.BaseModel): 'human_desc': function.human_desc, 'description': function.description, 'parameters': function.parameters, - 'enable': function.enable, 'func': function.func.__name__, } for function in self.content_functions ], + 'status': self.status.value, } diff --git a/pkg/plugin/loader.py b/pkg/plugin/loader.py index d5f4a20c..44ded4ac 100644 --- a/pkg/plugin/loader.py +++ b/pkg/plugin/loader.py @@ -13,13 +13,16 @@ class PluginLoader(metaclass=abc.ABCMeta): ap: app.Application + plugins: list[context.RuntimeContainer] + def __init__(self, ap: app.Application): self.ap = ap + self.plugins = [] async def initialize(self): pass @abc.abstractmethod - async def load_plugins(self) -> list[context.RuntimeContainer]: + async def load_plugins(self): pass diff --git a/pkg/plugin/loaders/classic.py b/pkg/plugin/loaders/classic.py index b2553c1a..b3710c9e 100644 --- a/pkg/plugin/loaders/classic.py +++ b/pkg/plugin/loaders/classic.py @@ -20,7 +20,14 @@ class PluginLoader(loader.PluginLoader): _current_container: context.RuntimeContainer = None - containers: list[context.RuntimeContainer] = [] + plugins: list[context.RuntimeContainer] = [] + + def __init__(self, ap): + self.ap = ap + self.plugins = [] + self._current_pkg_path = '' + self._current_module_path = '' + self._current_container = None async def initialize(self): """初始化""" @@ -77,8 +84,10 @@ class PluginLoader(loader.PluginLoader): } # 把 ctx.event 所有的属性都放到 args 里 - for k, v in ctx.event.dict().items(): - args[k] = v + # for k, v in ctx.event.dict().items(): + # args[k] = v + for attr_name in ctx.event.__dict__.keys(): + args[attr_name] = getattr(ctx.event, attr_name) func(plugin, **args) @@ -113,7 +122,6 @@ class PluginLoader(loader.PluginLoader): name=function_name, human_desc='', description=function_schema['description'], - enable=True, parameters=function_schema['parameters'], func=handler, ) @@ -153,7 +161,6 @@ class PluginLoader(loader.PluginLoader): name=function_name, human_desc='', description=function_schema['description'], - enable=True, parameters=function_schema['parameters'], func=func, ) @@ -189,15 +196,13 @@ class PluginLoader(loader.PluginLoader): importlib.import_module(module.__name__ + "." + item.name) if self._current_container is not None: - self.containers.append(self._current_container) + self.plugins.append(self._current_container) self.ap.logger.debug(f'插件 {self._current_container} 已加载') except: self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误') traceback.print_exc() - async def load_plugins(self) -> list[context.RuntimeContainer]: + async def load_plugins(self): """加载插件 """ await self._walk_plugin_path(__import__("plugins", fromlist=[""])) - - return self.containers diff --git a/pkg/plugin/manager.py b/pkg/plugin/manager.py index ea9c4371..19d0ee49 100644 --- a/pkg/plugin/manager.py +++ b/pkg/plugin/manager.py @@ -22,7 +22,22 @@ class PluginManager: api_host: context.APIHost - plugins: list[context.RuntimeContainer] + def plugins( + self, + enabled: bool=None, + status: context.RuntimeContainerStatus=None, + ) -> list[context.RuntimeContainer]: + """获取插件列表 + """ + plugins = self.loader.plugins + + if enabled is not None: + plugins = [plugin for plugin in plugins if plugin.enabled == enabled] + + if status is not None: + plugins = [plugin for plugin in plugins if plugin.status == status] + + return plugins def __init__(self, ap: app.Application): self.ap = ap @@ -30,7 +45,6 @@ class PluginManager: self.installer = github.GitHubRepoInstaller(ap) self.setting = setting.SettingManager(ap) self.api_host = context.APIHost(ap) - self.plugins = [] async def initialize(self): await self.loader.initialize() @@ -41,27 +55,57 @@ class PluginManager: setattr(models, 'require_ver', self.api_host.require_ver) async def load_plugins(self): - self.plugins = await self.loader.load_plugins() + await self.loader.load_plugins() - await self.setting.sync_setting(self.plugins) + await self.setting.sync_setting(self.loader.plugins) # 按优先级倒序 - self.plugins.sort(key=lambda x: x.priority, reverse=True) + self.loader.plugins.sort(key=lambda x: x.priority, reverse=True) - self.ap.logger.debug(f'优先级排序后的插件列表 {self.plugins}') + self.ap.logger.debug(f'优先级排序后的插件列表 {self.loader.plugins}') + + async def initialize_plugin(self, plugin: context.RuntimeContainer): + self.ap.logger.debug(f'初始化插件 {plugin.plugin_name}') + plugin.plugin_inst = plugin.plugin_class(self.api_host) + plugin.plugin_inst.ap = self.ap + plugin.plugin_inst.host = self.api_host + await plugin.plugin_inst.initialize() + plugin.status = context.RuntimeContainerStatus.INITIALIZED async def initialize_plugins(self): - for plugin in self.plugins: + for plugin in self.plugins(): + if not plugin.enabled: + self.ap.logger.debug(f'插件 {plugin.plugin_name} 未启用,跳过初始化') + continue try: - plugin.plugin_inst = plugin.plugin_class(self.api_host) - plugin.plugin_inst.ap = self.ap - plugin.plugin_inst.host = self.api_host - await plugin.plugin_inst.initialize() + await self.initialize_plugin(plugin) except Exception as e: self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}') self.ap.logger.exception(e) continue + async def destroy_plugin(self, plugin: context.RuntimeContainer): + if plugin.status != context.RuntimeContainerStatus.INITIALIZED: + return + + self.ap.logger.debug(f'释放插件 {plugin.plugin_name}') + await plugin.plugin_inst.destroy() + plugin.plugin_inst = None + plugin.status = context.RuntimeContainerStatus.MOUNTED + + async def destroy_plugins(self): + for plugin in self.plugins(): + if plugin.status != context.RuntimeContainerStatus.INITIALIZED: + self.ap.logger.debug(f'插件 {plugin.plugin_name} 未初始化,跳过释放') + continue + + try: + await self.destroy_plugin(plugin) + except Exception as e: + self.ap.logger.error(f'插件 {plugin.plugin_name} 释放失败: {e}') + self.ap.logger.exception(e) + continue + async def install_plugin( self, plugin_source: str, @@ -127,7 +171,7 @@ class PluginManager: def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer: """通过插件名获取插件 """ - for plugin in self.plugins: + for plugin in self.plugins(): if plugin.plugin_name == plugin_name: return plugin return None @@ -143,30 +187,32 @@ class PluginManager: emitted_plugins: list[context.RuntimeContainer] = [] - for plugin in self.plugins: - if plugin.enabled: - if event.__class__ in plugin.event_handlers: - self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}') - - is_prevented_default_before_call = ctx.is_prevented_default() + for plugin in self.plugins( + enabled=True, + status=context.RuntimeContainerStatus.INITIALIZED + ): + if event.__class__ in plugin.event_handlers: + self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}') + + is_prevented_default_before_call = ctx.is_prevented_default() - try: - await plugin.event_handlers[event.__class__]( - plugin.plugin_inst, - ctx - ) - except Exception as e: - self.ap.logger.error(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}') - self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") - - emitted_plugins.append(plugin) + try: + await plugin.event_handlers[event.__class__]( + plugin.plugin_inst, + ctx + ) + except Exception as e: + self.ap.logger.error(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}') + self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + + emitted_plugins.append(plugin) - if not is_prevented_default_before_call and ctx.is_prevented_default(): - self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行') + if not is_prevented_default_before_call and ctx.is_prevented_default(): + self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行') - if ctx.is_prevented_postorder(): - self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行') - break + if ctx.is_prevented_postorder(): + self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行') + break for key in ctx.__return_value__.keys(): if hasattr(ctx.event, key): @@ -191,16 +237,22 @@ class PluginManager: return ctx - async def update_plugin_status(self, plugin_name: str, new_status: bool): + async def update_plugin_switch(self, plugin_name: str, new_status: bool): if self.get_plugin_by_name(plugin_name) is not None: - for plugin in self.plugins: + for plugin in self.plugins(): if plugin.plugin_name == plugin_name: - plugin.enabled = new_status - - for func in plugin.content_functions: - func.enable = new_status + if plugin.enabled == new_status: + return False - await self.setting.dump_container_setting(self.plugins) + # 初始化/释放插件 + if new_status: + await self.initialize_plugin(plugin) + else: + await self.destroy_plugin(plugin) + + plugin.enabled = new_status + + await self.setting.dump_container_setting(self.loader.plugins) break @@ -214,11 +266,11 @@ class PluginManager: plugin_name = plugin.get('name') plugin_priority = plugin.get('priority') - for plugin in self.plugins: + for plugin in self.loader.plugins: if plugin.plugin_name == plugin_name: plugin.priority = plugin_priority break - self.plugins.sort(key=lambda x: x.priority, reverse=True) + self.loader.plugins.sort(key=lambda x: x.priority, reverse=True) - await self.setting.dump_container_setting(self.plugins) + await self.setting.dump_container_setting(self.loader.plugins) diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py index 9130cbe2..f328ffff 100644 --- a/pkg/provider/session/sessionmgr.py +++ b/pkg/provider/session/sessionmgr.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from ...core import app, entities as core_entities +from ...plugin import context as plugin_context class SessionManager: @@ -51,7 +52,10 @@ class SessionManager: prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name), messages=[], use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']), - use_funcs=await self.ap.tool_mgr.get_all_functions(), + use_funcs=await self.ap.tool_mgr.get_all_functions( + plugin_enabled=True, + plugin_status=plugin_context.RuntimeContainerStatus.INITIALIZED, + ), ) session.conversations.append(conversation) session.using_conversation = conversation diff --git a/pkg/provider/tools/entities.py b/pkg/provider/tools/entities.py index 52867291..8f09ab21 100644 --- a/pkg/provider/tools/entities.py +++ b/pkg/provider/tools/entities.py @@ -20,8 +20,6 @@ class LLMFunction(pydantic.BaseModel): description: str """给LLM识别的函数描述""" - enable: typing.Optional[bool] = True - parameters: dict func: typing.Callable diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index 5e780c50..3e412eaf 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -20,28 +20,25 @@ class ToolManager: async def initialize(self): pass - async def get_function(self, name: str) -> entities.LLMFunction: - """获取函数""" - for function in await self.get_all_functions(): - if function.name == name: - return function - return None - async def get_function_and_plugin( self, name: str ) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]: - """获取函数和插件""" - for plugin in self.ap.plugin_mgr.plugins: + """获取函数和插件实例""" + for plugin in self.ap.plugin_mgr.plugins( + enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED + ): for function in plugin.content_functions: if function.name == name: return function, plugin.plugin_inst return None, None - async def get_all_functions(self) -> list[entities.LLMFunction]: + async def get_all_functions(self, plugin_enabled: bool=None, plugin_status: plugin_context.RuntimeContainerStatus=None) -> list[entities.LLMFunction]: """获取所有函数""" all_functions: list[entities.LLMFunction] = [] - for plugin in self.ap.plugin_mgr.plugins: + for plugin in self.ap.plugin_mgr.plugins( + enabled=plugin_enabled, status=plugin_status + ): all_functions.extend(plugin.content_functions) return all_functions @@ -51,16 +48,15 @@ class ToolManager: tools = [] for function in use_funcs: - if function.enable: - function_schema = { - "type": "function", - "function": { - "name": function.name, - "description": function.description, - "parameters": function.parameters, - }, - } - tools.append(function_schema) + function_schema = { + "type": "function", + "function": { + "name": function.name, + "description": function.description, + "parameters": function.parameters, + }, + } + tools.append(function_schema) return tools @@ -92,13 +88,12 @@ class ToolManager: tools = [] for function in use_funcs: - if function.enable: - function_schema = { - "name": function.name, - "description": function.description, - "input_schema": function.parameters, - } - tools.append(function_schema) + function_schema = { + "name": function.name, + "description": function.description, + "input_schema": function.parameters, + } + tools.append(function_schema) return tools