feat(plugin): implement INVOKE_RERANK handler with run-scoped authorization

- Add invoke_rerank action handler in plugin handler
- Validate rerank model access via run session
- Cap documents at 64 for API limit
- Return sorted results by relevance score
This commit is contained in:
huanghuoguoguo
2026-05-13 10:36:47 +08:00
parent 2fd126b0d7
commit e81a1af36c
7 changed files with 586 additions and 102 deletions

View File

@@ -46,17 +46,19 @@ async def _validate_run_authorization(
resource_type: str,
resource_id: str,
ap: app.Application,
caller_plugin_identity: str | None = None,
) -> Union[tuple[None, handler.ActionResponse], tuple[Any, None]]:
"""Validate run_id authorization for a resource access.
Common validation logic for INVOKE_LLM, INVOKE_LLM_STREAM, CALL_TOOL,
RETRIEVE_KNOWLEDGE_BASE, and RETRIEVE_KNOWLEDGE actions.
RETRIEVE_KNOWLEDGE_BASE, RETRIEVE_KNOWLEDGE, and storage/file actions.
Args:
run_id: The run_id to validate.
resource_type: Resource type ('model', 'tool', 'knowledge_base').
resource_id: Resource identifier (model_uuid, tool_name, kb_id).
resource_type: Resource type ('model', 'tool', 'knowledge_base', 'storage', 'file').
resource_id: Resource identifier (model_uuid, tool_name, kb_id, 'plugin'/'workspace', file_key).
ap: Application instance for logging.
caller_plugin_identity: Optional plugin identity (author/name) of the caller for cross-plugin validation.
Returns:
Tuple of (session, None) if validation passes.
@@ -72,6 +74,18 @@ async def _validate_run_authorization(
message=f'Run session {run_id} not found or expired',
)
# Validate caller_plugin_identity matches session's plugin_identity
if caller_plugin_identity:
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
ap.logger.warning(
f'{resource_type.upper()}: caller_plugin_identity {caller_plugin_identity} '
f'does not match session plugin_identity {session_plugin_identity}'
)
return None, handler.ActionResponse.error(
message=f'Plugin identity mismatch: caller {caller_plugin_identity} is not authorized for run_id {run_id}',
)
if not session_registry.is_resource_allowed(session, resource_type, resource_id):
ap.logger.warning(
f'{resource_type.upper()}: {resource_id} not allowed for run_id {run_id}'
@@ -377,11 +391,12 @@ class RuntimeConnectionHandler(handler.Handler):
funcs = data.get('funcs', [])
extra_args = data.get('extra_args', {})
run_id = data.get('run_id') # Optional: present for AgentRunner calls
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
# Permission validation for AgentRunner calls
if run_id:
session, error = await _validate_run_authorization(
run_id, 'model', llm_model_uuid, self.ap
run_id, 'model', llm_model_uuid, self.ap, caller_plugin_identity
)
if error:
return error
@@ -428,11 +443,12 @@ class RuntimeConnectionHandler(handler.Handler):
funcs = data.get('funcs', [])
extra_args = data.get('extra_args', {})
run_id = data.get('run_id') # Optional: present for AgentRunner calls
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
# Permission validation for AgentRunner calls
if run_id:
session, error = await _validate_run_authorization(
run_id, 'model', llm_model_uuid, self.ap
run_id, 'model', llm_model_uuid, self.ap, caller_plugin_identity
)
if error:
yield error
@@ -476,13 +492,14 @@ class RuntimeConnectionHandler(handler.Handler):
# Support 'tool_parameters' (LangBotAPIProxy) and 'parameters' (AgentRunAPIProxy)
parameters = data.get('tool_parameters') or data.get('parameters', {})
run_id = data.get('run_id') # Optional: present for AgentRunner calls
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
# session_data = data['session']
# query_id = data['query_id']
# Permission validation for AgentRunner calls
if run_id:
session, error = await _validate_run_authorization(
run_id, 'tool', tool_name, self.ap
run_id, 'tool', tool_name, self.ap, caller_plugin_identity
)
if error:
return error
@@ -511,20 +528,36 @@ class RuntimeConnectionHandler(handler.Handler):
)
# ================= Binary Storage Handlers =================
# NOTE: These are low-level actions called by SDK Runtime's storage wrapper handlers.
# Permission validation is handled in SDK Runtime layer (not here):
# - plugin_storage: SDK handler auto-sets owner to caller plugin identity (inherent isolation)
# - workspace_storage: SDK handler should validate session.resources.storage.workspace_storage
# TODO: SDK storage handlers need to pass run_id and validate workspace_storage permission.
# Current risk: workspace storage access is unrestricted from AgentRunner context.
# Permission validation:
# - For AgentRunner calls (with run_id): validates storage permission via session_registry
# - For regular plugin calls (no run_id): unrestricted access (backward compatibility)
# - Plugin storage: inherent isolation via owner = plugin identity (set by SDK runtime)
# - Workspace storage: requires ctx.resources.storage.workspace_storage for AgentRunner
@self.action(RuntimeToLangBotAction.SET_BINARY_STORAGE)
async def set_binary_storage(data: dict[str, Any]) -> handler.ActionResponse:
"""Set binary storage"""
"""Set binary storage
For AgentRunner calls: validates storage permission via session_registry.
For regular plugin calls: unrestricted access (backward compatibility).
"""
key = data['key']
owner_type = data['owner_type']
owner = data['owner']
value = base64.b64decode(data['value_base64'])
run_id = data.get('run_id') # Optional: present for AgentRunner calls
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
# Permission validation for AgentRunner calls
if run_id:
# Determine storage type from owner_type
storage_type = owner_type # 'plugin' or 'workspace'
session, error = await _validate_run_authorization(
run_id, 'storage', storage_type, self.ap, caller_plugin_identity
)
if error:
return error
max_value_bytes = (
self.ap.instance_config.data.get('plugin', {})
.get('binary_storage', {})
@@ -574,10 +607,25 @@ class RuntimeConnectionHandler(handler.Handler):
@self.action(RuntimeToLangBotAction.GET_BINARY_STORAGE)
async def get_binary_storage(data: dict[str, Any]) -> handler.ActionResponse:
"""Get binary storage"""
"""Get binary storage
For AgentRunner calls: validates storage permission via session_registry.
For regular plugin calls: unrestricted access (backward compatibility).
"""
key = data['key']
owner_type = data['owner_type']
owner = data['owner']
run_id = data.get('run_id') # Optional: present for AgentRunner calls
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
# Permission validation for AgentRunner calls
if run_id:
storage_type = owner_type
session, error = await _validate_run_authorization(
run_id, 'storage', storage_type, self.ap, caller_plugin_identity
)
if error:
return error
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bstorage.BinaryStorage)
@@ -600,10 +648,25 @@ class RuntimeConnectionHandler(handler.Handler):
@self.action(RuntimeToLangBotAction.DELETE_BINARY_STORAGE)
async def delete_binary_storage(data: dict[str, Any]) -> handler.ActionResponse:
"""Delete binary storage"""
"""Delete binary storage
For AgentRunner calls: validates storage permission via session_registry.
For regular plugin calls: unrestricted access (backward compatibility).
"""
key = data['key']
owner_type = data['owner_type']
owner = data['owner']
run_id = data.get('run_id') # Optional: present for AgentRunner calls
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
# Permission validation for AgentRunner calls
if run_id:
storage_type = owner_type
session, error = await _validate_run_authorization(
run_id, 'storage', storage_type, self.ap, caller_plugin_identity
)
if error:
return error
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_bstorage.BinaryStorage)
@@ -618,9 +681,24 @@ class RuntimeConnectionHandler(handler.Handler):
@self.action(RuntimeToLangBotAction.GET_BINARY_STORAGE_KEYS)
async def get_binary_storage_keys(data: dict[str, Any]) -> handler.ActionResponse:
"""Get binary storage keys"""
"""Get binary storage keys
For AgentRunner calls: validates storage permission via session_registry.
For regular plugin calls: unrestricted access (backward compatibility).
"""
owner_type = data['owner_type']
owner = data['owner']
run_id = data.get('run_id') # Optional: present for AgentRunner calls
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
# Permission validation for AgentRunner calls
if run_id:
storage_type = owner_type
session, error = await _validate_run_authorization(
run_id, 'storage', storage_type, self.ap, caller_plugin_identity
)
if error:
return error
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bstorage.BinaryStorage.key)
@@ -636,8 +714,22 @@ class RuntimeConnectionHandler(handler.Handler):
@self.action(PluginToRuntimeAction.GET_CONFIG_FILE)
async def get_config_file(data: dict[str, Any]) -> handler.ActionResponse:
"""Get a config file by file key"""
"""Get a config file by file key
For AgentRunner calls: validates file_key against session.resources.files.
For regular plugin calls: unrestricted access (backward compatibility).
"""
file_key = data['file_key']
run_id = data.get('run_id') # Optional: present for AgentRunner calls
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
# Permission validation for AgentRunner calls
if run_id:
session, error = await _validate_run_authorization(
run_id, 'file', file_key, self.ap, caller_plugin_identity
)
if error:
return error
try:
# Load file from storage
@@ -672,6 +764,50 @@ class RuntimeConnectionHandler(handler.Handler):
except Exception as e:
return _make_rag_error_response(e, 'EmbeddingError', embedding_model_uuid=embedding_model_uuid)
@self.action(PluginToRuntimeAction.INVOKE_RERANK)
async def invoke_rerank(data: dict[str, Any]) -> handler.ActionResponse:
"""Invoke rerank model for agent runner with run-scoped authorization."""
run_id = data.get('run_id')
rerank_model_uuid = data['rerank_model_uuid']
query = data['query']
documents = data['documents']
top_k = data.get('top_k')
# Validate run authorization
session, error = await _validate_run_authorization(
run_id, 'model', rerank_model_uuid, self.ap
)
if error:
return error
# Get rerank model
rerank_model = await self.ap.model_mgr.get_rerank_model_by_uuid(rerank_model_uuid)
if rerank_model is None:
return handler.ActionResponse.error(
message=f'Rerank model with uuid {rerank_model_uuid} not found',
)
try:
# Cap documents at 64 for API limit
documents_capped = documents[:64]
scores = await rerank_model.provider.invoke_rerank(
model=rerank_model,
query=query,
documents=documents_capped,
)
# Sort by relevance score descending
scored = sorted(scores, key=lambda x: x.get('relevance_score', 0), reverse=True)
# Apply top_k if specified
if top_k is not None:
scored = scored[:top_k]
return handler.ActionResponse.success(data={'results': scored})
except Exception as e:
return _make_rag_error_response(e, 'RerankError', rerank_model_uuid=rerank_model_uuid)
@self.action(PluginToRuntimeAction.VECTOR_UPSERT)
async def vector_upsert(data: dict[str, Any]) -> handler.ActionResponse:
collection_id = data['collection_id']
@@ -817,11 +953,12 @@ class RuntimeConnectionHandler(handler.Handler):
top_k = data.get('top_k', 5)
filters = data.get('filters', {})
run_id = data.get('run_id') # Optional: present for AgentRunner calls
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
# Permission validation for AgentRunner calls
if run_id:
session, error = await _validate_run_authorization(
run_id, 'knowledge_base', kb_id, self.ap
run_id, 'knowledge_base', kb_id, self.ap, caller_plugin_identity
)
if error:
return error
@@ -891,12 +1028,6 @@ class RuntimeConnectionHandler(handler.Handler):
Note: This action has dual validation paths:
- AgentRunner: uses session_registry for permission check
- Regular plugin: uses ConfigMigration.resolve_runner_config for pipeline-level check
SECURITY TODO: This handler cannot verify the caller's plugin identity.
The session contains 'plugin_identity' (author/name), but we don't have access
to which plugin is making the API call. This could allow a malicious plugin to
use another plugin's run_id if it can guess/obtain it. Future improvement:
track caller plugin identity in RuntimeConnectionHandler or pass it in action data.
"""
query_id = data['query_id']
kb_id = data['kb_id']
@@ -904,6 +1035,7 @@ class RuntimeConnectionHandler(handler.Handler):
top_k = data.get('top_k', 5)
filters = data.get('filters', {})
run_id = data.get('run_id') # Optional: present for AgentRunner calls
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
if query_id not in self.ap.query_pool.cached_queries:
return handler.ActionResponse.error(
@@ -915,7 +1047,7 @@ class RuntimeConnectionHandler(handler.Handler):
# Permission validation for AgentRunner calls
if run_id:
session, error = await _validate_run_authorization(
run_id, 'knowledge_base', kb_id, self.ap
run_id, 'knowledge_base', kb_id, self.ap, caller_plugin_identity
)
if error:
return error