From 97603e8441c723b23df70d66e87c3bd55169552c Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Wed, 19 Mar 2025 09:36:03 +0800 Subject: [PATCH] feat: tool loader abstraction --- pkg/command/operators/func.py | 1 - pkg/provider/session/sessionmgr.py | 1 - pkg/provider/tools/loader.py | 16 ++++- pkg/provider/tools/loaders/plugin.py | 88 ++++++++++++++++++++++++++++ pkg/provider/tools/toolmgr.py | 70 ++++++---------------- 5 files changed, 118 insertions(+), 58 deletions(-) create mode 100644 pkg/provider/tools/loaders/plugin.py diff --git a/pkg/command/operators/func.py b/pkg/command/operators/func.py index 404813eb..ae2ba4c1 100644 --- a/pkg/command/operators/func.py +++ b/pkg/command/operators/func.py @@ -16,7 +16,6 @@ class FuncOperator(operator.CommandOperator): all_functions = await self.ap.tool_mgr.get_all_functions( plugin_enabled=True, - plugin_status=plugin_context.RuntimeContainerStatus.INITIALIZED, ) for func in all_functions: diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py index f328ffff..00523472 100644 --- a/pkg/provider/session/sessionmgr.py +++ b/pkg/provider/session/sessionmgr.py @@ -54,7 +54,6 @@ class SessionManager: use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']), use_funcs=await self.ap.tool_mgr.get_all_functions( plugin_enabled=True, - plugin_status=plugin_context.RuntimeContainerStatus.INITIALIZED, ), ) session.conversations.append(conversation) diff --git a/pkg/provider/tools/loader.py b/pkg/provider/tools/loader.py index 94f90a2c..82e6440d 100644 --- a/pkg/provider/tools/loader.py +++ b/pkg/provider/tools/loader.py @@ -3,11 +3,11 @@ from __future__ import annotations import abc import typing -from ...core import app +from ...core import app, entities as core_entities from . import entities as tools_entities -preregistered_loaders = [] +preregistered_loaders: list[typing.Type[ToolLoader]] = [] def loader_class(name: str): """注册一个工具加载器 @@ -34,6 +34,16 @@ class ToolLoader(abc.ABC): pass @abc.abstractmethod - def get_tools(self) -> list[tools_entities.LLMFunction]: + async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]: """获取所有工具""" pass + + @abc.abstractmethod + async def has_tool(self, name: str) -> bool: + """检查工具是否存在""" + pass + + @abc.abstractmethod + async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: + """执行工具调用""" + pass \ No newline at end of file diff --git a/pkg/provider/tools/loaders/plugin.py b/pkg/provider/tools/loaders/plugin.py new file mode 100644 index 00000000..da0bc555 --- /dev/null +++ b/pkg/provider/tools/loaders/plugin.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import typing +import traceback + +from .. import loader, entities as tools_entities +from ....core import app, entities as core_entities +from ....plugin import context as plugin_context + + +@loader.loader_class("plugin-tool-loader") +class PluginToolLoader(loader.ToolLoader): + """插件工具加载器。 + + 本加载器中不存储工具信息,仅负责从插件系统中获取工具信息。 + """ + + async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]: + + # 从插件系统获取工具(内容函数) + all_functions: list[tools_entities.LLMFunction] = [] + + for plugin in self.ap.plugin_mgr.plugins( + enabled=enabled, status=plugin_context.RuntimeContainerStatus.INITIALIZED + ): + all_functions.extend(plugin.content_functions) + + return all_functions + + async def has_tool(self, name: str) -> bool: + """检查工具是否存在""" + for plugin in self.ap.plugin_mgr.plugins( + enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED + ): + for function in plugin.content_functions: + if function.name == name: + return True + return False + + async def _get_function_and_plugin( + self, name: str + ) -> typing.Tuple[tools_entities.LLMFunction, plugin_context.BasePlugin]: + """获取函数和插件实例""" + for plugin in self.ap.plugin_mgr.plugins( + enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED + ): + for function in plugin.content_functions: + if function.name == name: + return function, plugin.plugin_inst + return None, None + + async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: + + try: + + function, plugin = await self._get_function_and_plugin(name) + if function is None: + return None + + parameters = parameters.copy() + + parameters = {"query": query, **parameters} + + return await function.func(plugin, **parameters) + except Exception as e: + self.ap.logger.error(f"执行函数 {name} 时发生错误: {e}") + traceback.print_exc() + return f"error occurred when executing function {name}: {e}" + finally: + plugin = None + + for p in self.ap.plugin_mgr.plugins(): + if function in p.content_functions: + plugin = p + break + + if plugin is not None: + + await self.ap.ctr_mgr.usage.post_function_record( + plugin={ + "name": plugin.plugin_name, + "remote": plugin.plugin_source, + "version": plugin.plugin_version, + "author": plugin.plugin_author, + }, + function_name=function.name, + function_description=function.description, + ) \ No newline at end of file diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index c7bd0018..9986d3ab 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -4,8 +4,9 @@ import typing import traceback from ...core import app, entities as core_entities -from . import entities +from . import entities, loader as tools_loader from ...plugin import context as plugin_context +from .loaders import plugin class ToolManager: @@ -13,33 +14,26 @@ class ToolManager: ap: app.Application + loaders: list[tools_loader.ToolLoader] + def __init__(self, ap: app.Application): self.ap = ap self.all_functions = [] + self.loaders = [] async def initialize(self): - pass - async def get_function_and_plugin( - self, name: str - ) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]: - """获取函数和插件实例""" - for plugin in self.ap.plugin_mgr.plugins( - enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED - ): - for function in plugin.content_functions: - if function.name == name: - return function, plugin.plugin_inst - return None, None + for loader_cls in tools_loader.preregistered_loaders: + loader_inst = loader_cls(self.ap) + await loader_inst.initialize() + self.loaders.append(loader_inst) - async def get_all_functions(self, plugin_enabled: bool=None, plugin_status: plugin_context.RuntimeContainerStatus=None) -> list[entities.LLMFunction]: + async def get_all_functions(self, plugin_enabled: bool=None) -> list[entities.LLMFunction]: """获取所有函数""" all_functions: list[entities.LLMFunction] = [] - for plugin in self.ap.plugin_mgr.plugins( - enabled=plugin_enabled, status=plugin_status - ): - all_functions.extend(plugin.content_functions) + for loader in self.loaders: + all_functions.extend(await loader.get_tools(plugin_enabled)) return all_functions @@ -102,38 +96,8 @@ class ToolManager: ) -> typing.Any: """执行函数调用""" - try: - - function, plugin = await self.get_function_and_plugin(name) - if function is None: - return None - - parameters = parameters.copy() - - parameters = {"query": query, **parameters} - - return await function.func(plugin, **parameters) - except Exception as e: - self.ap.logger.error(f"执行函数 {name} 时发生错误: {e}") - traceback.print_exc() - return f"error occurred when executing function {name}: {e}" - finally: - plugin = None - - for p in self.ap.plugin_mgr.plugins(): - if function in p.content_functions: - plugin = p - break - - if plugin is not None: - - await self.ap.ctr_mgr.usage.post_function_record( - plugin={ - "name": plugin.plugin_name, - "remote": plugin.plugin_source, - "version": plugin.plugin_version, - "author": plugin.plugin_author, - }, - function_name=function.name, - function_description=function.description, - ) \ No newline at end of file + for loader in self.loaders: + if await loader.has_tool(name): + return await loader.invoke_tool(query, name, parameters) + else: + raise ValueError(f"未找到工具: {name}") \ No newline at end of file