mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: add Tool component
This commit is contained in:
@@ -12,7 +12,7 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
plugins = await self.ap.plugin_connector.handler.list_plugins()
|
||||
plugins = await self.ap.plugin_connector.list_plugins()
|
||||
|
||||
return self.success(data={'plugins': plugins})
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
|
||||
return self.success(data=task.to_dict())
|
||||
|
||||
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
@self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
return self.http_status(403, 403, 'Forbidden')
|
||||
@@ -45,3 +45,14 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
ap = self.ap
|
||||
|
||||
return self.success(data=exec(py_code, {'ap': ap}))
|
||||
|
||||
@self.route('/debug/tools/call', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
return self.http_status(403, 403, 'Forbidden')
|
||||
|
||||
data = await quart.request.json
|
||||
|
||||
return self.success(
|
||||
data=await self.ap.tool_mgr.execute_func_call(data['tool_name'], data['tool_parameters'])
|
||||
)
|
||||
|
||||
@@ -11,9 +11,7 @@ class FuncOperator(operator.CommandOperator):
|
||||
|
||||
index = 1
|
||||
|
||||
all_functions = await self.ap.tool_mgr.get_all_functions(
|
||||
plugin_enabled=True,
|
||||
)
|
||||
all_functions = await self.ap.tool_mgr.get_all_tools()
|
||||
|
||||
for func in all_functions:
|
||||
reply_str += '{}. {}:\n{}\n\n'.format(
|
||||
|
||||
@@ -60,9 +60,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.use_funcs = []
|
||||
|
||||
if llm_model.model_entity.abilities.__contains__('func_call'):
|
||||
query.use_funcs = await self.ap.tool_mgr.get_all_functions(
|
||||
plugin_enabled=True,
|
||||
)
|
||||
query.use_funcs = await self.ap.tool_mgr.get_all_tools()
|
||||
|
||||
query.variables = {
|
||||
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
import typing
|
||||
import os
|
||||
import sys
|
||||
@@ -11,8 +12,10 @@ from . import handler
|
||||
from ..utils import platform
|
||||
from langbot_plugin.runtime.io.controllers.stdio import client as stdio_client_controller
|
||||
from langbot_plugin.runtime.io.controllers.ws import client as ws_client_controller
|
||||
from langbot_plugin.api.entities import events, context
|
||||
from langbot_plugin.api.entities import events
|
||||
from langbot_plugin.api.entities import context
|
||||
import langbot_plugin.runtime.io.connection as base_connection
|
||||
from langbot_plugin.api.definition.components.manifest import ComponentManifest
|
||||
|
||||
|
||||
class PluginRuntimeConnector:
|
||||
@@ -91,6 +94,9 @@ class PluginRuntimeConnector:
|
||||
async def initialize_plugins(self):
|
||||
pass
|
||||
|
||||
async def list_plugins(self) -> list[dict[str, Any]]:
|
||||
return await self.handler.list_plugins()
|
||||
|
||||
async def emit_event(
|
||||
self,
|
||||
event: events.BaseEventModel,
|
||||
@@ -104,3 +110,11 @@ class PluginRuntimeConnector:
|
||||
event_ctx = context.EventContext.parse_from_dict(event_ctx_result['event_context'])
|
||||
|
||||
return event_ctx
|
||||
|
||||
async def list_tools(self) -> list[ComponentManifest]:
|
||||
list_tools_data = await self.handler.list_tools()
|
||||
|
||||
return [ComponentManifest.model_validate(tool) for tool in list_tools_data]
|
||||
|
||||
async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
return await self.handler.call_tool(tool_name, parameters)
|
||||
|
||||
@@ -94,3 +94,26 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def list_tools(self) -> list[dict[str, Any]]:
|
||||
"""List tools"""
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.LIST_TOOLS,
|
||||
{},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
return result['tools']
|
||||
|
||||
async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call tool"""
|
||||
result = await self.call_action(
|
||||
LangBotToRuntimeAction.CALL_TOOL,
|
||||
{
|
||||
'tool_name': tool_name,
|
||||
'tool_parameters': parameters,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
return result['tool_response']
|
||||
|
||||
@@ -43,7 +43,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
|
||||
parameters = json.loads(func.arguments)
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(query, func.name, parameters)
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters)
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role='tool',
|
||||
|
||||
@@ -5,7 +5,6 @@ import typing
|
||||
|
||||
from ...core import app
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
preregistered_loaders: list[typing.Type[ToolLoader]] = []
|
||||
@@ -36,7 +35,7 @@ class ToolLoader(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_tools(self, enabled: bool = True) -> list[resource_tool.LLMTool]:
|
||||
async def get_tools(self) -> list[resource_tool.LLMTool]:
|
||||
"""获取所有工具"""
|
||||
pass
|
||||
|
||||
@@ -46,7 +45,7 @@ class ToolLoader(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
|
||||
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
|
||||
"""执行工具调用"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from mcp.client.sse import sse_client
|
||||
from .. import loader
|
||||
from ....core import app
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
class RuntimeMCPSession:
|
||||
@@ -84,7 +83,7 @@ class RuntimeMCPSession:
|
||||
|
||||
for tool in tools.tools:
|
||||
|
||||
async def func(query: pipeline_query.Query, *, _tool=tool, **kwargs):
|
||||
async def func(*, _tool=tool, **kwargs):
|
||||
result = await self.session.call_tool(_tool.name, kwargs)
|
||||
if result.isError:
|
||||
raise Exception(result.content[0].text)
|
||||
@@ -132,7 +131,7 @@ class MCPLoader(loader.ToolLoader):
|
||||
# self.ap.event_loop.create_task(session.initialize())
|
||||
self.sessions[server_config['name']] = session
|
||||
|
||||
async def get_tools(self, enabled: bool = True) -> list[resource_tool.LLMTool]:
|
||||
async def get_tools(self) -> list[resource_tool.LLMTool]:
|
||||
all_functions = []
|
||||
|
||||
for session in self.sessions.values():
|
||||
@@ -145,11 +144,11 @@ class MCPLoader(loader.ToolLoader):
|
||||
async def has_tool(self, name: str) -> bool:
|
||||
return name in [f.name for f in self._last_listed_functions]
|
||||
|
||||
async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
|
||||
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
|
||||
for server_name, session in self.sessions.items():
|
||||
for function in session.functions:
|
||||
if function.name == name:
|
||||
return await function.func(query, **parameters)
|
||||
return await function.func(**parameters)
|
||||
|
||||
raise ValueError(f'未找到工具: {name}')
|
||||
|
||||
|
||||
@@ -4,9 +4,7 @@ import typing
|
||||
import traceback
|
||||
|
||||
from .. import loader
|
||||
from ....plugin import context as plugin_context
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@loader.loader_class('plugin-tool-loader')
|
||||
@@ -16,63 +14,42 @@ class PluginToolLoader(loader.ToolLoader):
|
||||
本加载器中不存储工具信息,仅负责从插件系统中获取工具信息。
|
||||
"""
|
||||
|
||||
async def get_tools(self, enabled: bool = True) -> list[resource_tool.LLMTool]:
|
||||
async def get_tools(self) -> list[resource_tool.LLMTool]:
|
||||
# 从插件系统获取工具(内容函数)
|
||||
all_functions: list[resource_tool.LLMTool] = []
|
||||
|
||||
for plugin in self.ap.plugin_mgr.plugins(
|
||||
enabled=enabled, status=plugin_context.RuntimeContainerStatus.INITIALIZED
|
||||
):
|
||||
all_functions.extend(plugin.tools)
|
||||
for tool in await self.ap.plugin_connector.list_tools():
|
||||
tool_obj = resource_tool.LLMTool(
|
||||
name=tool.metadata.name,
|
||||
human_desc=tool.metadata.description.en_US,
|
||||
description=tool.spec['llm_prompt'],
|
||||
parameters=tool.spec['parameters'],
|
||||
func=lambda parameters: {},
|
||||
)
|
||||
all_functions.append(tool_obj)
|
||||
|
||||
return all_functions
|
||||
|
||||
async def has_tool(self, name: str) -> bool:
|
||||
"""检查工具是否存在"""
|
||||
for plugin in self.ap.plugin_mgr.plugins(
|
||||
enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED
|
||||
):
|
||||
for function in plugin.tools:
|
||||
if function.name == name:
|
||||
return True
|
||||
for tool in await self.ap.plugin_connector.list_tools():
|
||||
if tool.metadata.name == name:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _get_function_and_plugin(
|
||||
self, name: str
|
||||
) -> typing.Tuple[resource_tool.LLMTool, plugin_context.BasePlugin]:
|
||||
"""获取函数和插件实例"""
|
||||
for plugin in self.ap.plugin_mgr.plugins(
|
||||
enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED
|
||||
):
|
||||
for function in plugin.tools:
|
||||
if function.name == name:
|
||||
return function, plugin.plugin_inst
|
||||
return None, None
|
||||
async def _get_tool(self, name: str) -> resource_tool.LLMTool:
|
||||
for tool in await self.ap.plugin_connector.list_tools():
|
||||
if tool.metadata.name == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
|
||||
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
|
||||
try:
|
||||
function, plugin = await self._get_function_and_plugin(name)
|
||||
if function is None:
|
||||
return None
|
||||
|
||||
parameters = parameters.copy()
|
||||
|
||||
parameters = {'query': query, **parameters}
|
||||
|
||||
return await function.func(plugin, **parameters)
|
||||
return await self.ap.plugin_connector.call_tool(name, parameters)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
|
||||
traceback.print_exc()
|
||||
return f'error occurred when executing function {name}: {e}'
|
||||
finally:
|
||||
plugin = None
|
||||
|
||||
for p in self.ap.plugin_mgr.plugins():
|
||||
if function in p.tools:
|
||||
plugin = p
|
||||
break
|
||||
|
||||
# TODO statistics
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭工具"""
|
||||
|
||||
@@ -7,7 +7,6 @@ from . import loader as tools_loader
|
||||
from ...utils import importutil
|
||||
from . import loaders
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
importutil.import_modules_in_pkg(loaders)
|
||||
|
||||
@@ -30,12 +29,12 @@ class ToolManager:
|
||||
await loader_inst.initialize()
|
||||
self.loaders.append(loader_inst)
|
||||
|
||||
async def get_all_functions(self, plugin_enabled: bool = None) -> list[resource_tool.LLMTool]:
|
||||
async def get_all_tools(self) -> list[resource_tool.LLMTool]:
|
||||
"""获取所有函数"""
|
||||
all_functions: list[resource_tool.LLMTool] = []
|
||||
|
||||
for loader in self.loaders:
|
||||
all_functions.extend(await loader.get_tools(plugin_enabled))
|
||||
all_functions.extend(await loader.get_tools())
|
||||
|
||||
return all_functions
|
||||
|
||||
@@ -91,12 +90,12 @@ class ToolManager:
|
||||
|
||||
return tools
|
||||
|
||||
async def execute_func_call(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
|
||||
async def execute_func_call(self, name: str, parameters: dict) -> typing.Any:
|
||||
"""执行函数调用"""
|
||||
|
||||
for loader in self.loaders:
|
||||
if await loader.has_tool(name):
|
||||
return await loader.invoke_tool(query, name, parameters)
|
||||
return await loader.invoke_tool(name, parameters)
|
||||
else:
|
||||
raise ValueError(f'未找到工具: {name}')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user