mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-10 15:56:03 +00:00
style: introduce ruff as linter and formatter (#1356)
* style: remove necessary imports * style: fix F841 * style: fix F401 * style: fix F811 * style: fix E402 * style: fix E721 * style: fix E722 * style: fix E722 * style: fix F541 * style: ruff format * style: all passed * style: add ruff in deps * style: more ignores in ruff.toml * style: add pre-commit
This commit is contained in:
committed by
GitHub
parent
09e70d70e9
commit
209f16af76
@@ -1,13 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
import asyncio
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from ...core import entities as core_entities
|
||||
|
||||
|
||||
class LLMFunction(pydantic.BaseModel):
|
||||
"""函数"""
|
||||
|
||||
@@ -9,9 +9,10 @@ from . import entities as tools_entities
|
||||
|
||||
preregistered_loaders: list[typing.Type[ToolLoader]] = []
|
||||
|
||||
|
||||
def loader_class(name: str):
|
||||
"""注册一个工具加载器
|
||||
"""
|
||||
"""注册一个工具加载器"""
|
||||
|
||||
def decorator(cls: typing.Type[ToolLoader]) -> typing.Type[ToolLoader]:
|
||||
cls.name = name
|
||||
preregistered_loaders.append(cls)
|
||||
@@ -22,7 +23,7 @@ def loader_class(name: str):
|
||||
|
||||
class ToolLoader(abc.ABC):
|
||||
"""工具加载器"""
|
||||
|
||||
|
||||
name: str = None
|
||||
|
||||
ap: app.Application
|
||||
@@ -34,7 +35,7 @@ class ToolLoader(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
|
||||
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
|
||||
"""获取所有工具"""
|
||||
pass
|
||||
|
||||
@@ -44,11 +45,13 @@ class ToolLoader(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
||||
async def invoke_tool(
|
||||
self, query: core_entities.Query, name: str, parameters: dict
|
||||
) -> typing.Any:
|
||||
"""执行工具调用"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def shutdown(self):
|
||||
"""关闭工具"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -30,7 +30,7 @@ class RuntimeMCPSession:
|
||||
self.server_name = server_name
|
||||
self.server_config = server_config
|
||||
self.ap = ap
|
||||
|
||||
|
||||
self.session = None
|
||||
|
||||
self.exit_stack = AsyncExitStack()
|
||||
@@ -38,9 +38,9 @@ class RuntimeMCPSession:
|
||||
|
||||
async def _init_stdio_python_server(self):
|
||||
server_params = StdioServerParameters(
|
||||
command=self.server_config["command"],
|
||||
args=self.server_config["args"],
|
||||
env=self.server_config["env"],
|
||||
command=self.server_config['command'],
|
||||
args=self.server_config['args'],
|
||||
env=self.server_config['env'],
|
||||
)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
@@ -58,12 +58,12 @@ class RuntimeMCPSession:
|
||||
async def _init_sse_server(self):
|
||||
sse_transport = await self.exit_stack.enter_async_context(
|
||||
sse_client(
|
||||
self.server_config["url"],
|
||||
headers=self.server_config.get("headers", {}),
|
||||
timeout=self.server_config.get("timeout", 10),
|
||||
self.server_config['url'],
|
||||
headers=self.server_config.get('headers', {}),
|
||||
timeout=self.server_config.get('timeout', 10),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
sseio, write = sse_transport
|
||||
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
@@ -73,18 +73,22 @@ class RuntimeMCPSession:
|
||||
await self.session.initialize()
|
||||
|
||||
async def initialize(self):
|
||||
self.ap.logger.debug(f"初始化 MCP 会话: {self.server_name} {self.server_config}")
|
||||
self.ap.logger.debug(
|
||||
f'初始化 MCP 会话: {self.server_name} {self.server_config}'
|
||||
)
|
||||
|
||||
if self.server_config["mode"] == "stdio":
|
||||
if self.server_config['mode'] == 'stdio':
|
||||
await self._init_stdio_python_server()
|
||||
elif self.server_config["mode"] == "sse":
|
||||
elif self.server_config['mode'] == 'sse':
|
||||
await self._init_sse_server()
|
||||
else:
|
||||
raise ValueError(f"无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}")
|
||||
|
||||
raise ValueError(
|
||||
f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}'
|
||||
)
|
||||
|
||||
tools = await self.session.list_tools()
|
||||
|
||||
self.ap.logger.debug(f"获取 MCP 工具: {tools}")
|
||||
self.ap.logger.debug(f'获取 MCP 工具: {tools}')
|
||||
|
||||
for tool in tools.tools:
|
||||
|
||||
@@ -93,25 +97,28 @@ class RuntimeMCPSession:
|
||||
if result.isError:
|
||||
raise Exception(result.content[0].text)
|
||||
return result.content[0].text
|
||||
|
||||
|
||||
func.__name__ = tool.name
|
||||
|
||||
self.functions.append(tools_entities.LLMFunction(
|
||||
name=tool.name,
|
||||
human_desc=tool.description,
|
||||
description=tool.description,
|
||||
parameters=tool.inputSchema,
|
||||
func=func,
|
||||
))
|
||||
self.functions.append(
|
||||
tools_entities.LLMFunction(
|
||||
name=tool.name,
|
||||
human_desc=tool.description,
|
||||
description=tool.description,
|
||||
parameters=tool.inputSchema,
|
||||
func=func,
|
||||
)
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭工具"""
|
||||
await self.session._exit_stack.aclose()
|
||||
|
||||
@loader.loader_class("mcp")
|
||||
|
||||
@loader.loader_class('mcp')
|
||||
class MCPLoader(loader.ToolLoader):
|
||||
"""MCP 工具加载器。
|
||||
|
||||
|
||||
在此加载器中管理所有与 MCP Server 的连接。
|
||||
"""
|
||||
|
||||
@@ -125,16 +132,17 @@ class MCPLoader(loader.ToolLoader):
|
||||
self._last_listed_functions = []
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
for server_config in self.ap.instance_config.data.get("mcp", {}).get("servers", []):
|
||||
if not server_config["enable"]:
|
||||
for server_config in self.ap.instance_config.data.get('mcp', {}).get(
|
||||
'servers', []
|
||||
):
|
||||
if not server_config['enable']:
|
||||
continue
|
||||
session = RuntimeMCPSession(server_config["name"], server_config, self.ap)
|
||||
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
|
||||
await session.initialize()
|
||||
# self.ap.event_loop.create_task(session.initialize())
|
||||
self.sessions[server_config["name"]] = session
|
||||
self.sessions[server_config['name']] = session
|
||||
|
||||
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
|
||||
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
|
||||
all_functions = []
|
||||
|
||||
for session in self.sessions.values():
|
||||
@@ -147,13 +155,15 @@ class MCPLoader(loader.ToolLoader):
|
||||
async def has_tool(self, name: str) -> bool:
|
||||
return name in [f.name for f in self._last_listed_functions]
|
||||
|
||||
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
||||
async def invoke_tool(
|
||||
self, query: core_entities.Query, name: str, parameters: dict
|
||||
) -> typing.Any:
|
||||
for server_name, session in self.sessions.items():
|
||||
for function in session.functions:
|
||||
if function.name == name:
|
||||
return await function.func(query, **parameters)
|
||||
|
||||
raise ValueError(f"未找到工具: {name}")
|
||||
raise ValueError(f'未找到工具: {name}')
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭工具"""
|
||||
|
||||
@@ -4,19 +4,18 @@ import typing
|
||||
import traceback
|
||||
|
||||
from .. import loader, entities as tools_entities
|
||||
from ....core import app, entities as core_entities
|
||||
from ....core import entities as core_entities
|
||||
from ....plugin import context as plugin_context
|
||||
|
||||
|
||||
@loader.loader_class("plugin-tool-loader")
|
||||
@loader.loader_class('plugin-tool-loader')
|
||||
class PluginToolLoader(loader.ToolLoader):
|
||||
"""插件工具加载器。
|
||||
|
||||
|
||||
本加载器中不存储工具信息,仅负责从插件系统中获取工具信息。
|
||||
"""
|
||||
|
||||
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
|
||||
|
||||
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
|
||||
# 从插件系统获取工具(内容函数)
|
||||
all_functions: list[tools_entities.LLMFunction] = []
|
||||
|
||||
@@ -49,23 +48,23 @@ class PluginToolLoader(loader.ToolLoader):
|
||||
return function, plugin.plugin_inst
|
||||
return None, None
|
||||
|
||||
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
||||
|
||||
async def invoke_tool(
|
||||
self, query: core_entities.Query, name: str, parameters: dict
|
||||
) -> typing.Any:
|
||||
try:
|
||||
|
||||
function, plugin = await self._get_function_and_plugin(name)
|
||||
if function is None:
|
||||
return None
|
||||
|
||||
parameters = parameters.copy()
|
||||
|
||||
parameters = {"query": query, **parameters}
|
||||
parameters = {'query': query, **parameters}
|
||||
|
||||
return await function.func(plugin, **parameters)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"执行函数 {name} 时发生错误: {e}")
|
||||
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
|
||||
traceback.print_exc()
|
||||
return f"error occurred when executing function {name}: {e}"
|
||||
return f'error occurred when executing function {name}: {e}'
|
||||
finally:
|
||||
plugin = None
|
||||
|
||||
@@ -75,13 +74,12 @@ class PluginToolLoader(loader.ToolLoader):
|
||||
break
|
||||
|
||||
if plugin is not None:
|
||||
|
||||
await self.ap.ctr_mgr.usage.post_function_record(
|
||||
plugin={
|
||||
"name": plugin.plugin_name,
|
||||
"remote": plugin.plugin_repository,
|
||||
"version": plugin.plugin_version,
|
||||
"author": plugin.plugin_author,
|
||||
'name': plugin.plugin_name,
|
||||
'remote': plugin.plugin_repository,
|
||||
'version': plugin.plugin_version,
|
||||
'author': plugin.plugin_author,
|
||||
},
|
||||
function_name=function.name,
|
||||
function_description=function.description,
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities, loader as tools_loader
|
||||
from ...plugin import context as plugin_context
|
||||
from .loaders import plugin, mcp
|
||||
from ...utils import importutil
|
||||
from . import loaders
|
||||
|
||||
importutil.import_modules_in_pkg(loaders)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
@@ -22,13 +23,14 @@ class ToolManager:
|
||||
self.loaders = []
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
for loader_cls in tools_loader.preregistered_loaders:
|
||||
loader_inst = loader_cls(self.ap)
|
||||
await loader_inst.initialize()
|
||||
self.loaders.append(loader_inst)
|
||||
|
||||
async def get_all_functions(self, plugin_enabled: bool=None) -> list[entities.LLMFunction]:
|
||||
async def get_all_functions(
|
||||
self, plugin_enabled: bool = None
|
||||
) -> list[entities.LLMFunction]:
|
||||
"""获取所有函数"""
|
||||
all_functions: list[entities.LLMFunction] = []
|
||||
|
||||
@@ -37,17 +39,19 @@ class ToolManager:
|
||||
|
||||
return all_functions
|
||||
|
||||
async def generate_tools_for_openai(self, use_funcs: list[entities.LLMFunction]) -> list:
|
||||
async def generate_tools_for_openai(
|
||||
self, use_funcs: list[entities.LLMFunction]
|
||||
) -> list:
|
||||
"""生成函数列表"""
|
||||
tools = []
|
||||
|
||||
for function in use_funcs:
|
||||
function_schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"parameters": function.parameters,
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': function.name,
|
||||
'description': function.description,
|
||||
'parameters': function.parameters,
|
||||
},
|
||||
}
|
||||
tools.append(function_schema)
|
||||
@@ -83,9 +87,9 @@ class ToolManager:
|
||||
|
||||
for function in use_funcs:
|
||||
function_schema = {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"input_schema": function.parameters,
|
||||
'name': function.name,
|
||||
'description': function.description,
|
||||
'input_schema': function.parameters,
|
||||
}
|
||||
tools.append(function_schema)
|
||||
|
||||
@@ -100,7 +104,7 @@ class ToolManager:
|
||||
if await loader.has_tool(name):
|
||||
return await loader.invoke_tool(query, name, parameters)
|
||||
else:
|
||||
raise ValueError(f"未找到工具: {name}")
|
||||
raise ValueError(f'未找到工具: {name}')
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭所有工具"""
|
||||
|
||||
Reference in New Issue
Block a user