diff --git a/src/langbot/pkg/provider/tools/loader.py b/src/langbot/pkg/provider/tools/loader.py index e90f07b3..5945f80d 100644 --- a/src/langbot/pkg/provider/tools/loader.py +++ b/src/langbot/pkg/provider/tools/loader.py @@ -4,12 +4,15 @@ import abc import typing from typing import TYPE_CHECKING +from langbot_plugin.api.definition.components.manifest import ComponentManifest from langbot_plugin.api.entities.events import pipeline_query import langbot_plugin.api.entities.builtin.resource.tool as resource_tool if TYPE_CHECKING: from ...core import app +ToolLookupResult = resource_tool.LLMTool | ComponentManifest + preregistered_loaders: list[typing.Type[ToolLoader]] = [] @@ -43,6 +46,17 @@ class ToolLoader(abc.ABC): """获取所有工具""" pass + async def get_tool(self, name: str) -> ToolLookupResult | None: + """Get one tool by name. + + Loaders with a cheaper direct lookup can override this method. The + default keeps simple loaders working by searching their public list. + """ + for tool in await self.get_tools(): + if tool.name == name: + return tool + return None + @abc.abstractmethod async def has_tool(self, name: str) -> bool: """检查工具是否存在""" diff --git a/src/langbot/pkg/provider/tools/loaders/mcp.py b/src/langbot/pkg/provider/tools/loaders/mcp.py index 4e12ca42..835b7025 100644 --- a/src/langbot/pkg/provider/tools/loaders/mcp.py +++ b/src/langbot/pkg/provider/tools/loaders/mcp.py @@ -525,7 +525,7 @@ class MCPLoader(loader.ToolLoader): return True return False - async def _get_tool(self, name: str) -> resource_tool.LLMTool | None: + async def get_tool(self, name: str) -> resource_tool.LLMTool | None: """Get tool by name. Args: diff --git a/src/langbot/pkg/provider/tools/loaders/plugin.py b/src/langbot/pkg/provider/tools/loaders/plugin.py index 7e6aab82..860c21e6 100644 --- a/src/langbot/pkg/provider/tools/loaders/plugin.py +++ b/src/langbot/pkg/provider/tools/loaders/plugin.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing import traceback +from langbot_plugin.api.definition.components.manifest import ComponentManifest from langbot_plugin.api.entities.events import pipeline_query from .. import loader @@ -39,7 +40,7 @@ class PluginToolLoader(loader.ToolLoader): return True return False - async def _get_tool(self, name: str) -> resource_tool.LLMTool: + async def get_tool(self, name: str) -> ComponentManifest | None: for tool in await self.ap.plugin_connector.list_tools(): if tool.metadata.name == name: return tool diff --git a/src/langbot/pkg/provider/tools/toolmgr.py b/src/langbot/pkg/provider/tools/toolmgr.py index 35d0c84c..fba71ed9 100644 --- a/src/langbot/pkg/provider/tools/toolmgr.py +++ b/src/langbot/pkg/provider/tools/toolmgr.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING import langbot_plugin.api.entities.builtin.resource.tool as resource_tool from langbot_plugin.api.entities.events import pipeline_query +from . import loader as tool_loader from .errors import ToolNotFoundError if TYPE_CHECKING: @@ -69,7 +70,7 @@ class ToolManager: return all_functions - async def get_tool_by_name(self, name: str) -> resource_tool.LLMTool | None: + async def get_tool_by_name(self, name: str) -> tool_loader.ToolLookupResult | None: """Get tool by name from any active loader. Args: @@ -78,28 +79,18 @@ class ToolManager: Returns: LLMTool if found, None otherwise """ - for tool_loader in ( + for active_loader in ( self.native_tool_loader, self.plugin_tool_loader, self.mcp_tool_loader, self.skill_tool_loader, ): - tool = await self._get_tool_from_loader(tool_loader, name) + tool = await active_loader.get_tool(name) if tool: return tool return None - async def _get_tool_from_loader(self, tool_loader: typing.Any, name: str) -> resource_tool.LLMTool | None: - if hasattr(tool_loader, '_get_tool'): - return await tool_loader._get_tool(name) - - for tool in await tool_loader.get_tools(): - if tool.name == name: - return tool - - return None - async def generate_tools_for_openai(self, use_funcs: list[resource_tool.LLMTool]) -> list: tools = [] diff --git a/tests/unit_tests/provider/test_tool_manager.py b/tests/unit_tests/provider/test_tool_manager.py index 8e8439f5..1bb406b4 100644 --- a/tests/unit_tests/provider/test_tool_manager.py +++ b/tests/unit_tests/provider/test_tool_manager.py @@ -265,7 +265,7 @@ class TestToolManagerExecuteFuncCall: @pytest.mark.asyncio async def test_execute_raises_when_tool_not_found(self, mock_app_with_loaders, sample_query): - """Test that execute_func_call raises ValueError when tool not found.""" + """Test that execute_func_call raises ToolNotFoundError when tool not found.""" toolmgr = get_toolmgr_module() mock_app, mock_plugin_loader, mock_mcp_loader = mock_app_with_loaders @@ -275,7 +275,7 @@ class TestToolManagerExecuteFuncCall: manager = toolmgr.ToolManager(mock_app) self._wire_loaders(manager, mock_app, mock_plugin_loader, mock_mcp_loader) - with pytest.raises(ValueError, match='未找到工具'): + with pytest.raises(toolmgr.ToolNotFoundError, match='Tool not found: unknown_tool'): await manager.execute_func_call('unknown_tool', {}, sample_query) @pytest.mark.asyncio diff --git a/tests/unit_tests/provider/test_tool_manager_native.py b/tests/unit_tests/provider/test_tool_manager_native.py index 2d654fac..f0ad78ae 100644 --- a/tests/unit_tests/provider/test_tool_manager_native.py +++ b/tests/unit_tests/provider/test_tool_manager_native.py @@ -9,6 +9,7 @@ import pytest import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +from langbot.pkg.provider.tools.loader import ToolLoader from langbot.pkg.provider.tools.loaders.native import ( _DEFAULT_TOOL_RESULT_MAX_BYTES, _GLOB_MAX_MATCHES, @@ -26,6 +27,12 @@ class StubLoader: async def get_tools(self, *_args, **_kwargs): return self._tools + async def get_tool(self, name: str): + for tool in self._tools: + if tool.name == name: + return tool + return None + async def has_tool(self, name: str) -> bool: return any(tool.name == name for tool in self._tools) @@ -36,6 +43,14 @@ class StubLoader: return None +class DirectLookupLoader(StubLoader): + async def get_tools(self, *_args, **_kwargs): + raise AssertionError('ToolManager should use the loader get_tool contract') + + async def get_tool(self, name: str): + return make_tool(name) if name == 'direct_tool' else None + + def make_tool(name: str) -> resource_tool.LLMTool: return resource_tool.LLMTool( name=name, @@ -46,6 +61,32 @@ def make_tool(name: str) -> resource_tool.LLMTool: ) +class ListOnlyLoader(ToolLoader): + async def get_tools(self, *_args, **_kwargs): + return [make_tool('listed_tool')] + + async def has_tool(self, name: str) -> bool: + return name == 'listed_tool' + + async def invoke_tool(self, name: str, parameters: dict, query): + return parameters + + async def shutdown(self): + return None + + +@pytest.mark.asyncio +async def test_tool_loader_get_tool_falls_back_to_public_tool_list(): + loader = ListOnlyLoader(SimpleNamespace()) + + tool = await loader.get_tool('listed_tool') + missing_tool = await loader.get_tool('missing_tool') + + assert tool is not None + assert tool.name == 'listed_tool' + assert missing_tool is None + + @pytest.mark.asyncio async def test_tool_manager_omits_skill_authoring_tools_by_default(): manager = ToolManager(SimpleNamespace()) @@ -103,6 +144,20 @@ async def test_tool_manager_get_tool_by_name_resolves_native_and_skill_tools(): assert skill_tool.name == 'activate' +@pytest.mark.asyncio +async def test_tool_manager_uses_loader_get_tool_contract(): + manager = ToolManager(SimpleNamespace()) + manager.native_tool_loader = StubLoader([]) + manager.skill_tool_loader = StubLoader([]) + manager.plugin_tool_loader = DirectLookupLoader() + manager.mcp_tool_loader = StubLoader([]) + + tool = await manager.get_tool_by_name('direct_tool') + + assert tool is not None + assert tool.name == 'direct_tool' + + @pytest.mark.asyncio async def test_native_tool_loader_hides_tools_when_box_unavailable(): loader = NativeToolLoader(SimpleNamespace(box_service=SimpleNamespace(available=False)))