style: introduce ruff as linter and formatter (#1356)

* style: remove necessary imports

* style: fix F841

* style: fix F401

* style: fix F811

* style: fix E402

* style: fix E721

* style: fix E722

* style: fix E722

* style: fix F541

* style: ruff format

* style: all passed

* style: add ruff in deps

* style: more ignores in ruff.toml

* style: add pre-commit
This commit is contained in:
Junyan Qin (Chin)
2025-04-29 17:24:07 +08:00
committed by GitHub
parent 09e70d70e9
commit 209f16af76
240 changed files with 5307 additions and 4689 deletions
+42 -32
View File
@@ -30,7 +30,7 @@ class RuntimeMCPSession:
self.server_name = server_name
self.server_config = server_config
self.ap = ap
self.session = None
self.exit_stack = AsyncExitStack()
@@ -38,9 +38,9 @@ class RuntimeMCPSession:
async def _init_stdio_python_server(self):
server_params = StdioServerParameters(
command=self.server_config["command"],
args=self.server_config["args"],
env=self.server_config["env"],
command=self.server_config['command'],
args=self.server_config['args'],
env=self.server_config['env'],
)
stdio_transport = await self.exit_stack.enter_async_context(
@@ -58,12 +58,12 @@ class RuntimeMCPSession:
async def _init_sse_server(self):
sse_transport = await self.exit_stack.enter_async_context(
sse_client(
self.server_config["url"],
headers=self.server_config.get("headers", {}),
timeout=self.server_config.get("timeout", 10),
self.server_config['url'],
headers=self.server_config.get('headers', {}),
timeout=self.server_config.get('timeout', 10),
)
)
sseio, write = sse_transport
self.session = await self.exit_stack.enter_async_context(
@@ -73,18 +73,22 @@ class RuntimeMCPSession:
await self.session.initialize()
async def initialize(self):
self.ap.logger.debug(f"初始化 MCP 会话: {self.server_name} {self.server_config}")
self.ap.logger.debug(
f'初始化 MCP 会话: {self.server_name} {self.server_config}'
)
if self.server_config["mode"] == "stdio":
if self.server_config['mode'] == 'stdio':
await self._init_stdio_python_server()
elif self.server_config["mode"] == "sse":
elif self.server_config['mode'] == 'sse':
await self._init_sse_server()
else:
raise ValueError(f"无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}")
raise ValueError(
f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}'
)
tools = await self.session.list_tools()
self.ap.logger.debug(f"获取 MCP 工具: {tools}")
self.ap.logger.debug(f'获取 MCP 工具: {tools}')
for tool in tools.tools:
@@ -93,25 +97,28 @@ class RuntimeMCPSession:
if result.isError:
raise Exception(result.content[0].text)
return result.content[0].text
func.__name__ = tool.name
self.functions.append(tools_entities.LLMFunction(
name=tool.name,
human_desc=tool.description,
description=tool.description,
parameters=tool.inputSchema,
func=func,
))
self.functions.append(
tools_entities.LLMFunction(
name=tool.name,
human_desc=tool.description,
description=tool.description,
parameters=tool.inputSchema,
func=func,
)
)
async def shutdown(self):
"""关闭工具"""
await self.session._exit_stack.aclose()
@loader.loader_class("mcp")
@loader.loader_class('mcp')
class MCPLoader(loader.ToolLoader):
"""MCP 工具加载器。
在此加载器中管理所有与 MCP Server 的连接。
"""
@@ -125,16 +132,17 @@ class MCPLoader(loader.ToolLoader):
self._last_listed_functions = []
async def initialize(self):
for server_config in self.ap.instance_config.data.get("mcp", {}).get("servers", []):
if not server_config["enable"]:
for server_config in self.ap.instance_config.data.get('mcp', {}).get(
'servers', []
):
if not server_config['enable']:
continue
session = RuntimeMCPSession(server_config["name"], server_config, self.ap)
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
await session.initialize()
# self.ap.event_loop.create_task(session.initialize())
self.sessions[server_config["name"]] = session
self.sessions[server_config['name']] = session
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
all_functions = []
for session in self.sessions.values():
@@ -147,13 +155,15 @@ class MCPLoader(loader.ToolLoader):
async def has_tool(self, name: str) -> bool:
return name in [f.name for f in self._last_listed_functions]
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
async def invoke_tool(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
for server_name, session in self.sessions.items():
for function in session.functions:
if function.name == name:
return await function.func(query, **parameters)
raise ValueError(f"未找到工具: {name}")
raise ValueError(f'未找到工具: {name}')
async def shutdown(self):
"""关闭工具"""
+14 -16
View File
@@ -4,19 +4,18 @@ import typing
import traceback
from .. import loader, entities as tools_entities
from ....core import app, entities as core_entities
from ....core import entities as core_entities
from ....plugin import context as plugin_context
@loader.loader_class("plugin-tool-loader")
@loader.loader_class('plugin-tool-loader')
class PluginToolLoader(loader.ToolLoader):
"""插件工具加载器。
本加载器中不存储工具信息,仅负责从插件系统中获取工具信息。
"""
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
# 从插件系统获取工具(内容函数)
all_functions: list[tools_entities.LLMFunction] = []
@@ -49,23 +48,23 @@ class PluginToolLoader(loader.ToolLoader):
return function, plugin.plugin_inst
return None, None
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
async def invoke_tool(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
try:
function, plugin = await self._get_function_and_plugin(name)
if function is None:
return None
parameters = parameters.copy()
parameters = {"query": query, **parameters}
parameters = {'query': query, **parameters}
return await function.func(plugin, **parameters)
except Exception as e:
self.ap.logger.error(f"执行函数 {name} 时发生错误: {e}")
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
traceback.print_exc()
return f"error occurred when executing function {name}: {e}"
return f'error occurred when executing function {name}: {e}'
finally:
plugin = None
@@ -75,13 +74,12 @@ class PluginToolLoader(loader.ToolLoader):
break
if plugin is not None:
await self.ap.ctr_mgr.usage.post_function_record(
plugin={
"name": plugin.plugin_name,
"remote": plugin.plugin_repository,
"version": plugin.plugin_version,
"author": plugin.plugin_author,
'name': plugin.plugin_name,
'remote': plugin.plugin_repository,
'version': plugin.plugin_version,
'author': plugin.plugin_author,
},
function_name=function.name,
function_description=function.description,