mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-04 21:06:03 +00:00
feat(plugin): implement INVOKE_RERANK handler with run-scoped authorization
- Add invoke_rerank action handler in plugin handler - Validate rerank model access via run session - Cap documents at 64 for API limit - Return sorted results by relevance score
This commit is contained in:
@@ -46,17 +46,19 @@ async def _validate_run_authorization(
|
||||
resource_type: str,
|
||||
resource_id: str,
|
||||
ap: app.Application,
|
||||
caller_plugin_identity: str | None = None,
|
||||
) -> Union[tuple[None, handler.ActionResponse], tuple[Any, None]]:
|
||||
"""Validate run_id authorization for a resource access.
|
||||
|
||||
Common validation logic for INVOKE_LLM, INVOKE_LLM_STREAM, CALL_TOOL,
|
||||
RETRIEVE_KNOWLEDGE_BASE, and RETRIEVE_KNOWLEDGE actions.
|
||||
RETRIEVE_KNOWLEDGE_BASE, RETRIEVE_KNOWLEDGE, and storage/file actions.
|
||||
|
||||
Args:
|
||||
run_id: The run_id to validate.
|
||||
resource_type: Resource type ('model', 'tool', 'knowledge_base').
|
||||
resource_id: Resource identifier (model_uuid, tool_name, kb_id).
|
||||
resource_type: Resource type ('model', 'tool', 'knowledge_base', 'storage', 'file').
|
||||
resource_id: Resource identifier (model_uuid, tool_name, kb_id, 'plugin'/'workspace', file_key).
|
||||
ap: Application instance for logging.
|
||||
caller_plugin_identity: Optional plugin identity (author/name) of the caller for cross-plugin validation.
|
||||
|
||||
Returns:
|
||||
Tuple of (session, None) if validation passes.
|
||||
@@ -72,6 +74,18 @@ async def _validate_run_authorization(
|
||||
message=f'Run session {run_id} not found or expired',
|
||||
)
|
||||
|
||||
# Validate caller_plugin_identity matches session's plugin_identity
|
||||
if caller_plugin_identity:
|
||||
session_plugin_identity = session.get('plugin_identity')
|
||||
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
|
||||
ap.logger.warning(
|
||||
f'{resource_type.upper()}: caller_plugin_identity {caller_plugin_identity} '
|
||||
f'does not match session plugin_identity {session_plugin_identity}'
|
||||
)
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'Plugin identity mismatch: caller {caller_plugin_identity} is not authorized for run_id {run_id}',
|
||||
)
|
||||
|
||||
if not session_registry.is_resource_allowed(session, resource_type, resource_id):
|
||||
ap.logger.warning(
|
||||
f'{resource_type.upper()}: {resource_id} not allowed for run_id {run_id}'
|
||||
@@ -377,11 +391,12 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
funcs = data.get('funcs', [])
|
||||
extra_args = data.get('extra_args', {})
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
||||
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'model', llm_model_uuid, self.ap
|
||||
run_id, 'model', llm_model_uuid, self.ap, caller_plugin_identity
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
@@ -428,11 +443,12 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
funcs = data.get('funcs', [])
|
||||
extra_args = data.get('extra_args', {})
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
||||
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'model', llm_model_uuid, self.ap
|
||||
run_id, 'model', llm_model_uuid, self.ap, caller_plugin_identity
|
||||
)
|
||||
if error:
|
||||
yield error
|
||||
@@ -476,13 +492,14 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
# Support 'tool_parameters' (LangBotAPIProxy) and 'parameters' (AgentRunAPIProxy)
|
||||
parameters = data.get('tool_parameters') or data.get('parameters', {})
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
||||
# session_data = data['session']
|
||||
# query_id = data['query_id']
|
||||
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'tool', tool_name, self.ap
|
||||
run_id, 'tool', tool_name, self.ap, caller_plugin_identity
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
@@ -511,20 +528,36 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
)
|
||||
|
||||
# ================= Binary Storage Handlers =================
|
||||
# NOTE: These are low-level actions called by SDK Runtime's storage wrapper handlers.
|
||||
# Permission validation is handled in SDK Runtime layer (not here):
|
||||
# - plugin_storage: SDK handler auto-sets owner to caller plugin identity (inherent isolation)
|
||||
# - workspace_storage: SDK handler should validate session.resources.storage.workspace_storage
|
||||
# TODO: SDK storage handlers need to pass run_id and validate workspace_storage permission.
|
||||
# Current risk: workspace storage access is unrestricted from AgentRunner context.
|
||||
# Permission validation:
|
||||
# - For AgentRunner calls (with run_id): validates storage permission via session_registry
|
||||
# - For regular plugin calls (no run_id): unrestricted access (backward compatibility)
|
||||
# - Plugin storage: inherent isolation via owner = plugin identity (set by SDK runtime)
|
||||
# - Workspace storage: requires ctx.resources.storage.workspace_storage for AgentRunner
|
||||
|
||||
@self.action(RuntimeToLangBotAction.SET_BINARY_STORAGE)
|
||||
async def set_binary_storage(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Set binary storage"""
|
||||
"""Set binary storage
|
||||
|
||||
For AgentRunner calls: validates storage permission via session_registry.
|
||||
For regular plugin calls: unrestricted access (backward compatibility).
|
||||
"""
|
||||
key = data['key']
|
||||
owner_type = data['owner_type']
|
||||
owner = data['owner']
|
||||
value = base64.b64decode(data['value_base64'])
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
||||
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
# Determine storage type from owner_type
|
||||
storage_type = owner_type # 'plugin' or 'workspace'
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'storage', storage_type, self.ap, caller_plugin_identity
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
max_value_bytes = (
|
||||
self.ap.instance_config.data.get('plugin', {})
|
||||
.get('binary_storage', {})
|
||||
@@ -574,10 +607,25 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
|
||||
@self.action(RuntimeToLangBotAction.GET_BINARY_STORAGE)
|
||||
async def get_binary_storage(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Get binary storage"""
|
||||
"""Get binary storage
|
||||
|
||||
For AgentRunner calls: validates storage permission via session_registry.
|
||||
For regular plugin calls: unrestricted access (backward compatibility).
|
||||
"""
|
||||
key = data['key']
|
||||
owner_type = data['owner_type']
|
||||
owner = data['owner']
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
||||
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
storage_type = owner_type
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'storage', storage_type, self.ap, caller_plugin_identity
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_bstorage.BinaryStorage)
|
||||
@@ -600,10 +648,25 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
|
||||
@self.action(RuntimeToLangBotAction.DELETE_BINARY_STORAGE)
|
||||
async def delete_binary_storage(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Delete binary storage"""
|
||||
"""Delete binary storage
|
||||
|
||||
For AgentRunner calls: validates storage permission via session_registry.
|
||||
For regular plugin calls: unrestricted access (backward compatibility).
|
||||
"""
|
||||
key = data['key']
|
||||
owner_type = data['owner_type']
|
||||
owner = data['owner']
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
||||
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
storage_type = owner_type
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'storage', storage_type, self.ap, caller_plugin_identity
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_bstorage.BinaryStorage)
|
||||
@@ -618,9 +681,24 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
|
||||
@self.action(RuntimeToLangBotAction.GET_BINARY_STORAGE_KEYS)
|
||||
async def get_binary_storage_keys(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Get binary storage keys"""
|
||||
"""Get binary storage keys
|
||||
|
||||
For AgentRunner calls: validates storage permission via session_registry.
|
||||
For regular plugin calls: unrestricted access (backward compatibility).
|
||||
"""
|
||||
owner_type = data['owner_type']
|
||||
owner = data['owner']
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
||||
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
storage_type = owner_type
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'storage', storage_type, self.ap, caller_plugin_identity
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_bstorage.BinaryStorage.key)
|
||||
@@ -636,8 +714,22 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
|
||||
@self.action(PluginToRuntimeAction.GET_CONFIG_FILE)
|
||||
async def get_config_file(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Get a config file by file key"""
|
||||
"""Get a config file by file key
|
||||
|
||||
For AgentRunner calls: validates file_key against session.resources.files.
|
||||
For regular plugin calls: unrestricted access (backward compatibility).
|
||||
"""
|
||||
file_key = data['file_key']
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
||||
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'file', file_key, self.ap, caller_plugin_identity
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
try:
|
||||
# Load file from storage
|
||||
@@ -672,6 +764,50 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
except Exception as e:
|
||||
return _make_rag_error_response(e, 'EmbeddingError', embedding_model_uuid=embedding_model_uuid)
|
||||
|
||||
@self.action(PluginToRuntimeAction.INVOKE_RERANK)
|
||||
async def invoke_rerank(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Invoke rerank model for agent runner with run-scoped authorization."""
|
||||
run_id = data.get('run_id')
|
||||
rerank_model_uuid = data['rerank_model_uuid']
|
||||
query = data['query']
|
||||
documents = data['documents']
|
||||
top_k = data.get('top_k')
|
||||
|
||||
# Validate run authorization
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'model', rerank_model_uuid, self.ap
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
# Get rerank model
|
||||
rerank_model = await self.ap.model_mgr.get_rerank_model_by_uuid(rerank_model_uuid)
|
||||
if rerank_model is None:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Rerank model with uuid {rerank_model_uuid} not found',
|
||||
)
|
||||
|
||||
try:
|
||||
# Cap documents at 64 for API limit
|
||||
documents_capped = documents[:64]
|
||||
|
||||
scores = await rerank_model.provider.invoke_rerank(
|
||||
model=rerank_model,
|
||||
query=query,
|
||||
documents=documents_capped,
|
||||
)
|
||||
|
||||
# Sort by relevance score descending
|
||||
scored = sorted(scores, key=lambda x: x.get('relevance_score', 0), reverse=True)
|
||||
|
||||
# Apply top_k if specified
|
||||
if top_k is not None:
|
||||
scored = scored[:top_k]
|
||||
|
||||
return handler.ActionResponse.success(data={'results': scored})
|
||||
except Exception as e:
|
||||
return _make_rag_error_response(e, 'RerankError', rerank_model_uuid=rerank_model_uuid)
|
||||
|
||||
@self.action(PluginToRuntimeAction.VECTOR_UPSERT)
|
||||
async def vector_upsert(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
collection_id = data['collection_id']
|
||||
@@ -817,11 +953,12 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
top_k = data.get('top_k', 5)
|
||||
filters = data.get('filters', {})
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
||||
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'knowledge_base', kb_id, self.ap
|
||||
run_id, 'knowledge_base', kb_id, self.ap, caller_plugin_identity
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
@@ -891,12 +1028,6 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
Note: This action has dual validation paths:
|
||||
- AgentRunner: uses session_registry for permission check
|
||||
- Regular plugin: uses ConfigMigration.resolve_runner_config for pipeline-level check
|
||||
|
||||
SECURITY TODO: This handler cannot verify the caller's plugin identity.
|
||||
The session contains 'plugin_identity' (author/name), but we don't have access
|
||||
to which plugin is making the API call. This could allow a malicious plugin to
|
||||
use another plugin's run_id if it can guess/obtain it. Future improvement:
|
||||
track caller plugin identity in RuntimeConnectionHandler or pass it in action data.
|
||||
"""
|
||||
query_id = data['query_id']
|
||||
kb_id = data['kb_id']
|
||||
@@ -904,6 +1035,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
top_k = data.get('top_k', 5)
|
||||
filters = data.get('filters', {})
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
||||
|
||||
if query_id not in self.ap.query_pool.cached_queries:
|
||||
return handler.ActionResponse.error(
|
||||
@@ -915,7 +1047,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'knowledge_base', kb_id, self.ap
|
||||
run_id, 'knowledge_base', kb_id, self.ap, caller_plugin_identity
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
Reference in New Issue
Block a user