mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-07 14:26:03 +00:00
feat: make agent runner config schema driven
This commit is contained in:
@@ -25,6 +25,8 @@ from ..entity.persistence import bstorage as persistence_bstorage
|
||||
from ..core import app
|
||||
from ..utils import constants
|
||||
from ..agent.runner.session_registry import get_session_registry
|
||||
from ..agent.runner.config_migration import ConfigMigration
|
||||
from ..agent.runner import config_schema
|
||||
|
||||
|
||||
def _make_rag_error_response(error: Exception, error_type: str, **extra_context) -> handler.ActionResponse:
|
||||
@@ -98,6 +100,46 @@ def _build_tool_detail(tool: Any, requested_tool_name: str | None = None) -> dic
|
||||
}
|
||||
|
||||
|
||||
def _normalize_uuid_list(values: Any) -> list[str]:
|
||||
"""Normalize a user/config supplied UUID list while preserving order."""
|
||||
if not isinstance(values, list):
|
||||
return []
|
||||
return list(
|
||||
dict.fromkeys(
|
||||
value for value in values if isinstance(value, str) and value not in config_schema.NONE_SENTINELS
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _get_pipeline_knowledge_base_uuids(ap: app.Application, query: Any) -> list[str]:
|
||||
"""Resolve pipeline-scoped KBs from preprocessed variables or runner schema."""
|
||||
variables = getattr(query, 'variables', {}) or {}
|
||||
if '_knowledge_base_uuids' in variables:
|
||||
return _normalize_uuid_list(variables.get('_knowledge_base_uuids'))
|
||||
|
||||
pipeline_config = getattr(query, 'pipeline_config', None)
|
||||
if not pipeline_config:
|
||||
return []
|
||||
|
||||
runner_id = ConfigMigration.resolve_runner_id(pipeline_config)
|
||||
if not runner_id:
|
||||
return []
|
||||
|
||||
runner_config = ConfigMigration.resolve_runner_config(pipeline_config, runner_id)
|
||||
registry = getattr(ap, 'agent_runner_registry', None)
|
||||
if registry is None:
|
||||
return []
|
||||
|
||||
bound_plugins = variables.get('_pipeline_bound_plugins')
|
||||
try:
|
||||
descriptor = await registry.get(runner_id, bound_plugins)
|
||||
except Exception as e:
|
||||
ap.logger.warning(f'Failed to load AgentRunner descriptor for pipeline knowledge-base scope: {e}')
|
||||
return []
|
||||
|
||||
return config_schema.extract_knowledge_base_uuids(descriptor, runner_config)
|
||||
|
||||
|
||||
async def _validate_run_authorization(
|
||||
run_id: str,
|
||||
resource_type: str,
|
||||
@@ -1155,15 +1197,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
|
||||
query = self.ap.query_pool.cached_queries[query_id]
|
||||
|
||||
kb_uuids = []
|
||||
if query.pipeline_config:
|
||||
local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {})
|
||||
kb_uuids = local_agent_config.get('knowledge-bases', [])
|
||||
# Backward compatibility
|
||||
if not kb_uuids:
|
||||
old_kb_uuid = local_agent_config.get('knowledge-base', '')
|
||||
if old_kb_uuid and old_kb_uuid != '__none__':
|
||||
kb_uuids = [old_kb_uuid]
|
||||
kb_uuids = await _get_pipeline_knowledge_base_uuids(self.ap, query)
|
||||
|
||||
knowledge_bases = []
|
||||
for kb_uuid in kb_uuids:
|
||||
@@ -1213,19 +1247,9 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
if error:
|
||||
return error
|
||||
else:
|
||||
# Regular plugin call: validate against pipeline's configured knowledge bases
|
||||
# FIX: First resolve runner_id, then resolve runner_config
|
||||
allowed_kb_uuids = []
|
||||
if query.pipeline_config:
|
||||
from langbot.pkg.agent.runner.config_migration import ConfigMigration
|
||||
runner_id = ConfigMigration.resolve_runner_id(query.pipeline_config)
|
||||
if runner_id:
|
||||
runner_config = ConfigMigration.resolve_runner_config(query.pipeline_config, runner_id)
|
||||
allowed_kb_uuids = runner_config.get('knowledge-bases', [])
|
||||
if not allowed_kb_uuids:
|
||||
old_kb_uuid = runner_config.get('knowledge-base', '')
|
||||
if old_kb_uuid and old_kb_uuid != '__none__':
|
||||
allowed_kb_uuids = [old_kb_uuid]
|
||||
# Regular plugin call: validate against the runner binding's
|
||||
# schema-defined KB selectors or the preprocessed query scope.
|
||||
allowed_kb_uuids = await _get_pipeline_knowledge_base_uuids(self.ap, query)
|
||||
|
||||
if kb_id not in allowed_kb_uuids:
|
||||
return handler.ActionResponse.error(
|
||||
@@ -1434,6 +1458,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
|
||||
Yields AgentRunResult dicts.
|
||||
"""
|
||||
timeout = self._get_runner_action_timeout(context)
|
||||
gen = self.call_action_generator(
|
||||
LangBotToRuntimeAction.RUN_AGENT,
|
||||
{
|
||||
@@ -1442,12 +1467,27 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
'runner_name': runner_name,
|
||||
'context': context,
|
||||
},
|
||||
timeout=300,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async for ret in gen:
|
||||
yield ret
|
||||
|
||||
def _get_runner_action_timeout(self, context: dict[str, Any]) -> float:
|
||||
"""Use the run deadline as the transport idle timeout when available."""
|
||||
try:
|
||||
import time
|
||||
|
||||
deadline_at = (context.get('runtime') or {}).get('deadline_at')
|
||||
if deadline_at is None:
|
||||
return 300
|
||||
remaining = float(deadline_at) - time.time()
|
||||
if remaining <= 0:
|
||||
return 0.001
|
||||
return max(remaining + 1.0, 0.001)
|
||||
except (TypeError, ValueError):
|
||||
return 300
|
||||
|
||||
async def get_plugin_icon(self, plugin_author: str, plugin_name: str) -> dict[str, Any]:
|
||||
"""Get plugin icon"""
|
||||
result = await self.call_action(
|
||||
|
||||
Reference in New Issue
Block a user