Feat/pipeline specified plugins (#1752)

* feat: add persistence field

* feat: add basic extension page in pipeline config

* Merge pull request #1751 from langbot-app/copilot/add-plugin-extension-tab

Implement pipeline-scoped plugin binding system

* fix: i18n keys

---------

Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
This commit is contained in:
Junyan Qin (Chin)
2025-11-06 12:51:33 +08:00
committed by GitHub
parent 2c2a89d9db
commit 4a84bf2355
30 changed files with 525 additions and 41 deletions

View File

@@ -46,3 +46,28 @@ class PipelinesRouterGroup(group.RouterGroup):
await self.ap.pipeline_service.delete_pipeline(pipeline_uuid)
return self.success()
@self.route('/<pipeline_uuid>/extensions', methods=['GET', 'PUT'])
async def _(pipeline_uuid: str) -> str:
if quart.request.method == 'GET':
# Get current extensions and available plugins
pipeline = await self.ap.pipeline_service.get_pipeline(pipeline_uuid)
if pipeline is None:
return self.http_status(404, -1, 'pipeline not found')
plugins = await self.ap.plugin_connector.list_plugins()
return self.success(
data={
'bound_plugins': pipeline.get('extensions_preferences', {}).get('plugins', []),
'available_plugins': plugins,
}
)
elif quart.request.method == 'PUT':
# Update bound plugins for this pipeline
json_data = await quart.request.json
bound_plugins = json_data.get('bound_plugins', [])
await self.ap.pipeline_service.update_pipeline_extensions(pipeline_uuid, bound_plugins)
return self.success()

View File

@@ -136,3 +136,31 @@ class PipelineService:
)
)
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
async def update_pipeline_extensions(self, pipeline_uuid: str, bound_plugins: list[dict]) -> None:
"""Update the bound plugins for a pipeline"""
# Get current pipeline
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid
)
)
pipeline = result.first()
if pipeline is None:
raise ValueError(f'Pipeline {pipeline_uuid} not found')
# Update extensions_preferences
extensions_preferences = pipeline.extensions_preferences or {}
extensions_preferences['plugins'] = bound_plugins
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
.where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid)
.values(extensions_preferences=extensions_preferences)
)
# Reload pipeline to apply changes
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
pipeline = await self.get_pipeline(pipeline_uuid)
await self.ap.pipeline_mgr.load_pipeline(pipeline)

View File

@@ -59,14 +59,15 @@ class CommandManager:
context: command_context.ExecuteContext,
operator_list: list[operator.CommandOperator],
operator: operator.CommandOperator = None,
bound_plugins: list[str] | None = None,
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
"""执行命令"""
command_list = await self.ap.plugin_connector.list_commands()
command_list = await self.ap.plugin_connector.list_commands(bound_plugins)
for command in command_list:
if command.metadata.name == context.command:
async for ret in self.ap.plugin_connector.execute_command(context):
async for ret in self.ap.plugin_connector.execute_command(context, bound_plugins):
yield ret
break
else:
@@ -102,5 +103,8 @@ class CommandManager:
ctx.shift()
async for ret in self._execute(ctx, self.cmd_list):
# Get bound plugins from query
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
async for ret in self._execute(ctx, self.cmd_list, bound_plugins=bound_plugins):
yield ret

View File

@@ -1,12 +1,13 @@
import sqlalchemy
from .base import Base
from ...utils import constants
initial_metadata = [
{
'key': 'database_version',
'value': '0',
'value': str(constants.required_database_version),
},
]

View File

@@ -22,6 +22,7 @@ class LegacyPipeline(Base):
is_default = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
stages = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
extensions_preferences = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
class PipelineRunRecord(Base):

View File

@@ -78,6 +78,8 @@ class PersistenceManager:
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
await self.write_default_pipeline()
async def create_tables(self):
# create tables
async with self.get_db_engine().connect() as conn:
@@ -98,6 +100,7 @@ class PersistenceManager:
if row is None:
await self.execute_async(sqlalchemy.insert(metadata.Metadata).values(item))
async def write_default_pipeline(self):
# write default pipeline
result = await self.execute_async(sqlalchemy.select(pipeline.LegacyPipeline))
default_pipeline_uuid = None

View File

@@ -0,0 +1,20 @@
import sqlalchemy
from .. import migration
@migration.migration_class(9)
class DBMigratePipelineExtensionPreferences(migration.DBMigration):
"""Pipeline extension preferences"""
async def upgrade(self):
"""Upgrade"""
sql_text = sqlalchemy.text(
"ALTER TABLE legacy_pipelines ADD COLUMN extensions_preferences JSON NOT NULL DEFAULT '{}'"
)
await self.ap.persistence_mgr.execute_async(sql_text)
async def downgrade(self):
"""Downgrade"""
sql_text = sqlalchemy.text('ALTER TABLE legacy_pipelines DROP COLUMN extensions_preferences')
await self.ap.persistence_mgr.execute_async(sql_text)

View File

@@ -68,6 +68,9 @@ class RuntimePipeline:
stage_containers: list[StageInstContainer]
"""阶段实例容器"""
bound_plugins: list[str]
"""绑定到此流水线的插件列表格式author/plugin_name"""
def __init__(
self,
@@ -78,9 +81,16 @@ class RuntimePipeline:
self.ap = ap
self.pipeline_entity = pipeline_entity
self.stage_containers = stage_containers
# Extract bound plugins from extensions_preferences
extensions_prefs = pipeline_entity.extensions_preferences or {}
plugin_list = extensions_prefs.get('plugins', [])
self.bound_plugins = [f"{p['author']}/{p['name']}" for p in plugin_list] if plugin_list else []
async def run(self, query: pipeline_query.Query):
query.pipeline_config = self.pipeline_entity.config
# Store bound plugins in query for filtering
query.variables['_pipeline_bound_plugins'] = self.bound_plugins
await self.process_query(query)
async def _check_output(self, query: pipeline_query.Query, result: pipeline_entities.StageProcessResult):
@@ -188,6 +198,9 @@ class RuntimePipeline:
async def process_query(self, query: pipeline_query.Query):
"""处理请求"""
try:
# Get bound plugins for this pipeline
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
# ======== 触发 MessageReceived 事件 ========
event_type = (
events.PersonMessageReceived
@@ -203,7 +216,7 @@ class RuntimePipeline:
message_chain=query.message_chain,
)
event_ctx = await self.ap.plugin_connector.emit_event(event_obj)
event_ctx = await self.ap.plugin_connector.emit_event(event_obj, bound_plugins)
if event_ctx.is_prevented_default():
return

View File

@@ -65,7 +65,9 @@ class PreProcessor(stage.PipelineStage):
query.use_llm_model_uuid = llm_model.model_entity.uuid
if llm_model.model_entity.abilities.__contains__('func_call'):
query.use_funcs = await self.ap.tool_mgr.get_all_tools()
# Get bound plugins for filtering tools
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins)
variables = {
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
@@ -130,7 +132,9 @@ class PreProcessor(stage.PipelineStage):
query=query,
)
event_ctx = await self.ap.plugin_connector.emit_event(event)
# Get bound plugins for filtering
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
event_ctx = await self.ap.plugin_connector.emit_event(event, bound_plugins)
query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt

View File

@@ -43,7 +43,9 @@ class ChatMessageHandler(handler.MessageHandler):
query=query,
)
event_ctx = await self.ap.plugin_connector.emit_event(event)
# Get bound plugins for filtering
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
event_ctx = await self.ap.plugin_connector.emit_event(event, bound_plugins)
is_create_card = False # 判断下是否需要创建流式卡片

View File

@@ -45,7 +45,9 @@ class CommandHandler(handler.MessageHandler):
query=query,
)
event_ctx = await self.ap.plugin_connector.emit_event(event)
# Get bound plugins for filtering
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
event_ctx = await self.ap.plugin_connector.emit_event(event, bound_plugins)
if event_ctx.is_prevented_default():
if event_ctx.event.reply_message_chain is not None:

View File

@@ -72,7 +72,9 @@ class ResponseWrapper(stage.PipelineStage):
query=query,
)
event_ctx = await self.ap.plugin_connector.emit_event(event)
# Get bound plugins for filtering
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
event_ctx = await self.ap.plugin_connector.emit_event(event, bound_plugins)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
@@ -115,7 +117,9 @@ class ResponseWrapper(stage.PipelineStage):
query=query,
)
event_ctx = await self.ap.plugin_connector.emit_event(event)
# Get bound plugins for filtering
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
event_ctx = await self.ap.plugin_connector.emit_event(event, bound_plugins)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(

View File

@@ -249,47 +249,62 @@ class PluginRuntimeConnector:
async def emit_event(
self,
event: events.BaseEventModel,
bound_plugins: list[str] | None = None,
) -> context.EventContext:
event_ctx = context.EventContext.from_event(event)
if not self.is_enable_plugin:
return event_ctx
event_ctx_result = await self.handler.emit_event(event_ctx.model_dump(serialize_as_any=False))
# Pass include_plugins to runtime for filtering
event_ctx_result = await self.handler.emit_event(
event_ctx.model_dump(serialize_as_any=False), include_plugins=bound_plugins
)
event_ctx = context.EventContext.model_validate(event_ctx_result['event_context'])
return event_ctx
async def list_tools(self) -> list[ComponentManifest]:
async def list_tools(self, bound_plugins: list[str] | None = None) -> list[ComponentManifest]:
if not self.is_enable_plugin:
return []
list_tools_data = await self.handler.list_tools()
# Pass include_plugins to runtime for filtering
list_tools_data = await self.handler.list_tools(include_plugins=bound_plugins)
return [ComponentManifest.model_validate(tool) for tool in list_tools_data]
tools = [ComponentManifest.model_validate(tool) for tool in list_tools_data]
async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]:
return tools
async def call_tool(
self, tool_name: str, parameters: dict[str, Any], bound_plugins: list[str] | None = None
) -> dict[str, Any]:
if not self.is_enable_plugin:
return {'error': 'Tool not found: plugin system is disabled'}
return await self.handler.call_tool(tool_name, parameters)
# Pass include_plugins to runtime for validation
return await self.handler.call_tool(tool_name, parameters, include_plugins=bound_plugins)
async def list_commands(self) -> list[ComponentManifest]:
async def list_commands(self, bound_plugins: list[str] | None = None) -> list[ComponentManifest]:
if not self.is_enable_plugin:
return []
list_commands_data = await self.handler.list_commands()
# Pass include_plugins to runtime for filtering
list_commands_data = await self.handler.list_commands(include_plugins=bound_plugins)
return [ComponentManifest.model_validate(command) for command in list_commands_data]
commands = [ComponentManifest.model_validate(command) for command in list_commands_data]
return commands
async def execute_command(
self, command_ctx: command_context.ExecuteContext
self, command_ctx: command_context.ExecuteContext, bound_plugins: list[str] | None = None
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if not self.is_enable_plugin:
yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(command_ctx.command))
return
gen = self.handler.execute_command(command_ctx.model_dump(serialize_as_any=True))
# Pass include_plugins to runtime for validation
gen = self.handler.execute_command(command_ctx.model_dump(serialize_as_any=True), include_plugins=bound_plugins)
async for ret in gen:
cmd_ret = command_context.CommandReturn.model_validate(ret)

View File

@@ -554,23 +554,27 @@ class RuntimeConnectionHandler(handler.Handler):
async def emit_event(
self,
event_context: dict[str, Any],
include_plugins: list[str] | None = None,
) -> dict[str, Any]:
"""Emit event"""
result = await self.call_action(
LangBotToRuntimeAction.EMIT_EVENT,
{
'event_context': event_context,
'include_plugins': include_plugins,
},
timeout=60,
)
return result
async def list_tools(self) -> list[dict[str, Any]]:
async def list_tools(self, include_plugins: list[str] | None = None) -> list[dict[str, Any]]:
"""List tools"""
result = await self.call_action(
LangBotToRuntimeAction.LIST_TOOLS,
{},
{
'include_plugins': include_plugins,
},
timeout=20,
)
@@ -615,34 +619,42 @@ class RuntimeConnectionHandler(handler.Handler):
.where(persistence_bstorage.BinaryStorage.owner == owner)
)
async def call_tool(self, tool_name: str, parameters: dict[str, Any]) -> dict[str, Any]:
async def call_tool(
self, tool_name: str, parameters: dict[str, Any], include_plugins: list[str] | None = None
) -> dict[str, Any]:
"""Call tool"""
result = await self.call_action(
LangBotToRuntimeAction.CALL_TOOL,
{
'tool_name': tool_name,
'tool_parameters': parameters,
'include_plugins': include_plugins,
},
timeout=60,
)
return result['tool_response']
async def list_commands(self) -> list[dict[str, Any]]:
async def list_commands(self, include_plugins: list[str] | None = None) -> list[dict[str, Any]]:
"""List commands"""
result = await self.call_action(
LangBotToRuntimeAction.LIST_COMMANDS,
{},
{
'include_plugins': include_plugins,
},
timeout=10,
)
return result['commands']
async def execute_command(self, command_context: dict[str, Any]) -> typing.AsyncGenerator[dict[str, Any], None]:
async def execute_command(
self, command_context: dict[str, Any], include_plugins: list[str] | None = None
) -> typing.AsyncGenerator[dict[str, Any], None]:
"""Execute command"""
gen = self.call_action_generator(
LangBotToRuntimeAction.EXECUTE_COMMAND,
{
'command_context': command_context,
'include_plugins': include_plugins,
},
timeout=60,
)

View File

@@ -35,7 +35,7 @@ class ToolLoader(abc.ABC):
pass
@abc.abstractmethod
async def get_tools(self) -> list[resource_tool.LLMTool]:
async def get_tools(self, bound_plugins: list[str] | None = None) -> list[resource_tool.LLMTool]:
"""获取所有工具"""
pass

View File

@@ -301,7 +301,7 @@ class MCPLoader(loader.ToolLoader):
return session
async def get_tools(self) -> list[resource_tool.LLMTool]:
async def get_tools(self, bound_plugins: list[str] | None = None) -> list[resource_tool.LLMTool]:
all_functions = []
for session in self.sessions.values():

View File

@@ -14,11 +14,11 @@ class PluginToolLoader(loader.ToolLoader):
本加载器中不存储工具信息,仅负责从插件系统中获取工具信息。
"""
async def get_tools(self) -> list[resource_tool.LLMTool]:
async def get_tools(self, bound_plugins: list[str] | None = None) -> list[resource_tool.LLMTool]:
# 从插件系统获取工具(内容函数)
all_functions: list[resource_tool.LLMTool] = []
for tool in await self.ap.plugin_connector.list_tools():
for tool in await self.ap.plugin_connector.list_tools(bound_plugins):
tool_obj = resource_tool.LLMTool(
name=tool.metadata.name,
human_desc=tool.metadata.description.en_US,

View File

@@ -28,12 +28,12 @@ class ToolManager:
self.mcp_tool_loader = mcp_loader.MCPLoader(self.ap)
await self.mcp_tool_loader.initialize()
async def get_all_tools(self) -> list[resource_tool.LLMTool]:
async def get_all_tools(self, bound_plugins: list[str] | None = None) -> list[resource_tool.LLMTool]:
"""获取所有函数"""
all_functions: list[resource_tool.LLMTool] = []
all_functions.extend(await self.plugin_tool_loader.get_tools())
all_functions.extend(await self.mcp_tool_loader.get_tools())
all_functions.extend(await self.plugin_tool_loader.get_tools(bound_plugins))
all_functions.extend(await self.mcp_tool_loader.get_tools(bound_plugins))
return all_functions

View File

@@ -1,6 +1,6 @@
semantic_version = 'v4.4.1'
required_database_version = 8
required_database_version = 9
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
debug_mode = False