From 5b044a1917ddc4061325e859bf16478ec9920c70 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sun, 6 Jul 2025 21:03:33 +0800 Subject: [PATCH] feat: add Tool component --- pkg/api/http/controller/groups/plugins.py | 2 +- pkg/api/http/controller/groups/system.py | 13 ++++- pkg/command/operators/func.py | 4 +- pkg/pipeline/preproc/preproc.py | 4 +- pkg/plugin/connector.py | 16 +++++- pkg/plugin/handler.py | 23 +++++++++ pkg/provider/runners/localagent.py | 2 +- pkg/provider/tools/loader.py | 5 +- pkg/provider/tools/loaders/mcp.py | 9 ++-- pkg/provider/tools/loaders/plugin.py | 63 +++++++---------------- pkg/provider/tools/toolmgr.py | 9 ++-- 11 files changed, 84 insertions(+), 66 deletions(-) diff --git a/pkg/api/http/controller/groups/plugins.py b/pkg/api/http/controller/groups/plugins.py index 4551cb07..86ad25e8 100644 --- a/pkg/api/http/controller/groups/plugins.py +++ b/pkg/api/http/controller/groups/plugins.py @@ -12,7 +12,7 @@ class PluginsRouterGroup(group.RouterGroup): async def initialize(self) -> None: @self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: - plugins = await self.ap.plugin_connector.handler.list_plugins() + plugins = await self.ap.plugin_connector.list_plugins() return self.success(data={'plugins': plugins}) diff --git a/pkg/api/http/controller/groups/system.py b/pkg/api/http/controller/groups/system.py index 1089626d..979d60b2 100644 --- a/pkg/api/http/controller/groups/system.py +++ b/pkg/api/http/controller/groups/system.py @@ -35,7 +35,7 @@ class SystemRouterGroup(group.RouterGroup): return self.success(data=task.to_dict()) - @self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + @self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: if not constants.debug_mode: return self.http_status(403, 403, 'Forbidden') @@ -45,3 +45,14 @@ class SystemRouterGroup(group.RouterGroup): ap = self.ap return self.success(data=exec(py_code, {'ap': ap})) + + @self.route('/debug/tools/call', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + async def _() -> str: + if not constants.debug_mode: + return self.http_status(403, 403, 'Forbidden') + + data = await quart.request.json + + return self.success( + data=await self.ap.tool_mgr.execute_func_call(data['tool_name'], data['tool_parameters']) + ) diff --git a/pkg/command/operators/func.py b/pkg/command/operators/func.py index 648cc5e2..48dbd316 100644 --- a/pkg/command/operators/func.py +++ b/pkg/command/operators/func.py @@ -11,9 +11,7 @@ class FuncOperator(operator.CommandOperator): index = 1 - all_functions = await self.ap.tool_mgr.get_all_functions( - plugin_enabled=True, - ) + all_functions = await self.ap.tool_mgr.get_all_tools() for func in all_functions: reply_str += '{}. {}:\n{}\n\n'.format( diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 5b82bdf4..b48ced64 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -60,9 +60,7 @@ class PreProcessor(stage.PipelineStage): query.use_funcs = [] if llm_model.model_entity.abilities.__contains__('func_call'): - query.use_funcs = await self.ap.tool_mgr.get_all_functions( - plugin_enabled=True, - ) + query.use_funcs = await self.ap.tool_mgr.get_all_tools() query.variables = { 'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}', diff --git a/pkg/plugin/connector.py b/pkg/plugin/connector.py index 7b1bcdc2..c7496e70 100644 --- a/pkg/plugin/connector.py +++ b/pkg/plugin/connector.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from typing import Any import typing import os import sys @@ -11,8 +12,10 @@ from . import handler from ..utils import platform from langbot_plugin.runtime.io.controllers.stdio import client as stdio_client_controller from langbot_plugin.runtime.io.controllers.ws import client as ws_client_controller -from langbot_plugin.api.entities import events, context +from langbot_plugin.api.entities import events +from langbot_plugin.api.entities import context import langbot_plugin.runtime.io.connection as base_connection +from langbot_plugin.api.definition.components.manifest import ComponentManifest class PluginRuntimeConnector: @@ -91,6 +94,9 @@ class PluginRuntimeConnector: async def initialize_plugins(self): pass + async def list_plugins(self) -> list[dict[str, Any]]: + return await self.handler.list_plugins() + async def emit_event( self, event: events.BaseEventModel, @@ -104,3 +110,11 @@ class PluginRuntimeConnector: event_ctx = context.EventContext.parse_from_dict(event_ctx_result['event_context']) return event_ctx + + async def list_tools(self) -> list[ComponentManifest]: + list_tools_data = await self.handler.list_tools() + + return [ComponentManifest.model_validate(tool) for tool in list_tools_data] + + async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]: + return await self.handler.call_tool(tool_name, parameters) diff --git a/pkg/plugin/handler.py b/pkg/plugin/handler.py index afa0bbbc..16c95770 100644 --- a/pkg/plugin/handler.py +++ b/pkg/plugin/handler.py @@ -94,3 +94,26 @@ class RuntimeConnectionHandler(handler.Handler): ) return result + + async def list_tools(self) -> list[dict[str, Any]]: + """List tools""" + result = await self.call_action( + LangBotToRuntimeAction.LIST_TOOLS, + {}, + timeout=10, + ) + + return result['tools'] + + async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]: + """Call tool""" + result = await self.call_action( + LangBotToRuntimeAction.CALL_TOOL, + { + 'tool_name': tool_name, + 'tool_parameters': parameters, + }, + timeout=30, + ) + + return result['tool_response'] diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 5a879bcb..950f7756 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -43,7 +43,7 @@ class LocalAgentRunner(runner.RequestRunner): parameters = json.loads(func.arguments) - func_ret = await self.ap.tool_mgr.execute_func_call(query, func.name, parameters) + func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters) msg = llm_entities.Message( role='tool', diff --git a/pkg/provider/tools/loader.py b/pkg/provider/tools/loader.py index 658fdeb6..f3d65fd2 100644 --- a/pkg/provider/tools/loader.py +++ b/pkg/provider/tools/loader.py @@ -5,7 +5,6 @@ import typing from ...core import app import langbot_plugin.api.entities.builtin.resource.tool as resource_tool -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_loaders: list[typing.Type[ToolLoader]] = [] @@ -36,7 +35,7 @@ class ToolLoader(abc.ABC): pass @abc.abstractmethod - async def get_tools(self, enabled: bool = True) -> list[resource_tool.LLMTool]: + async def get_tools(self) -> list[resource_tool.LLMTool]: """获取所有工具""" pass @@ -46,7 +45,7 @@ class ToolLoader(abc.ABC): pass @abc.abstractmethod - async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any: + async def invoke_tool(self, name: str, parameters: dict) -> typing.Any: """执行工具调用""" pass diff --git a/pkg/provider/tools/loaders/mcp.py b/pkg/provider/tools/loaders/mcp.py index 577c704e..36fa9751 100644 --- a/pkg/provider/tools/loaders/mcp.py +++ b/pkg/provider/tools/loaders/mcp.py @@ -10,7 +10,6 @@ from mcp.client.sse import sse_client from .. import loader from ....core import app import langbot_plugin.api.entities.builtin.resource.tool as resource_tool -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class RuntimeMCPSession: @@ -84,7 +83,7 @@ class RuntimeMCPSession: for tool in tools.tools: - async def func(query: pipeline_query.Query, *, _tool=tool, **kwargs): + async def func(*, _tool=tool, **kwargs): result = await self.session.call_tool(_tool.name, kwargs) if result.isError: raise Exception(result.content[0].text) @@ -132,7 +131,7 @@ class MCPLoader(loader.ToolLoader): # self.ap.event_loop.create_task(session.initialize()) self.sessions[server_config['name']] = session - async def get_tools(self, enabled: bool = True) -> list[resource_tool.LLMTool]: + async def get_tools(self) -> list[resource_tool.LLMTool]: all_functions = [] for session in self.sessions.values(): @@ -145,11 +144,11 @@ class MCPLoader(loader.ToolLoader): async def has_tool(self, name: str) -> bool: return name in [f.name for f in self._last_listed_functions] - async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any: + async def invoke_tool(self, name: str, parameters: dict) -> typing.Any: for server_name, session in self.sessions.items(): for function in session.functions: if function.name == name: - return await function.func(query, **parameters) + return await function.func(**parameters) raise ValueError(f'未找到工具: {name}') diff --git a/pkg/provider/tools/loaders/plugin.py b/pkg/provider/tools/loaders/plugin.py index 7dfaea97..94296470 100644 --- a/pkg/provider/tools/loaders/plugin.py +++ b/pkg/provider/tools/loaders/plugin.py @@ -4,9 +4,7 @@ import typing import traceback from .. import loader -from ....plugin import context as plugin_context import langbot_plugin.api.entities.builtin.resource.tool as resource_tool -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @loader.loader_class('plugin-tool-loader') @@ -16,63 +14,42 @@ class PluginToolLoader(loader.ToolLoader): 本加载器中不存储工具信息,仅负责从插件系统中获取工具信息。 """ - async def get_tools(self, enabled: bool = True) -> list[resource_tool.LLMTool]: + async def get_tools(self) -> list[resource_tool.LLMTool]: # 从插件系统获取工具(内容函数) all_functions: list[resource_tool.LLMTool] = [] - for plugin in self.ap.plugin_mgr.plugins( - enabled=enabled, status=plugin_context.RuntimeContainerStatus.INITIALIZED - ): - all_functions.extend(plugin.tools) + for tool in await self.ap.plugin_connector.list_tools(): + tool_obj = resource_tool.LLMTool( + name=tool.metadata.name, + human_desc=tool.metadata.description.en_US, + description=tool.spec['llm_prompt'], + parameters=tool.spec['parameters'], + func=lambda parameters: {}, + ) + all_functions.append(tool_obj) return all_functions async def has_tool(self, name: str) -> bool: """检查工具是否存在""" - for plugin in self.ap.plugin_mgr.plugins( - enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED - ): - for function in plugin.tools: - if function.name == name: - return True + for tool in await self.ap.plugin_connector.list_tools(): + if tool.metadata.name == name: + return True return False - async def _get_function_and_plugin( - self, name: str - ) -> typing.Tuple[resource_tool.LLMTool, plugin_context.BasePlugin]: - """获取函数和插件实例""" - for plugin in self.ap.plugin_mgr.plugins( - enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED - ): - for function in plugin.tools: - if function.name == name: - return function, plugin.plugin_inst - return None, None + async def _get_tool(self, name: str) -> resource_tool.LLMTool: + for tool in await self.ap.plugin_connector.list_tools(): + if tool.metadata.name == name: + return tool + return None - async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any: + async def invoke_tool(self, name: str, parameters: dict) -> typing.Any: try: - function, plugin = await self._get_function_and_plugin(name) - if function is None: - return None - - parameters = parameters.copy() - - parameters = {'query': query, **parameters} - - return await function.func(plugin, **parameters) + return await self.ap.plugin_connector.call_tool(name, parameters) except Exception as e: self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}') traceback.print_exc() return f'error occurred when executing function {name}: {e}' - finally: - plugin = None - - for p in self.ap.plugin_mgr.plugins(): - if function in p.tools: - plugin = p - break - - # TODO statistics async def shutdown(self): """关闭工具""" diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index e1105750..43960aba 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -7,7 +7,6 @@ from . import loader as tools_loader from ...utils import importutil from . import loaders import langbot_plugin.api.entities.builtin.resource.tool as resource_tool -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query importutil.import_modules_in_pkg(loaders) @@ -30,12 +29,12 @@ class ToolManager: await loader_inst.initialize() self.loaders.append(loader_inst) - async def get_all_functions(self, plugin_enabled: bool = None) -> list[resource_tool.LLMTool]: + async def get_all_tools(self) -> list[resource_tool.LLMTool]: """获取所有函数""" all_functions: list[resource_tool.LLMTool] = [] for loader in self.loaders: - all_functions.extend(await loader.get_tools(plugin_enabled)) + all_functions.extend(await loader.get_tools()) return all_functions @@ -91,12 +90,12 @@ class ToolManager: return tools - async def execute_func_call(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any: + async def execute_func_call(self, name: str, parameters: dict) -> typing.Any: """执行函数调用""" for loader in self.loaders: if await loader.has_tool(name): - return await loader.invoke_tool(query, name, parameters) + return await loader.invoke_tool(name, parameters) else: raise ValueError(f'未找到工具: {name}')