feat: add Tool component

This commit is contained in:
Junyan Qin
2025-07-06 21:03:33 +08:00
parent a60aa6f644
commit 5b044a1917
11 changed files with 84 additions and 66 deletions

View File

@@ -12,7 +12,7 @@ class PluginsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
plugins = await self.ap.plugin_connector.handler.list_plugins()
plugins = await self.ap.plugin_connector.list_plugins()
return self.success(data={'plugins': plugins})

View File

@@ -35,7 +35,7 @@ class SystemRouterGroup(group.RouterGroup):
return self.success(data=task.to_dict())
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
@self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')
@@ -45,3 +45,14 @@ class SystemRouterGroup(group.RouterGroup):
ap = self.ap
return self.success(data=exec(py_code, {'ap': ap}))
@self.route('/debug/tools/call', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')
data = await quart.request.json
return self.success(
data=await self.ap.tool_mgr.execute_func_call(data['tool_name'], data['tool_parameters'])
)

View File

@@ -11,9 +11,7 @@ class FuncOperator(operator.CommandOperator):
index = 1
all_functions = await self.ap.tool_mgr.get_all_functions(
plugin_enabled=True,
)
all_functions = await self.ap.tool_mgr.get_all_tools()
for func in all_functions:
reply_str += '{}. {}:\n{}\n\n'.format(

View File

@@ -60,9 +60,7 @@ class PreProcessor(stage.PipelineStage):
query.use_funcs = []
if llm_model.model_entity.abilities.__contains__('func_call'):
query.use_funcs = await self.ap.tool_mgr.get_all_functions(
plugin_enabled=True,
)
query.use_funcs = await self.ap.tool_mgr.get_all_tools()
query.variables = {
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from typing import Any
import typing
import os
import sys
@@ -11,8 +12,10 @@ from . import handler
from ..utils import platform
from langbot_plugin.runtime.io.controllers.stdio import client as stdio_client_controller
from langbot_plugin.runtime.io.controllers.ws import client as ws_client_controller
from langbot_plugin.api.entities import events, context
from langbot_plugin.api.entities import events
from langbot_plugin.api.entities import context
import langbot_plugin.runtime.io.connection as base_connection
from langbot_plugin.api.definition.components.manifest import ComponentManifest
class PluginRuntimeConnector:
@@ -91,6 +94,9 @@ class PluginRuntimeConnector:
async def initialize_plugins(self):
pass
async def list_plugins(self) -> list[dict[str, Any]]:
return await self.handler.list_plugins()
async def emit_event(
self,
event: events.BaseEventModel,
@@ -104,3 +110,11 @@ class PluginRuntimeConnector:
event_ctx = context.EventContext.parse_from_dict(event_ctx_result['event_context'])
return event_ctx
async def list_tools(self) -> list[ComponentManifest]:
list_tools_data = await self.handler.list_tools()
return [ComponentManifest.model_validate(tool) for tool in list_tools_data]
async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]:
return await self.handler.call_tool(tool_name, parameters)

View File

@@ -94,3 +94,26 @@ class RuntimeConnectionHandler(handler.Handler):
)
return result
async def list_tools(self) -> list[dict[str, Any]]:
"""List tools"""
result = await self.call_action(
LangBotToRuntimeAction.LIST_TOOLS,
{},
timeout=10,
)
return result['tools']
async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]:
"""Call tool"""
result = await self.call_action(
LangBotToRuntimeAction.CALL_TOOL,
{
'tool_name': tool_name,
'tool_parameters': parameters,
},
timeout=30,
)
return result['tool_response']

View File

@@ -43,7 +43,7 @@ class LocalAgentRunner(runner.RequestRunner):
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(query, func.name, parameters)
func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters)
msg = llm_entities.Message(
role='tool',

View File

@@ -5,7 +5,6 @@ import typing
from ...core import app
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_loaders: list[typing.Type[ToolLoader]] = []
@@ -36,7 +35,7 @@ class ToolLoader(abc.ABC):
pass
@abc.abstractmethod
async def get_tools(self, enabled: bool = True) -> list[resource_tool.LLMTool]:
async def get_tools(self) -> list[resource_tool.LLMTool]:
"""获取所有工具"""
pass
@@ -46,7 +45,7 @@ class ToolLoader(abc.ABC):
pass
@abc.abstractmethod
async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
"""执行工具调用"""
pass

View File

@@ -10,7 +10,6 @@ from mcp.client.sse import sse_client
from .. import loader
from ....core import app
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class RuntimeMCPSession:
@@ -84,7 +83,7 @@ class RuntimeMCPSession:
for tool in tools.tools:
async def func(query: pipeline_query.Query, *, _tool=tool, **kwargs):
async def func(*, _tool=tool, **kwargs):
result = await self.session.call_tool(_tool.name, kwargs)
if result.isError:
raise Exception(result.content[0].text)
@@ -132,7 +131,7 @@ class MCPLoader(loader.ToolLoader):
# self.ap.event_loop.create_task(session.initialize())
self.sessions[server_config['name']] = session
async def get_tools(self, enabled: bool = True) -> list[resource_tool.LLMTool]:
async def get_tools(self) -> list[resource_tool.LLMTool]:
all_functions = []
for session in self.sessions.values():
@@ -145,11 +144,11 @@ class MCPLoader(loader.ToolLoader):
async def has_tool(self, name: str) -> bool:
return name in [f.name for f in self._last_listed_functions]
async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
for server_name, session in self.sessions.items():
for function in session.functions:
if function.name == name:
return await function.func(query, **parameters)
return await function.func(**parameters)
raise ValueError(f'未找到工具: {name}')

View File

@@ -4,9 +4,7 @@ import typing
import traceback
from .. import loader
from ....plugin import context as plugin_context
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@loader.loader_class('plugin-tool-loader')
@@ -16,63 +14,42 @@ class PluginToolLoader(loader.ToolLoader):
本加载器中不存储工具信息,仅负责从插件系统中获取工具信息。
"""
async def get_tools(self, enabled: bool = True) -> list[resource_tool.LLMTool]:
async def get_tools(self) -> list[resource_tool.LLMTool]:
# 从插件系统获取工具(内容函数)
all_functions: list[resource_tool.LLMTool] = []
for plugin in self.ap.plugin_mgr.plugins(
enabled=enabled, status=plugin_context.RuntimeContainerStatus.INITIALIZED
):
all_functions.extend(plugin.tools)
for tool in await self.ap.plugin_connector.list_tools():
tool_obj = resource_tool.LLMTool(
name=tool.metadata.name,
human_desc=tool.metadata.description.en_US,
description=tool.spec['llm_prompt'],
parameters=tool.spec['parameters'],
func=lambda parameters: {},
)
all_functions.append(tool_obj)
return all_functions
async def has_tool(self, name: str) -> bool:
"""检查工具是否存在"""
for plugin in self.ap.plugin_mgr.plugins(
enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED
):
for function in plugin.tools:
if function.name == name:
return True
for tool in await self.ap.plugin_connector.list_tools():
if tool.metadata.name == name:
return True
return False
async def _get_function_and_plugin(
self, name: str
) -> typing.Tuple[resource_tool.LLMTool, plugin_context.BasePlugin]:
"""获取函数和插件实例"""
for plugin in self.ap.plugin_mgr.plugins(
enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED
):
for function in plugin.tools:
if function.name == name:
return function, plugin.plugin_inst
return None, None
async def _get_tool(self, name: str) -> resource_tool.LLMTool:
for tool in await self.ap.plugin_connector.list_tools():
if tool.metadata.name == name:
return tool
return None
async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
try:
function, plugin = await self._get_function_and_plugin(name)
if function is None:
return None
parameters = parameters.copy()
parameters = {'query': query, **parameters}
return await function.func(plugin, **parameters)
return await self.ap.plugin_connector.call_tool(name, parameters)
except Exception as e:
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
traceback.print_exc()
return f'error occurred when executing function {name}: {e}'
finally:
plugin = None
for p in self.ap.plugin_mgr.plugins():
if function in p.tools:
plugin = p
break
# TODO statistics
async def shutdown(self):
"""关闭工具"""

View File

@@ -7,7 +7,6 @@ from . import loader as tools_loader
from ...utils import importutil
from . import loaders
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
importutil.import_modules_in_pkg(loaders)
@@ -30,12 +29,12 @@ class ToolManager:
await loader_inst.initialize()
self.loaders.append(loader_inst)
async def get_all_functions(self, plugin_enabled: bool = None) -> list[resource_tool.LLMTool]:
async def get_all_tools(self) -> list[resource_tool.LLMTool]:
"""获取所有函数"""
all_functions: list[resource_tool.LLMTool] = []
for loader in self.loaders:
all_functions.extend(await loader.get_tools(plugin_enabled))
all_functions.extend(await loader.get_tools())
return all_functions
@@ -91,12 +90,12 @@ class ToolManager:
return tools
async def execute_func_call(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
async def execute_func_call(self, name: str, parameters: dict) -> typing.Any:
"""执行函数调用"""
for loader in self.loaders:
if await loader.has_tool(name):
return await loader.invoke_tool(query, name, parameters)
return await loader.invoke_tool(name, parameters)
else:
raise ValueError(f'未找到工具: {name}')