refactor: 重构部分插件管理逻辑

This commit is contained in:
Junyan Qin
2024-11-16 16:13:02 +08:00
parent bb219889e5
commit 658eb278c4
13 changed files with 219 additions and 158 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from ...core import app, entities as core_entities
from ...plugin import context as plugin_context
class SessionManager:
@@ -51,7 +52,10 @@ class SessionManager:
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
messages=[],
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']),
use_funcs=await self.ap.tool_mgr.get_all_functions(),
use_funcs=await self.ap.tool_mgr.get_all_functions(
plugin_enabled=True,
plugin_status=plugin_context.RuntimeContainerStatus.INITIALIZED,
),
)
session.conversations.append(conversation)
session.using_conversation = conversation

View File

@@ -20,8 +20,6 @@ class LLMFunction(pydantic.BaseModel):
description: str
"""给LLM识别的函数描述"""
enable: typing.Optional[bool] = True
parameters: dict
func: typing.Callable

View File

@@ -20,28 +20,25 @@ class ToolManager:
async def initialize(self):
pass
async def get_function(self, name: str) -> entities.LLMFunction:
"""获取函数"""
for function in await self.get_all_functions():
if function.name == name:
return function
return None
async def get_function_and_plugin(
self, name: str
) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]:
"""获取函数和插件"""
for plugin in self.ap.plugin_mgr.plugins:
"""获取函数和插件实例"""
for plugin in self.ap.plugin_mgr.plugins(
enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED
):
for function in plugin.content_functions:
if function.name == name:
return function, plugin.plugin_inst
return None, None
async def get_all_functions(self) -> list[entities.LLMFunction]:
async def get_all_functions(self, plugin_enabled: bool=None, plugin_status: plugin_context.RuntimeContainerStatus=None) -> list[entities.LLMFunction]:
"""获取所有函数"""
all_functions: list[entities.LLMFunction] = []
for plugin in self.ap.plugin_mgr.plugins:
for plugin in self.ap.plugin_mgr.plugins(
enabled=plugin_enabled, status=plugin_status
):
all_functions.extend(plugin.content_functions)
return all_functions
@@ -51,16 +48,15 @@ class ToolManager:
tools = []
for function in use_funcs:
if function.enable:
function_schema = {
"type": "function",
"function": {
"name": function.name,
"description": function.description,
"parameters": function.parameters,
},
}
tools.append(function_schema)
function_schema = {
"type": "function",
"function": {
"name": function.name,
"description": function.description,
"parameters": function.parameters,
},
}
tools.append(function_schema)
return tools
@@ -92,13 +88,12 @@ class ToolManager:
tools = []
for function in use_funcs:
if function.enable:
function_schema = {
"name": function.name,
"description": function.description,
"input_schema": function.parameters,
}
tools.append(function_schema)
function_schema = {
"name": function.name,
"description": function.description,
"input_schema": function.parameters,
}
tools.append(function_schema)
return tools