mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-08 06:46:02 +00:00
perf(agent-runner): improve session registry and orchestrator efficiency
- Add pre-computed _authorized_ids (frozenset) at session registration for O(1) lookup - Refactor is_resource_allowed() from linear search to set membership check - Add thread-safe locking to get_session_registry() singleton - Cache _session_registry and _state_store references in orchestrator __init__ - Add asyncio.gather() for parallel resource building in AgentResourceBuilder - Create shared test fixtures in tests/unit_tests/agent/conftest.py - Update test files to import from shared conftest.py Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
import base64
|
||||
import traceback
|
||||
|
||||
@@ -24,6 +24,7 @@ from ..entity.persistence import bstorage as persistence_bstorage
|
||||
|
||||
from ..core import app
|
||||
from ..utils import constants
|
||||
from ..agent.runner.session_registry import get_session_registry
|
||||
|
||||
|
||||
def _make_rag_error_response(error: Exception, error_type: str, **extra_context) -> handler.ActionResponse:
|
||||
@@ -40,6 +41,48 @@ def _make_rag_error_response(error: Exception, error_type: str, **extra_context)
|
||||
return handler.ActionResponse.error(message=message)
|
||||
|
||||
|
||||
async def _validate_run_authorization(
|
||||
run_id: str,
|
||||
resource_type: str,
|
||||
resource_id: str,
|
||||
ap: app.Application,
|
||||
) -> Union[tuple[None, handler.ActionResponse], tuple[Any, None]]:
|
||||
"""Validate run_id authorization for a resource access.
|
||||
|
||||
Common validation logic for INVOKE_LLM, INVOKE_LLM_STREAM, CALL_TOOL,
|
||||
RETRIEVE_KNOWLEDGE_BASE, and RETRIEVE_KNOWLEDGE actions.
|
||||
|
||||
Args:
|
||||
run_id: The run_id to validate.
|
||||
resource_type: Resource type ('model', 'tool', 'knowledge_base').
|
||||
resource_id: Resource identifier (model_uuid, tool_name, kb_id).
|
||||
ap: Application instance for logging.
|
||||
|
||||
Returns:
|
||||
Tuple of (session, None) if validation passes.
|
||||
Tuple of (None, error_response) if validation fails.
|
||||
"""
|
||||
session_registry = get_session_registry()
|
||||
session = await session_registry.get(run_id)
|
||||
if not session:
|
||||
ap.logger.warning(
|
||||
f'{resource_type.upper()}: run_id {run_id} not found in session registry'
|
||||
)
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'Run session {run_id} not found or expired',
|
||||
)
|
||||
|
||||
if not session_registry.is_resource_allowed(session, resource_type, resource_id):
|
||||
ap.logger.warning(
|
||||
f'{resource_type.upper()}: {resource_id} not allowed for run_id {run_id}'
|
||||
)
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'{resource_type} {resource_id} is not authorized for this agent run',
|
||||
)
|
||||
|
||||
return session, None
|
||||
|
||||
|
||||
class RuntimeConnectionHandler(handler.Handler):
|
||||
"""Runtime connection handler"""
|
||||
|
||||
@@ -324,11 +367,24 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
|
||||
@self.action(PluginToRuntimeAction.INVOKE_LLM)
|
||||
async def invoke_llm(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Invoke llm"""
|
||||
"""Invoke llm
|
||||
|
||||
For AgentRunner calls: requires run_id and validates model_uuid against session.resources.models.
|
||||
For regular plugin calls: no run_id, unrestricted access (backward compatibility).
|
||||
"""
|
||||
llm_model_uuid = data['llm_model_uuid']
|
||||
messages = data['messages']
|
||||
funcs = data.get('funcs', [])
|
||||
extra_args = data.get('extra_args', {})
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'model', llm_model_uuid, self.ap
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
llm_model = await self.ap.model_mgr.get_model_by_uuid(llm_model_uuid)
|
||||
if llm_model is None:
|
||||
@@ -362,11 +418,25 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
|
||||
@self.action(PluginToRuntimeAction.INVOKE_LLM_STREAM)
|
||||
async def invoke_llm_stream(data: dict[str, Any]):
|
||||
"""Invoke llm with streaming response"""
|
||||
"""Invoke llm with streaming response
|
||||
|
||||
For AgentRunner calls: requires run_id and validates model_uuid against session.resources.models.
|
||||
For regular plugin calls: no run_id, unrestricted access (backward compatibility).
|
||||
"""
|
||||
llm_model_uuid = data['llm_model_uuid']
|
||||
messages = data['messages']
|
||||
funcs = data.get('funcs', [])
|
||||
extra_args = data.get('extra_args', {})
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'model', llm_model_uuid, self.ap
|
||||
)
|
||||
if error:
|
||||
yield error
|
||||
return
|
||||
|
||||
llm_model = await self.ap.model_mgr.get_model_by_uuid(llm_model_uuid)
|
||||
if llm_model is None:
|
||||
@@ -393,12 +463,30 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
|
||||
@self.action(PluginToRuntimeAction.CALL_TOOL)
|
||||
async def call_tool(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Call a tool"""
|
||||
"""Call a tool
|
||||
|
||||
For AgentRunner calls: requires run_id and validates tool_name against session.resources.tools.
|
||||
For regular plugin calls: no run_id, unrestricted access (backward compatibility).
|
||||
|
||||
Note: SDK LangBotAPIProxy (legacy) sends 'tool_parameters' and expects 'tool_response'.
|
||||
SDK AgentRunAPIProxy sends 'parameters' and expects 'result'.
|
||||
Handler returns both for backward compatibility.
|
||||
"""
|
||||
tool_name = data['tool_name']
|
||||
parameters = data['parameters']
|
||||
# Support 'tool_parameters' (LangBotAPIProxy) and 'parameters' (AgentRunAPIProxy)
|
||||
parameters = data.get('tool_parameters') or data.get('parameters', {})
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
# session_data = data['session']
|
||||
# query_id = data['query_id']
|
||||
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'tool', tool_name, self.ap
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
# Convert session_data to Session object (simplified)
|
||||
# In real implementation, you would reconstruct the full session
|
||||
# For now, we'll call the tool manager's execute method
|
||||
@@ -408,9 +496,12 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
parameters=parameters,
|
||||
query=None, # TODO: reconstruct query from session_data if needed
|
||||
)
|
||||
# Return both 'tool_response' (LangBotAPIProxy) and 'result' (AgentRunAPIProxy)
|
||||
# LangBotAPIProxy expects 'tool_response', AgentRunAPIProxy expects 'result'
|
||||
return handler.ActionResponse.success(
|
||||
data={
|
||||
'result': result,
|
||||
'tool_response': result,
|
||||
'result': result, # backward compatibility
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -419,6 +510,14 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
message=f'Failed to execute tool {tool_name}: {e}',
|
||||
)
|
||||
|
||||
# ================= Binary Storage Handlers =================
|
||||
# NOTE: These are low-level actions called by SDK Runtime's storage wrapper handlers.
|
||||
# Permission validation is handled in SDK Runtime layer (not here):
|
||||
# - plugin_storage: SDK handler auto-sets owner to caller plugin identity (inherent isolation)
|
||||
# - workspace_storage: SDK handler should validate session.resources.storage.workspace_storage
|
||||
# TODO: SDK storage handlers need to pass run_id and validate workspace_storage permission.
|
||||
# Current risk: workspace storage access is unrestricted from AgentRunner context.
|
||||
|
||||
@self.action(RuntimeToLangBotAction.SET_BINARY_STORAGE)
|
||||
async def set_binary_storage(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Set binary storage"""
|
||||
@@ -706,11 +805,26 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
|
||||
@self.action(PluginToRuntimeAction.RETRIEVE_KNOWLEDGE)
|
||||
async def retrieve_knowledge(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Retrieve documents from any knowledge base (unrestricted)."""
|
||||
"""Retrieve documents from any knowledge base.
|
||||
|
||||
For AgentRunner calls: requires run_id and validates kb_id against session.resources.knowledge_bases.
|
||||
For regular plugin calls: no run_id, unrestricted access (backward compatibility).
|
||||
|
||||
Note: SDK AgentRunAPIProxy.retrieve_knowledge calls this action with run_id.
|
||||
"""
|
||||
kb_id = data['kb_id']
|
||||
query_text = data['query_text']
|
||||
top_k = data.get('top_k', 5)
|
||||
filters = data.get('filters', {})
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'knowledge_base', kb_id, self.ap
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_id)
|
||||
if not kb:
|
||||
@@ -769,12 +883,27 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
|
||||
@self.action(PluginToRuntimeAction.RETRIEVE_KNOWLEDGE_BASE)
|
||||
async def retrieve_knowledge_base(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Retrieve documents from a knowledge base within the pipeline's scope."""
|
||||
"""Retrieve documents from a knowledge base within the pipeline's scope.
|
||||
|
||||
For AgentRunner calls: requires run_id and validates kb_id against session.resources.knowledge_bases.
|
||||
For regular plugin calls: no run_id, validates against pipeline's configured knowledge bases.
|
||||
|
||||
Note: This action has dual validation paths:
|
||||
- AgentRunner: uses session_registry for permission check
|
||||
- Regular plugin: uses ConfigMigration.resolve_runner_config for pipeline-level check
|
||||
|
||||
SECURITY TODO: This handler cannot verify the caller's plugin identity.
|
||||
The session contains 'plugin_identity' (author/name), but we don't have access
|
||||
to which plugin is making the API call. This could allow a malicious plugin to
|
||||
use another plugin's run_id if it can guess/obtain it. Future improvement:
|
||||
track caller plugin identity in RuntimeConnectionHandler or pass it in action data.
|
||||
"""
|
||||
query_id = data['query_id']
|
||||
kb_id = data['kb_id']
|
||||
query_text = data['query_text']
|
||||
top_k = data.get('top_k', 5)
|
||||
filters = data.get('filters', {})
|
||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||
|
||||
if query_id not in self.ap.query_pool.cached_queries:
|
||||
return handler.ActionResponse.error(
|
||||
@@ -783,21 +912,32 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
|
||||
query = self.ap.query_pool.cached_queries[query_id]
|
||||
|
||||
# Validate kb_id is in pipeline's allowed list
|
||||
allowed_kb_uuids = []
|
||||
if query.pipeline_config:
|
||||
from langbot.pkg.agent.runner.config_migration import ConfigMigration
|
||||
runner_config = ConfigMigration.resolve_runner_config(query.pipeline_config, None)
|
||||
allowed_kb_uuids = runner_config.get('knowledge-bases', [])
|
||||
if not allowed_kb_uuids:
|
||||
old_kb_uuid = runner_config.get('knowledge-base', '')
|
||||
if old_kb_uuid and old_kb_uuid != '__none__':
|
||||
allowed_kb_uuids = [old_kb_uuid]
|
||||
|
||||
if kb_id not in allowed_kb_uuids:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Knowledge base {kb_id} is not configured for this pipeline',
|
||||
# Permission validation for AgentRunner calls
|
||||
if run_id:
|
||||
session, error = await _validate_run_authorization(
|
||||
run_id, 'knowledge_base', kb_id, self.ap
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
else:
|
||||
# Regular plugin call: validate against pipeline's configured knowledge bases
|
||||
# FIX: First resolve runner_id, then resolve runner_config
|
||||
allowed_kb_uuids = []
|
||||
if query.pipeline_config:
|
||||
from langbot.pkg.agent.runner.config_migration import ConfigMigration
|
||||
runner_id = ConfigMigration.resolve_runner_id(query.pipeline_config)
|
||||
if runner_id:
|
||||
runner_config = ConfigMigration.resolve_runner_config(query.pipeline_config, runner_id)
|
||||
allowed_kb_uuids = runner_config.get('knowledge-bases', [])
|
||||
if not allowed_kb_uuids:
|
||||
old_kb_uuid = runner_config.get('knowledge-base', '')
|
||||
if old_kb_uuid and old_kb_uuid != '__none__':
|
||||
allowed_kb_uuids = [old_kb_uuid]
|
||||
|
||||
if kb_id not in allowed_kb_uuids:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Knowledge base {kb_id} is not configured for this pipeline',
|
||||
)
|
||||
|
||||
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_id)
|
||||
if not kb:
|
||||
|
||||
Reference in New Issue
Block a user