style: restrict line-length

This commit is contained in:
Junyan Qin
2025-05-10 18:04:58 +08:00
parent b30016ed08
commit 055b389353
134 changed files with 1096 additions and 2595 deletions

View File

@@ -45,9 +45,7 @@ class ToolLoader(abc.ABC):
pass
@abc.abstractmethod
async def invoke_tool(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
"""执行工具调用"""
pass

View File

@@ -43,15 +43,11 @@ class RuntimeMCPSession:
env=self.server_config['env'],
)
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(stdio, write)
)
self.session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
await self.session.initialize()
@@ -66,25 +62,19 @@ class RuntimeMCPSession:
sseio, write = sse_transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(sseio, write)
)
self.session = await self.exit_stack.enter_async_context(ClientSession(sseio, write))
await self.session.initialize()
async def initialize(self):
self.ap.logger.debug(
f'初始化 MCP 会话: {self.server_name} {self.server_config}'
)
self.ap.logger.debug(f'初始化 MCP 会话: {self.server_name} {self.server_config}')
if self.server_config['mode'] == 'stdio':
await self._init_stdio_python_server()
elif self.server_config['mode'] == 'sse':
await self._init_sse_server()
else:
raise ValueError(
f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}'
)
raise ValueError(f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}')
tools = await self.session.list_tools()
@@ -132,9 +122,7 @@ class MCPLoader(loader.ToolLoader):
self._last_listed_functions = []
async def initialize(self):
for server_config in self.ap.instance_config.data.get('mcp', {}).get(
'servers', []
):
for server_config in self.ap.instance_config.data.get('mcp', {}).get('servers', []):
if not server_config['enable']:
continue
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
@@ -155,9 +143,7 @@ class MCPLoader(loader.ToolLoader):
async def has_tool(self, name: str) -> bool:
return name in [f.name for f in self._last_listed_functions]
async def invoke_tool(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
for server_name, session in self.sessions.items():
for function in session.functions:
if function.name == name:

View File

@@ -48,9 +48,7 @@ class PluginToolLoader(loader.ToolLoader):
return function, plugin.plugin_inst
return None, None
async def invoke_tool(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
try:
function, plugin = await self._get_function_and_plugin(name)
if function is None:

View File

@@ -28,9 +28,7 @@ class ToolManager:
await loader_inst.initialize()
self.loaders.append(loader_inst)
async def get_all_functions(
self, plugin_enabled: bool = None
) -> list[entities.LLMFunction]:
async def get_all_functions(self, plugin_enabled: bool = None) -> list[entities.LLMFunction]:
"""获取所有函数"""
all_functions: list[entities.LLMFunction] = []
@@ -39,9 +37,7 @@ class ToolManager:
return all_functions
async def generate_tools_for_openai(
self, use_funcs: list[entities.LLMFunction]
) -> list:
async def generate_tools_for_openai(self, use_funcs: list[entities.LLMFunction]) -> list:
"""生成函数列表"""
tools = []
@@ -58,9 +54,7 @@ class ToolManager:
return tools
async def generate_tools_for_anthropic(
self, use_funcs: list[entities.LLMFunction]
) -> list:
async def generate_tools_for_anthropic(self, use_funcs: list[entities.LLMFunction]) -> list:
"""为anthropic生成函数列表
e.g.
@@ -95,9 +89,7 @@ class ToolManager:
return tools
async def execute_func_call(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
async def execute_func_call(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
"""执行函数调用"""
for loader in self.loaders: