feat: make agent runner config schema driven

This commit is contained in:
huanghuoguoguo
2026-05-19 12:20:28 +08:00
parent 146694539e
commit 26923c66c0
20 changed files with 905 additions and 239 deletions

View File

@@ -24,7 +24,8 @@ class ConfigMigration:
Responsibilities: Responsibilities:
- Resolve runner ID from new ai.runner.id or old ai.runner.runner - Resolve runner ID from new ai.runner.id or old ai.runner.runner
- Map old built-in runner names to official plugin runner IDs - Map old built-in runner names to official plugin runner IDs
- Extract runner config from ai.runner_config or old ai.<runner-name> - Extract runtime runner config from ai.runner_config
- Migrate old ai.<runner-name> blocks into ai.runner_config
""" """
@staticmethod @staticmethod
@@ -74,9 +75,9 @@ class ConfigMigration:
) -> dict[str, typing.Any]: ) -> dict[str, typing.Any]:
"""Resolve runner binding configuration from pipeline configuration. """Resolve runner binding configuration from pipeline configuration.
Priority: Runtime code should only read the migrated format. Legacy
1. New format: ai.runner_config[runner_id] ai.<runner-name> blocks are handled by migration helpers, not by the
2. Old format: ai.<runner-name> (mapped from runner_id if applicable) hot path.
Args: Args:
pipeline_config: Pipeline configuration dict pipeline_config: Pipeline configuration dict
@@ -92,7 +93,16 @@ class ConfigMigration:
if runner_id in runner_configs: if runner_id in runner_configs:
return runner_configs[runner_id] return runner_configs[runner_id]
# Check old format: ai.<old_runner_name> return {}
@staticmethod
def resolve_legacy_runner_config(
pipeline_config: dict[str, typing.Any],
runner_id: str,
) -> dict[str, typing.Any]:
"""Resolve old ai.<runner-name> config for migration only."""
ai_config = pipeline_config.get('ai', {})
# Try to find old runner name from runner_id # Try to find old runner name from runner_id
old_runner_name = None old_runner_name = None
for old_name, mapped_id in OLD_RUNNER_TO_PLUGIN_RUNNER_ID.items(): for old_name, mapped_id in OLD_RUNNER_TO_PLUGIN_RUNNER_ID.items():
@@ -105,12 +115,6 @@ class ConfigMigration:
if old_config: if old_config:
return old_config return old_config
# If runner_id is plugin:* format, try extracting runner_name as config key
if is_plugin_runner_id(runner_id):
# Some configs might use just the runner_name component as key
# But this is legacy behavior - prefer ai.runner_config[id]
pass
return {} return {}
@staticmethod @staticmethod
@@ -181,6 +185,8 @@ class ConfigMigration:
# Migrate runner config # Migrate runner config
resolved_config = ConfigMigration.resolve_runner_config(pipeline_config, runner_id) resolved_config = ConfigMigration.resolve_runner_config(pipeline_config, runner_id)
if not resolved_config:
resolved_config = ConfigMigration.resolve_legacy_runner_config(pipeline_config, runner_id)
if resolved_config: if resolved_config:
runner_configs[runner_id] = resolved_config runner_configs[runner_id] = resolved_config
# Remove old runner config block # Remove old runner config block
@@ -193,4 +199,4 @@ class ConfigMigration:
ai_config['runner_config'] = runner_configs ai_config['runner_config'] = runner_configs
new_config['ai'] = ai_config new_config['ai'] = ai_config
return new_config return new_config

View File

@@ -0,0 +1,208 @@
"""Helpers for interpreting AgentRunner DynamicForm configuration."""
from __future__ import annotations
import typing
from .descriptor import AgentRunnerDescriptor
LLM_MODEL_SELECTOR_TYPES = {'model-fallback-selector', 'llm-model-selector'}
KB_SELECTOR_TYPES = {'knowledge-base-multi-selector'}
PROMPT_EDITOR_TYPES = {'prompt-editor'}
NONE_SENTINELS = {'', '__none__', '__none'}
def iter_schema_items(
descriptor: AgentRunnerDescriptor | None,
field_types: set[str],
) -> typing.Iterator[dict[str, typing.Any]]:
"""Yield descriptor config schema items whose type is in field_types."""
if descriptor is None:
return
for item in descriptor.config_schema or []:
if not isinstance(item, dict):
continue
if item.get('type') in field_types:
yield item
def has_permission(
descriptor: AgentRunnerDescriptor | None,
name: str,
actions: set[str],
) -> bool:
"""Return whether a runner descriptor requests one of the given actions."""
if descriptor is None:
return False
configured_actions = descriptor.permissions.get(name, [])
return any(action in configured_actions for action in actions)
def uses_host_models(descriptor: AgentRunnerDescriptor | None) -> bool:
"""Return whether LangBot should resolve model resources for this runner."""
return (
has_permission(descriptor, 'models', {'invoke', 'stream', 'list'})
and any(True for _ in iter_schema_items(descriptor, LLM_MODEL_SELECTOR_TYPES))
)
def uses_host_tools(descriptor: AgentRunnerDescriptor | None) -> bool:
"""Return whether LangBot should expose tool resources to this runner."""
return (
descriptor is not None
and descriptor.supports_tool_calling()
and has_permission(descriptor, 'tools', {'list', 'detail', 'call'})
)
def uses_host_knowledge_bases(descriptor: AgentRunnerDescriptor | None) -> bool:
"""Return whether LangBot should expose knowledge-base resources to this runner."""
return (
descriptor is not None
and descriptor.supports_knowledge_retrieval()
and has_permission(descriptor, 'knowledge_bases', {'list', 'retrieve'})
)
def extract_prompt_config(
descriptor: AgentRunnerDescriptor | None,
runner_config: dict[str, typing.Any],
default_prompt: list[dict[str, typing.Any]],
) -> list[dict[str, typing.Any]]:
"""Extract the prompt-editor value selected by the runner schema."""
for item in iter_schema_items(descriptor, PROMPT_EDITOR_TYPES):
field_name = item.get('name')
if field_name and field_name in runner_config:
configured_prompt = runner_config[field_name]
if isinstance(configured_prompt, list):
return configured_prompt
default_value = item.get('default')
if isinstance(default_value, list):
return default_value
return default_prompt
def extract_model_selection(
descriptor: AgentRunnerDescriptor | None,
runner_config: dict[str, typing.Any],
) -> tuple[str, list[str]]:
"""Extract primary/fallback LLM selections from schema-defined fields."""
primary_uuid = ''
fallback_uuids: list[str] = []
for item in iter_schema_items(descriptor, LLM_MODEL_SELECTOR_TYPES):
field_name = item.get('name')
if not field_name:
continue
value = runner_config.get(field_name, item.get('default'))
if item.get('type') == 'model-fallback-selector':
if isinstance(value, str):
primary_uuid = value
elif isinstance(value, dict):
primary_uuid = value.get('primary') or ''
fallbacks = value.get('fallbacks', [])
if isinstance(fallbacks, list):
fallback_uuids = [fallback for fallback in fallbacks if isinstance(fallback, str)]
break
if item.get('type') == 'llm-model-selector' and isinstance(value, str):
primary_uuid = value
break
return primary_uuid, fallback_uuids
def extract_knowledge_base_uuids(
descriptor: AgentRunnerDescriptor | None,
runner_config: dict[str, typing.Any],
) -> list[str]:
"""Extract configured knowledge-base UUIDs from schema-defined fields."""
if not uses_host_knowledge_bases(descriptor):
return []
kb_uuids: list[str] = []
for item in iter_schema_items(descriptor, KB_SELECTOR_TYPES):
field_name = item.get('name')
if not field_name:
continue
value = runner_config.get(field_name, item.get('default', []))
if isinstance(value, list):
kb_uuids.extend(
kb_uuid for kb_uuid in value if isinstance(kb_uuid, str) and kb_uuid not in NONE_SENTINELS
)
return list(dict.fromkeys(kb_uuids))
def iter_config_model_refs(
descriptor: AgentRunnerDescriptor,
runner_config: dict[str, typing.Any],
) -> typing.Iterator[tuple[str, str]]:
"""Yield model references declared by schema-defined model selector fields."""
for item in descriptor.config_schema or []:
if not isinstance(item, dict):
continue
field_name = item.get('name')
field_type = item.get('type')
if not field_name or field_name not in runner_config:
continue
value = runner_config.get(field_name)
if field_type == 'model-fallback-selector':
if isinstance(value, str) and value not in NONE_SENTINELS:
yield 'llm', value
elif isinstance(value, dict):
primary = value.get('primary')
if isinstance(primary, str) and primary not in NONE_SENTINELS:
yield 'llm', primary
fallbacks = value.get('fallbacks', [])
if isinstance(fallbacks, list):
for fallback_uuid in fallbacks:
if isinstance(fallback_uuid, str) and fallback_uuid not in NONE_SENTINELS:
yield 'llm', fallback_uuid
elif field_type == 'llm-model-selector':
if isinstance(value, str) and value not in NONE_SENTINELS:
yield 'llm', value
elif field_type == 'rerank-model-selector':
if isinstance(value, str) and value not in NONE_SENTINELS:
yield 'rerank', value
def set_empty_llm_model_selection(
descriptor: AgentRunnerDescriptor,
runner_config: dict[str, typing.Any],
model_uuid: str,
) -> bool:
"""Set the first empty schema-defined LLM selector to model_uuid."""
for item in iter_schema_items(descriptor, LLM_MODEL_SELECTOR_TYPES):
field_name = item.get('name')
field_type = item.get('type')
if not field_name:
continue
value = runner_config.get(field_name, item.get('default'))
if field_type == 'model-fallback-selector':
if isinstance(value, dict):
primary = value.get('primary') or ''
if primary not in NONE_SENTINELS:
return False
fallbacks = value.get('fallbacks', [])
runner_config[field_name] = {
'primary': model_uuid,
'fallbacks': fallbacks if isinstance(fallbacks, list) else [],
}
return True
if isinstance(value, str) and value not in NONE_SENTINELS:
return False
runner_config[field_name] = {'primary': model_uuid, 'fallbacks': []}
return True
if field_type == 'llm-model-selector':
if isinstance(value, str) and value not in NONE_SENTINELS:
return False
runner_config[field_name] = model_uuid
return True
return False

View File

@@ -15,6 +15,9 @@ from .state_store import get_state_store
from . import events as runner_events from . import events as runner_events
DEFAULT_RUNNER_TIMEOUT_SECONDS = 300
# Internal models for the agent runner context protocol. # Internal models for the agent runner context protocol.
@@ -106,7 +109,7 @@ class AgentRuntimeContext(typing.TypedDict):
sdk_protocol_version: str sdk_protocol_version: str
query_id: int | None query_id: int | None
trace_id: str | None trace_id: str | None
deadline_at: int | None deadline_at: float | None
metadata: dict[str, typing.Any] metadata: dict[str, typing.Any]
@@ -480,9 +483,13 @@ class AgentRunContextBuilder:
}, },
} }
def _build_deadline(self, runner_config: dict[str, typing.Any]) -> int | None: def _build_deadline(self, runner_config: dict[str, typing.Any]) -> float | None:
"""Build deadline timestamp from runner timeout config if present.""" """Build deadline timestamp from runner timeout config.
timeout = runner_config.get('timeout')
A missing timeout uses the host default. Explicit null, zero, or negative
values disable the total run deadline for advanced deployments.
"""
timeout = runner_config.get('timeout', DEFAULT_RUNNER_TIMEOUT_SECONDS)
if timeout is None: if timeout is None:
return None return None
@@ -494,7 +501,7 @@ class AgentRunContextBuilder:
if timeout_seconds <= 0: if timeout_seconds <= 0:
return None return None
return int(time.time() + timeout_seconds) return time.time() + timeout_seconds
async def _is_stream_output_supported(self, query: pipeline_query.Query) -> bool: async def _is_stream_output_supported(self, query: pipeline_query.Query) -> bool:
"""Check whether the current adapter can consume streaming chunks.""" """Check whether the current adapter can consume streaming chunks."""

View File

@@ -3,9 +3,12 @@ from __future__ import annotations
import typing import typing
import traceback import traceback
import asyncio
import time
from langbot_plugin.api.entities.builtin.provider import message as provider_message from langbot_plugin.api.entities.builtin.provider import message as provider_message
from langbot_plugin.api.entities.builtin.pipeline import query as pipeline_query from langbot_plugin.api.entities.builtin.pipeline import query as pipeline_query
from langbot_plugin.entities.io.errors import ActionCallTimeoutError
from ...core import app from ...core import app
from .descriptor import AgentRunnerDescriptor from .descriptor import AgentRunnerDescriptor
@@ -155,14 +158,32 @@ class AgentRunOrchestrator:
) )
try: try:
async for result_dict in self.ap.plugin_connector.run_agent( gen = self.ap.plugin_connector.run_agent(
plugin_author=descriptor.plugin_author, plugin_author=descriptor.plugin_author,
plugin_name=descriptor.plugin_name, plugin_name=descriptor.plugin_name,
runner_name=descriptor.runner_name, runner_name=descriptor.runner_name,
context=context, context=context,
): )
while True:
try:
result_dict = await self._next_with_deadline(gen, descriptor, context)
except StopAsyncIteration:
break
yield result_dict yield result_dict
except asyncio.TimeoutError as e:
raise RunnerExecutionError(
descriptor.id,
'Runner timed out (code: runner.timeout)',
retryable=True,
) from e
except ActionCallTimeoutError as e:
raise RunnerExecutionError(
descriptor.id,
f'{e} (code: runner.timeout)',
retryable=True,
) from e
except RunnerExecutionError: except RunnerExecutionError:
raise raise
except Exception as e: except Exception as e:
@@ -176,6 +197,57 @@ class AgentRunOrchestrator:
retryable=False, retryable=False,
) )
async def _next_with_deadline(
self,
gen: typing.AsyncGenerator[dict[str, typing.Any], None],
descriptor: AgentRunnerDescriptor,
context: AgentRunContextPayload,
) -> dict[str, typing.Any]:
"""Read the next runner result while enforcing the run deadline."""
remaining = self._remaining_deadline_seconds(context)
if remaining is not None and remaining <= 0:
await self._close_generator(gen, descriptor)
raise asyncio.TimeoutError
try:
if remaining is None:
return await anext(gen)
return await asyncio.wait_for(anext(gen), timeout=remaining)
except StopAsyncIteration:
if self._is_deadline_exhausted(context):
raise asyncio.TimeoutError
raise
except asyncio.TimeoutError:
await self._close_generator(gen, descriptor)
raise
def _remaining_deadline_seconds(
self,
context: AgentRunContextPayload,
) -> float | None:
runtime = context.get('runtime') or {}
deadline_at = runtime.get('deadline_at')
if deadline_at is None:
return None
try:
return float(deadline_at) - time.time()
except (TypeError, ValueError):
return None
def _is_deadline_exhausted(self, context: AgentRunContextPayload) -> bool:
remaining = self._remaining_deadline_seconds(context)
return remaining is not None and remaining <= 0
async def _close_generator(
self,
gen: typing.AsyncGenerator[dict[str, typing.Any], None],
descriptor: AgentRunnerDescriptor,
) -> None:
try:
await gen.aclose()
except Exception as e:
self.ap.logger.warning(f'Failed to close timed-out runner {descriptor.id}: {e}')
def resolve_runner_id_for_telemetry(self, query: pipeline_query.Query) -> str | None: def resolve_runner_id_for_telemetry(self, query: pipeline_query.Query) -> str | None:
"""Resolve runner ID for telemetry/logging without full execution. """Resolve runner ID for telemetry/logging without full execution.

View File

@@ -13,6 +13,7 @@ from .context_builder import (
KnowledgeBaseResource, KnowledgeBaseResource,
StorageResource, StorageResource,
) )
from . import config_schema
class AgentResourceBuilder: class AgentResourceBuilder:
@@ -73,7 +74,7 @@ class AgentResourceBuilder:
models, tools, knowledge_bases = await asyncio.gather( models, tools, knowledge_bases = await asyncio.gather(
self._build_models(manifest_perms, runner_config, descriptor, query), self._build_models(manifest_perms, runner_config, descriptor, query),
self._build_tools(manifest_perms, bound_plugins, bound_mcp_servers, query), self._build_tools(manifest_perms, bound_plugins, bound_mcp_servers, query),
self._build_knowledge_bases(manifest_perms, runner_config, query), self._build_knowledge_bases(manifest_perms, runner_config, descriptor, query),
) )
storage = self._build_storage(manifest_perms) storage = self._build_storage(manifest_perms)
@@ -132,34 +133,11 @@ class AgentResourceBuilder:
runner_config: dict[str, typing.Any], runner_config: dict[str, typing.Any],
) -> None: ) -> None:
"""Authorize model-like values selected through DynamicForm fields.""" """Authorize model-like values selected through DynamicForm fields."""
for item in descriptor.config_schema or []: for model_type, model_uuid in config_schema.iter_config_model_refs(descriptor, runner_config):
if not isinstance(item, dict): if model_type == 'llm':
continue await self._append_llm_model_resource(models, seen_model_ids, model_uuid)
elif model_type == 'rerank':
field_name = item.get('name') await self._append_rerank_model_resource(models, seen_model_ids, model_uuid)
field_type = item.get('type')
if not field_name or field_name not in runner_config:
continue
value = runner_config.get(field_name)
if field_type == 'model-fallback-selector':
if isinstance(value, str):
await self._append_llm_model_resource(models, seen_model_ids, value)
elif isinstance(value, dict):
primary = value.get('primary')
if isinstance(primary, str):
await self._append_llm_model_resource(models, seen_model_ids, primary)
fallbacks = value.get('fallbacks', [])
if isinstance(fallbacks, list):
for fallback_uuid in fallbacks:
if isinstance(fallback_uuid, str):
await self._append_llm_model_resource(models, seen_model_ids, fallback_uuid)
elif field_type == 'llm-model-selector':
if isinstance(value, str):
await self._append_llm_model_resource(models, seen_model_ids, value)
elif field_type == 'rerank-model-selector':
if isinstance(value, str):
await self._append_rerank_model_resource(models, seen_model_ids, value)
async def _append_llm_model_resource( async def _append_llm_model_resource(
self, self,
@@ -236,6 +214,7 @@ class AgentResourceBuilder:
self, self,
manifest_perms: dict[str, list[str]], manifest_perms: dict[str, list[str]],
runner_config: dict[str, typing.Any], runner_config: dict[str, typing.Any],
descriptor: AgentRunnerDescriptor,
query: typing.Any, query: typing.Any,
) -> list[KnowledgeBaseResource]: ) -> list[KnowledgeBaseResource]:
"""Build knowledge bases list with plugin SDK field names.""" """Build knowledge bases list with plugin SDK field names."""
@@ -246,13 +225,8 @@ class AgentResourceBuilder:
if 'list' not in kb_perms and 'retrieve' not in kb_perms: if 'list' not in kb_perms and 'retrieve' not in kb_perms:
return kb_resources return kb_resources
# Get knowledge base UUIDs from config # Get knowledge base UUIDs from schema-defined config fields.
kb_uuids = runner_config.get('knowledge-bases', []) kb_uuids = config_schema.extract_knowledge_base_uuids(descriptor, runner_config)
if not kb_uuids:
# Old single KB config
old_kb_uuid = runner_config.get('knowledge-base', '')
if old_kb_uuid and old_kb_uuid != '__none__':
kb_uuids = [old_kb_uuid]
# Also check query variables (may be modified by plugin PromptPreProcessing) # Also check query variables (may be modified by plugin PromptPreProcessing)
kb_uuids_from_vars = query.variables.get('_knowledge_base_uuids', []) kb_uuids_from_vars = query.variables.get('_knowledge_base_uuids', [])

View File

@@ -9,6 +9,8 @@ from ....core import app
from ....entity.persistence import model as persistence_model from ....entity.persistence import model as persistence_model
from ....entity.persistence import pipeline as persistence_pipeline from ....entity.persistence import pipeline as persistence_pipeline
from ....provider.modelmgr import requester as model_requester from ....provider.modelmgr import requester as model_requester
from ....agent.runner.config_migration import ConfigMigration
from ....agent.runner import config_schema
def _parse_provider_api_keys(provider_dict: dict) -> dict: def _parse_provider_api_keys(provider_dict: dict) -> dict:
@@ -40,6 +42,40 @@ class LLMModelsService:
def __init__(self, ap: app.Application) -> None: def __init__(self, ap: app.Application) -> None:
self.ap = ap self.ap = ap
async def _get_runner_descriptor(self, runner_id: str):
registry = getattr(self.ap, 'agent_runner_registry', None)
if registry is None:
return None
try:
return await registry.get(runner_id, bound_plugins=None)
except Exception as e:
logger = getattr(self.ap, 'logger', None)
if logger:
logger.warning(f'Failed to load AgentRunner descriptor while setting default model: {e}')
return None
async def _auto_set_default_pipeline_llm_model(self, pipeline: persistence_pipeline.LegacyPipeline, model_uuid: str):
pipeline_config = pipeline.config
if not isinstance(pipeline_config, dict):
return
runner_id = ConfigMigration.resolve_runner_id(pipeline_config)
if not runner_id:
return
descriptor = await self._get_runner_descriptor(runner_id)
if descriptor is None:
return
ai_config = pipeline_config.setdefault('ai', {})
runner_configs = ai_config.setdefault('runner_config', {})
runner_config = runner_configs.setdefault(runner_id, {})
if not config_schema.set_empty_llm_model_selection(descriptor, runner_config, model_uuid):
return
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, {'config': pipeline_config})
async def get_llm_models(self, include_secret: bool = True) -> list[dict]: async def get_llm_models(self, include_secret: bool = True) -> list[dict]:
"""Get all LLM models with provider info""" """Get all LLM models with provider info"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
@@ -109,7 +145,6 @@ class LLMModelsService:
self.ap.model_mgr.llm_models.append(runtime_llm_model) self.ap.model_mgr.llm_models.append(runtime_llm_model)
if auto_set_to_default_pipeline: if auto_set_to_default_pipeline:
# set the default pipeline model to this model
result = await self.ap.persistence_mgr.execute_async( result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where( sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.is_default == True persistence_pipeline.LegacyPipeline.is_default == True
@@ -117,15 +152,7 @@ class LLMModelsService:
) )
pipeline = result.first() pipeline = result.first()
if pipeline is not None: if pipeline is not None:
model_config = pipeline.config.get('ai', {}).get('local-agent', {}).get('model', {}) await self._auto_set_default_pipeline_llm_model(pipeline, model_data['uuid'])
if not model_config.get('primary', ''):
pipeline_config = pipeline.config
pipeline_config['ai']['local-agent']['model'] = {
'primary': model_data['uuid'],
'fallbacks': [],
}
pipeline_data = {'config': pipeline_config}
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
return model_data['uuid'] return model_data['uuid']

View File

@@ -11,7 +11,8 @@ class RoundTruncator(truncator.Truncator):
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query: async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
"""截断""" """截断"""
# Get max-round from runner config (new or old format) # max-round remains a pipeline-side trimming knob until token-budget
# based compaction replaces this stage.
runner_id = ConfigMigration.resolve_runner_id(query.pipeline_config) runner_id = ConfigMigration.resolve_runner_id(query.pipeline_config)
runner_config = ConfigMigration.resolve_runner_config(query.pipeline_config, runner_id) if runner_id else {} runner_config = ConfigMigration.resolve_runner_config(query.pipeline_config, runner_id) if runner_id else {}
max_round = runner_config.get('max-round', 10) max_round = runner_config.get('max-round', 10)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import datetime import datetime
import typing
from .. import stage, entities from .. import stage, entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message from langbot_plugin.api.entities.builtin.provider import message as provider_message
@@ -9,12 +10,14 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.events as platform_events import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ...agent.runner.descriptor import AgentRunnerDescriptor
from ...agent.runner.config_migration import ConfigMigration from ...agent.runner.config_migration import ConfigMigration
from ...agent.runner import config_schema
# Official local-agent runner ID DEFAULT_PROMPT_CONFIG = [
LOCAL_AGENT_RUNNER_ID = 'plugin:langbot/local-agent/default' {'role': 'system', 'content': 'You are a helpful assistant.'},
]
@stage.stage_class('PreProcessor') @stage.stage_class('PreProcessor')
class PreProcessor(stage.PipelineStage): class PreProcessor(stage.PipelineStage):
@@ -31,6 +34,76 @@ class PreProcessor(stage.PipelineStage):
- use_funcs - use_funcs
""" """
async def _get_runner_descriptor(
self,
runner_id: str | None,
bound_plugins: list[str] | None,
) -> AgentRunnerDescriptor | None:
if not runner_id:
return None
registry = getattr(self.ap, 'agent_runner_registry', None)
if registry is None:
return None
try:
return await registry.get(runner_id, bound_plugins)
except Exception as e:
self.ap.logger.debug(f'Unable to load AgentRunner descriptor for {runner_id}: {e}')
return None
async def _resolve_llm_model(
self,
primary_uuid: str,
) -> typing.Any | None:
if primary_uuid in config_schema.NONE_SENTINELS:
return None
try:
return await self.ap.model_mgr.get_model_by_uuid(primary_uuid)
except ValueError:
self.ap.logger.warning(f'LLM model {primary_uuid} not found or not configured')
return None
async def _resolve_fallback_models(self, fallback_uuids: list[str]) -> list[str]:
valid_fallbacks = []
for fallback_uuid in fallback_uuids:
if fallback_uuid in config_schema.NONE_SENTINELS:
continue
try:
await self.ap.model_mgr.get_model_by_uuid(fallback_uuid)
valid_fallbacks.append(fallback_uuid)
except ValueError:
self.ap.logger.warning(f'Fallback model {fallback_uuid} not found, skipping')
return valid_fallbacks
def _runner_accepts_multimodal_input(self, descriptor: AgentRunnerDescriptor | None) -> bool:
if descriptor is None:
return True
return descriptor.capabilities.get('multimodal_input', False)
def _model_supports_vision(self, llm_model: typing.Any | None) -> bool:
if not llm_model:
return False
abilities = getattr(getattr(llm_model, 'model_entity', None), 'abilities', [])
return 'vision' in abilities
def _should_keep_image_inputs(
self,
descriptor: AgentRunnerDescriptor | None,
uses_host_models: bool,
llm_model: typing.Any | None,
) -> bool:
if not self._runner_accepts_multimodal_input(descriptor):
return False
if uses_host_models:
return self._model_supports_vision(llm_model)
return True
def _strip_images_from_history(self, query: pipeline_query.Query) -> None:
for msg in query.messages:
if isinstance(msg.content, list):
msg.content = [elem for elem in msg.content if elem.type != 'image_url']
async def process( async def process(
self, self,
query: pipeline_query.Query, query: pipeline_query.Query,
@@ -40,56 +113,25 @@ class PreProcessor(stage.PipelineStage):
# Resolve runner ID using ConfigMigration (supports both new and old formats) # Resolve runner ID using ConfigMigration (supports both new and old formats)
runner_id = ConfigMigration.resolve_runner_id(query.pipeline_config) runner_id = ConfigMigration.resolve_runner_id(query.pipeline_config)
# Get runner config (from new ai.runner_config or old ai.<runner-name>) # Get runner config from ai.runner_config[runner_id].
runner_config = ConfigMigration.resolve_runner_config(query.pipeline_config, runner_id) if runner_id else {} runner_config = ConfigMigration.resolve_runner_config(query.pipeline_config, runner_id) if runner_id else {}
query.variables = query.variables or {}
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None)
descriptor = await self._get_runner_descriptor(runner_id, bound_plugins)
session = await self.ap.sess_mgr.get_session(query) session = await self.ap.sess_mgr.get_session(query)
# Determine if this is a local-agent runner (built-in LLM capabilities) uses_host_models = config_schema.uses_host_models(descriptor)
# Check by runner_id OR by legacy runner field for backward compatibility
is_local_agent = runner_id == LOCAL_AGENT_RUNNER_ID or (
runner_id is None and
query.pipeline_config.get('ai', {}).get('runner', {}).get('runner') == 'local-agent'
)
# When not local-agent, llm_model is None
llm_model = None llm_model = None
if is_local_agent: if uses_host_models:
# Read model config — new format is { primary: str, fallbacks: [str] }, primary_uuid, fallback_uuids = config_schema.extract_model_selection(descriptor, runner_config)
# but handle legacy plain string for backward compatibility llm_model = await self._resolve_llm_model(primary_uuid)
model_config = runner_config.get('model', {}) valid_fallbacks = await self._resolve_fallback_models(fallback_uuids)
if isinstance(model_config, str): if valid_fallbacks:
# Legacy format: plain UUID string query.variables['_fallback_model_uuids'] = valid_fallbacks
primary_uuid = model_config
fallback_uuids = []
else:
primary_uuid = model_config.get('primary', '')
fallback_uuids = model_config.get('fallbacks', [])
if primary_uuid: prompt_config = config_schema.extract_prompt_config(descriptor, runner_config, DEFAULT_PROMPT_CONFIG)
try:
llm_model = await self.ap.model_mgr.get_model_by_uuid(primary_uuid)
except ValueError:
self.ap.logger.warning(f'LLM model {primary_uuid} not found or not configured')
# Resolve fallback model UUIDs
if fallback_uuids:
valid_fallbacks = []
for fb_uuid in fallback_uuids:
try:
await self.ap.model_mgr.get_model_by_uuid(fb_uuid)
valid_fallbacks.append(fb_uuid)
except ValueError:
self.ap.logger.warning(f'Fallback model {fb_uuid} not found, skipping')
if valid_fallbacks:
query.variables['_fallback_model_uuids'] = valid_fallbacks
# Get prompt config - for local-agent, use runner_config; for others, use default prompt
prompt_config = runner_config.get('prompt', [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]) if is_local_agent else [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
conversation = await self.ap.sess_mgr.get_conversation( conversation = await self.ap.sess_mgr.get_conversation(
query, query,
@@ -125,15 +167,14 @@ class PreProcessor(stage.PipelineStage):
query.prompt = conversation.prompt.copy() query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy() query.messages = conversation.messages.copy()
if is_local_agent: if uses_host_models:
query.use_funcs = [] query.use_funcs = []
if llm_model: if llm_model:
query.use_llm_model_uuid = llm_model.model_entity.uuid query.use_llm_model_uuid = llm_model.model_entity.uuid
if llm_model.model_entity.abilities.__contains__('func_call'): if config_schema.uses_host_tools(descriptor) and llm_model.model_entity.abilities.__contains__(
# Get bound plugins and MCP servers for filtering tools 'func_call'
bound_plugins = query.variables.get('_pipeline_bound_plugins', None) ):
bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None)
query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers) query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers)
self.ap.logger.debug(f'Bound plugins: {bound_plugins}') self.ap.logger.debug(f'Bound plugins: {bound_plugins}')
@@ -142,9 +183,11 @@ class PreProcessor(stage.PipelineStage):
# If primary model doesn't support func_call but fallback models exist, # If primary model doesn't support func_call but fallback models exist,
# load tools anyway since fallback models may support them # load tools anyway since fallback models may support them
if not query.use_funcs and query.variables.get('_fallback_model_uuids'): if (
bound_plugins = query.variables.get('_pipeline_bound_plugins', None) config_schema.uses_host_tools(descriptor)
bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None) and not query.use_funcs
and query.variables.get('_fallback_model_uuids')
):
query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers) query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers)
sender_name = '' sender_name = ''
@@ -170,18 +213,9 @@ class PreProcessor(stage.PipelineStage):
} }
query.variables.update(variables) query.variables.update(variables)
# Check if this model supports vision, if not, remove all images keep_image_inputs = self._should_keep_image_inputs(descriptor, uses_host_models, llm_model)
# TODO this checking should be performed in runner, and in this stage, the image should be reserved if not keep_image_inputs:
if ( self._strip_images_from_history(query)
is_local_agent
and llm_model
and not llm_model.model_entity.abilities.__contains__('vision')
):
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
if me.type == 'image_url':
msg.content.remove(me)
content_list: list[provider_message.ContentElement] = [] content_list: list[provider_message.ContentElement] = []
@@ -193,10 +227,7 @@ class PreProcessor(stage.PipelineStage):
content_list.append(provider_message.ContentElement.from_text(me.text)) content_list.append(provider_message.ContentElement.from_text(me.text))
plain_text += me.text plain_text += me.text
elif isinstance(me, platform_message.Image): elif isinstance(me, platform_message.Image):
# Allow images for non-local-agent runners or if local-agent has vision if keep_image_inputs:
if not is_local_agent or (
llm_model and llm_model.model_entity.abilities.__contains__('vision')
):
if me.base64 is not None: if me.base64 is not None:
content_list.append(provider_message.ContentElement.from_image_base64(me.base64)) content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
elif isinstance(me, platform_message.Voice): elif isinstance(me, platform_message.Voice):
@@ -215,9 +246,7 @@ class PreProcessor(stage.PipelineStage):
if isinstance(msg, platform_message.Plain): if isinstance(msg, platform_message.Plain):
content_list.append(provider_message.ContentElement.from_text(msg.text)) content_list.append(provider_message.ContentElement.from_text(msg.text))
elif isinstance(msg, platform_message.Image): elif isinstance(msg, platform_message.Image):
if not is_local_agent or ( if keep_image_inputs:
llm_model and llm_model.model_entity.abilities.__contains__('vision')
):
if msg.base64 is not None: if msg.base64 is not None:
content_list.append(provider_message.ContentElement.from_image_base64(msg.base64)) content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))
elif isinstance(msg, platform_message.File): elif isinstance(msg, platform_message.File):
@@ -237,15 +266,12 @@ class PreProcessor(stage.PipelineStage):
query.user_message = provider_message.Message(role='user', content=content_list) query.user_message = provider_message.Message(role='user', content=content_list)
# Extract knowledge base UUIDs into query variables so plugins can modify them # Extract configured KB UUIDs into query variables so PromptPreProcessing
# during PromptPreProcessing before the runner performs retrieval. # plugins can still adjust the authorized retrieval set before run_agent.
# Only for local-agent runner query.variables['_knowledge_base_uuids'] = config_schema.extract_knowledge_base_uuids(
kb_uuids = runner_config.get('knowledge-bases', []) if is_local_agent else [] descriptor,
if not kb_uuids: runner_config,
old_kb_uuid = runner_config.get('knowledge-base', '') if is_local_agent else '' )
if old_kb_uuid and old_kb_uuid != '__none__':
kb_uuids = [old_kb_uuid]
query.variables['_knowledge_base_uuids'] = list(kb_uuids)
# =========== 触发事件 PromptPreProcessing # =========== 触发事件 PromptPreProcessing
@@ -263,4 +289,4 @@ class PreProcessor(stage.PipelineStage):
query.prompt.messages = event_ctx.event.default_prompt query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt query.messages = event_ctx.event.prompt
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -25,6 +25,8 @@ from ..entity.persistence import bstorage as persistence_bstorage
from ..core import app from ..core import app
from ..utils import constants from ..utils import constants
from ..agent.runner.session_registry import get_session_registry from ..agent.runner.session_registry import get_session_registry
from ..agent.runner.config_migration import ConfigMigration
from ..agent.runner import config_schema
def _make_rag_error_response(error: Exception, error_type: str, **extra_context) -> handler.ActionResponse: def _make_rag_error_response(error: Exception, error_type: str, **extra_context) -> handler.ActionResponse:
@@ -98,6 +100,46 @@ def _build_tool_detail(tool: Any, requested_tool_name: str | None = None) -> dic
} }
def _normalize_uuid_list(values: Any) -> list[str]:
"""Normalize a user/config supplied UUID list while preserving order."""
if not isinstance(values, list):
return []
return list(
dict.fromkeys(
value for value in values if isinstance(value, str) and value not in config_schema.NONE_SENTINELS
)
)
async def _get_pipeline_knowledge_base_uuids(ap: app.Application, query: Any) -> list[str]:
"""Resolve pipeline-scoped KBs from preprocessed variables or runner schema."""
variables = getattr(query, 'variables', {}) or {}
if '_knowledge_base_uuids' in variables:
return _normalize_uuid_list(variables.get('_knowledge_base_uuids'))
pipeline_config = getattr(query, 'pipeline_config', None)
if not pipeline_config:
return []
runner_id = ConfigMigration.resolve_runner_id(pipeline_config)
if not runner_id:
return []
runner_config = ConfigMigration.resolve_runner_config(pipeline_config, runner_id)
registry = getattr(ap, 'agent_runner_registry', None)
if registry is None:
return []
bound_plugins = variables.get('_pipeline_bound_plugins')
try:
descriptor = await registry.get(runner_id, bound_plugins)
except Exception as e:
ap.logger.warning(f'Failed to load AgentRunner descriptor for pipeline knowledge-base scope: {e}')
return []
return config_schema.extract_knowledge_base_uuids(descriptor, runner_config)
async def _validate_run_authorization( async def _validate_run_authorization(
run_id: str, run_id: str,
resource_type: str, resource_type: str,
@@ -1155,15 +1197,7 @@ class RuntimeConnectionHandler(handler.Handler):
query = self.ap.query_pool.cached_queries[query_id] query = self.ap.query_pool.cached_queries[query_id]
kb_uuids = [] kb_uuids = await _get_pipeline_knowledge_base_uuids(self.ap, query)
if query.pipeline_config:
local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {})
kb_uuids = local_agent_config.get('knowledge-bases', [])
# Backward compatibility
if not kb_uuids:
old_kb_uuid = local_agent_config.get('knowledge-base', '')
if old_kb_uuid and old_kb_uuid != '__none__':
kb_uuids = [old_kb_uuid]
knowledge_bases = [] knowledge_bases = []
for kb_uuid in kb_uuids: for kb_uuid in kb_uuids:
@@ -1213,19 +1247,9 @@ class RuntimeConnectionHandler(handler.Handler):
if error: if error:
return error return error
else: else:
# Regular plugin call: validate against pipeline's configured knowledge bases # Regular plugin call: validate against the runner binding's
# FIX: First resolve runner_id, then resolve runner_config # schema-defined KB selectors or the preprocessed query scope.
allowed_kb_uuids = [] allowed_kb_uuids = await _get_pipeline_knowledge_base_uuids(self.ap, query)
if query.pipeline_config:
from langbot.pkg.agent.runner.config_migration import ConfigMigration
runner_id = ConfigMigration.resolve_runner_id(query.pipeline_config)
if runner_id:
runner_config = ConfigMigration.resolve_runner_config(query.pipeline_config, runner_id)
allowed_kb_uuids = runner_config.get('knowledge-bases', [])
if not allowed_kb_uuids:
old_kb_uuid = runner_config.get('knowledge-base', '')
if old_kb_uuid and old_kb_uuid != '__none__':
allowed_kb_uuids = [old_kb_uuid]
if kb_id not in allowed_kb_uuids: if kb_id not in allowed_kb_uuids:
return handler.ActionResponse.error( return handler.ActionResponse.error(
@@ -1424,6 +1448,7 @@ class RuntimeConnectionHandler(handler.Handler):
Yields AgentRunResult dicts. Yields AgentRunResult dicts.
""" """
timeout = self._get_runner_action_timeout(context)
gen = self.call_action_generator( gen = self.call_action_generator(
LangBotToRuntimeAction.RUN_AGENT, LangBotToRuntimeAction.RUN_AGENT,
{ {
@@ -1432,12 +1457,27 @@ class RuntimeConnectionHandler(handler.Handler):
'runner_name': runner_name, 'runner_name': runner_name,
'context': context, 'context': context,
}, },
timeout=300, timeout=timeout,
) )
async for ret in gen: async for ret in gen:
yield ret yield ret
def _get_runner_action_timeout(self, context: dict[str, Any]) -> float:
"""Use the run deadline as the transport idle timeout when available."""
try:
import time
deadline_at = (context.get('runtime') or {}).get('deadline_at')
if deadline_at is None:
return 300
remaining = float(deadline_at) - time.time()
if remaining <= 0:
return 0.001
return max(remaining + 1.0, 0.001)
except (TypeError, ValueError):
return 300
async def get_plugin_icon(self, plugin_author: str, plugin_name: str) -> dict[str, Any]: async def get_plugin_icon(self, plugin_author: str, plugin_name: str) -> dict[str, Any]:
"""Get plugin icon""" """Get plugin icon"""
result = await self.call_action( result = await self.call_action(

View File

@@ -18,6 +18,7 @@ import langbot_plugin.api.entities.builtin.provider.session as provider_session
# Counter for generating unique IDs # Counter for generating unique IDs
_query_counter = 0 _query_counter = 0
DEFAULT_RUNNER_ID = "plugin:langbot/local-agent/default"
def _next_query_id() -> int: def _next_query_id() -> int:
@@ -163,10 +164,12 @@ def _base_query(
"bot_uuid": "test-bot-uuid", "bot_uuid": "test-bot-uuid",
"pipeline_config": { "pipeline_config": {
"ai": { "ai": {
"runner": {"runner": "local-agent"}, "runner": {"id": DEFAULT_RUNNER_ID},
"local-agent": { "runner_config": {
"model": {"primary": "test-model-uuid", "fallbacks": []}, DEFAULT_RUNNER_ID: {
"prompt": "test-prompt", "model": {"primary": "test-model-uuid", "fallbacks": []},
"prompt": [{"role": "system", "content": "test-prompt"}],
},
}, },
}, },
"output": {"misc": {"at-sender": False, "quote-origin": False}}, "output": {"misc": {"at-sender": False, "quote-origin": False}},
@@ -469,4 +472,4 @@ def at_all_query(
sender_id=sender_id, sender_id=sender_id,
adapter=adapter, adapter=adapter,
**overrides, **overrides,
) )

View File

@@ -132,7 +132,7 @@ class TestResolveRunnerConfig:
assert config == {'model': 'uuid-123', 'max_round': 10} assert config == {'model': 'uuid-123', 'max_round': 10}
def test_resolve_old_format_config(self): def test_resolve_old_format_config(self):
"""Resolve runner config from old format.""" """Runtime config resolver should not read old format."""
pipeline_config = { pipeline_config = {
'ai': { 'ai': {
'local-agent': { 'local-agent': {
@@ -146,6 +146,23 @@ class TestResolveRunnerConfig:
pipeline_config, pipeline_config,
'plugin:langbot/local-agent/default', 'plugin:langbot/local-agent/default',
) )
assert config == {}
def test_resolve_legacy_config_for_migration(self):
"""Migration helper should read old format."""
pipeline_config = {
'ai': {
'local-agent': {
'model': 'uuid-123',
'max_round': 10,
},
},
}
config = ConfigMigration.resolve_legacy_runner_config(
pipeline_config,
'plugin:langbot/local-agent/default',
)
assert config == {'model': 'uuid-123', 'max_round': 10} assert config == {'model': 'uuid-123', 'max_round': 10}
def test_resolve_no_config(self): def test_resolve_no_config(self):
@@ -228,4 +245,4 @@ class TestGetOldRunnerName:
def test_get_old_runner_name_not_mapped(self): def test_get_old_runner_name_not_mapped(self):
"""Get old runner name for unmapped runner ID.""" """Get old runner name for unmapped runner ID."""
old_name = ConfigMigration.get_old_runner_name('plugin:alice/my-agent/custom') old_name = ConfigMigration.get_old_runner_name('plugin:alice/my-agent/custom')
assert old_name is None assert old_name is None

View File

@@ -229,8 +229,8 @@ class TestResolveRunnerIdBackwardCompat:
assert runner_id == 'plugin:new-runner/default' assert runner_id == 'plugin:new-runner/default'
class TestResolveRunnerConfigBackwardCompat: class TestResolveRunnerConfig:
"""Tests for backward compatibility in resolve_runner_config.""" """Tests for runtime runner config resolution."""
def test_resolve_new_format_config(self): def test_resolve_new_format_config(self):
"""resolve_runner_config should read from runner_config.""" """resolve_runner_config should read from runner_config."""
@@ -245,13 +245,23 @@ class TestResolveRunnerConfigBackwardCompat:
assert runner_config['max-round'] == 20 assert runner_config['max-round'] == 20
def test_resolve_old_format_config(self): def test_resolve_old_format_config(self):
"""resolve_runner_config should read from old ai.local-agent.""" """resolve_runner_config should not read old ai.local-agent at runtime."""
config = { config = {
'ai': { 'ai': {
'local-agent': {'max-round': 15}, 'local-agent': {'max-round': 15},
}, },
} }
runner_config = ConfigMigration.resolve_runner_config(config, 'plugin:langbot/local-agent/default') runner_config = ConfigMigration.resolve_runner_config(config, 'plugin:langbot/local-agent/default')
assert runner_config == {}
def test_resolve_legacy_runner_config_for_migration(self):
"""resolve_legacy_runner_config should read old ai.local-agent for migration."""
config = {
'ai': {
'local-agent': {'max-round': 15},
},
}
runner_config = ConfigMigration.resolve_legacy_runner_config(config, 'plugin:langbot/local-agent/default')
assert runner_config['max-round'] == 15 assert runner_config['max-round'] == 15
def test_resolve_new_format_priority(self): def test_resolve_new_format_priority(self):

View File

@@ -16,8 +16,9 @@ import pytest
import types import types
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry
from langbot.pkg.plugin.handler import _build_tool_detail from langbot.pkg.plugin.handler import _build_tool_detail, _get_pipeline_knowledge_base_uuids
# Import shared test fixtures from conftest.py # Import shared test fixtures from conftest.py
from .conftest import make_resources from .conftest import make_resources
@@ -105,11 +106,53 @@ class MockApplication:
self.persistence_mgr.execute_async = AsyncMock(return_value=MagicMock(first=lambda: None)) self.persistence_mgr.execute_async = AsyncMock(return_value=MagicMock(first=lambda: None))
class FakeAgentRunnerRegistry:
async def get(self, runner_id, bound_plugins=None):
return AgentRunnerDescriptor(
id=runner_id,
source='plugin',
label={'en_US': 'Test Runner'},
plugin_author='test',
plugin_name='runner',
runner_name='default',
config_schema=[
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector', 'default': []},
],
capabilities={'knowledge_retrieval': True},
permissions={'knowledge_bases': ['list', 'retrieve']},
)
class MockConnection: class MockConnection:
"""Mock connection for testing.""" """Mock connection for testing."""
pass pass
class TestPipelineKnowledgeBaseScope:
"""Tests for schema-driven pipeline KB scope resolution."""
@pytest.mark.asyncio
async def test_uses_preprocessed_query_scope(self):
app = MockApplication()
query = MockQuery()
query.variables = {'_knowledge_base_uuids': ['kb_var', '__none__', 'kb_var']}
kb_uuids = await _get_pipeline_knowledge_base_uuids(app, query)
assert kb_uuids == ['kb_var']
@pytest.mark.asyncio
async def test_uses_runner_schema_when_query_scope_not_preprocessed(self):
app = MockApplication()
app.agent_runner_registry = FakeAgentRunnerRegistry()
query = MockQuery()
query.variables = {}
kb_uuids = await _get_pipeline_knowledge_base_uuids(app, query)
assert kb_uuids == ['kb_001', 'kb_002']
class MockDisconnectCallback: class MockDisconnectCallback:
"""Mock disconnect callback for testing.""" """Mock disconnect callback for testing."""
async def __call__(self): async def __call__(self):

View File

@@ -1,6 +1,7 @@
"""Integration-style tests for AgentRunOrchestrator with a fake plugin runner.""" """Integration-style tests for AgentRunOrchestrator with a fake plugin runner."""
from __future__ import annotations from __future__ import annotations
import asyncio
import datetime import datetime
import types import types
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
@@ -61,9 +62,10 @@ class FakeKnowledgeBase:
class FakePluginConnector: class FakePluginConnector:
is_enable_plugin = True is_enable_plugin = True
def __init__(self, results=None, error: Exception | None = None): def __init__(self, results=None, error: Exception | None = None, delay: float = 0):
self.results = results or [] self.results = results or []
self.error = error self.error = error
self.delay = delay
self.calls: list[dict] = [] self.calls: list[dict] = []
self.contexts: list[dict] = [] self.contexts: list[dict] = []
self.sessions_during_run: list[dict | None] = [] self.sessions_during_run: list[dict | None] = []
@@ -83,6 +85,8 @@ class FakePluginConnector:
raise self.error raise self.error
for result in self.results: for result in self.results:
if self.delay:
await asyncio.sleep(self.delay)
yield result yield result
@@ -125,7 +129,11 @@ def make_descriptor() -> AgentRunnerDescriptor:
plugin_name="local-agent", plugin_name="local-agent",
runner_name="default", runner_name="default",
protocol_version="1", protocol_version="1",
capabilities={"streaming": True, "tool_calling": True}, capabilities={"streaming": True, "tool_calling": True, "knowledge_retrieval": True},
config_schema=[
{"name": "model", "type": "model-fallback-selector"},
{"name": "knowledge-bases", "type": "knowledge-base-multi-selector", "default": []},
],
permissions={ permissions={
"models": ["invoke", "stream"], "models": ["invoke", "stream"],
"tools": ["list", "detail", "call"], "tools": ["list", "detail", "call"],
@@ -367,3 +375,27 @@ async def test_orchestrator_unregisters_session_after_runner_failure():
context = plugin_connector.contexts[0] context = plugin_connector.contexts[0]
assert plugin_connector.sessions_during_run[0] is not None assert plugin_connector.sessions_during_run[0] is not None
assert await get_session_registry().get(context["run_id"]) is None assert await get_session_registry().get(context["run_id"]) is None
@pytest.mark.asyncio
async def test_orchestrator_enforces_total_runner_deadline():
descriptor = make_descriptor()
plugin_connector = FakePluginConnector(
results=[
{
"type": "message.completed",
"data": {"message": {"role": "assistant", "content": "too late"}},
}
],
delay=0.05,
)
orchestrator = AgentRunOrchestrator(FakeApplication(plugin_connector), FakeRegistry(descriptor))
query = make_query()
query.pipeline_config["ai"]["runner_config"][RUNNER_ID]["timeout"] = 0.01
with pytest.raises(RunnerExecutionError) as exc_info:
[message async for message in orchestrator.run_from_query(query)]
assert exc_info.value.retryable is True
assert "runner.timeout" in str(exc_info.value)
assert await get_session_registry().get(plugin_connector.contexts[0]["run_id"]) is None

View File

@@ -13,10 +13,12 @@ Source: src/langbot/pkg/api/http/service/model.py
from __future__ import annotations from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock
import pytest
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
from langbot.pkg.api.http.service.model import ( from langbot.pkg.api.http.service.model import (
LLMModelsService, LLMModelsService,
EmbeddingModelsService, EmbeddingModelsService,
@@ -28,6 +30,7 @@ from langbot.pkg.entity.persistence.model import LLMModel, EmbeddingModel, Reran
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
RUNNER_ID = 'plugin:test/runner/default'
def _create_mock_llm_model( def _create_mock_llm_model(
@@ -98,6 +101,22 @@ def _create_mock_result(items: list = None, first_item=None):
return result return result
class FakeAgentRunnerRegistry:
async def get(self, runner_id, bound_plugins=None):
return AgentRunnerDescriptor(
id=runner_id,
source='plugin',
label={'en_US': 'Test Runner'},
plugin_author='test',
plugin_name='runner',
runner_name='default',
config_schema=[
{'name': 'model', 'type': 'model-fallback-selector', 'default': {'primary': '', 'fallbacks': []}},
],
permissions={'models': ['invoke']},
)
class TestParseProviderApiKeys: class TestParseProviderApiKeys:
"""Tests for _parse_provider_api_keys helper function.""" """Tests for _parse_provider_api_keys helper function."""
@@ -402,6 +421,51 @@ class TestLLMModelsServiceCreateLLMModel:
# Verify # Verify
assert model_uuid == 'preserved-uuid' assert model_uuid == 'preserved-uuid'
async def test_create_llm_model_auto_sets_schema_defined_default_pipeline_model(self):
"""Auto-default model selection should use runner schema, not legacy field names."""
ap = SimpleNamespace()
ap.logger = Mock()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
ap.model_mgr.llm_models = []
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
ap.pipeline_service = SimpleNamespace(update_pipeline=AsyncMock())
ap.agent_runner_registry = FakeAgentRunnerRegistry()
pipeline = SimpleNamespace(
uuid='pipeline-uuid',
config={
'ai': {
'runner': {'id': RUNNER_ID},
'runner_config': {
RUNNER_ID: {
'model': {'primary': '', 'fallbacks': []},
},
},
},
},
)
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=pipeline))
service = LLMModelsService(ap)
model_uuid = await service.create_llm_model({
'uuid': 'new-model-uuid',
'name': 'New LLM',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
}, preserve_uuid=True)
assert model_uuid == 'new-model-uuid'
ap.pipeline_service.update_pipeline.assert_awaited_once()
updated_config = ap.pipeline_service.update_pipeline.await_args.args[1]['config']
assert updated_config['ai']['runner_config'][RUNNER_ID]['model'] == {
'primary': 'new-model-uuid',
'fallbacks': [],
}
async def test_create_llm_model_provider_not_found_raises_error(self): async def test_create_llm_model_provider_not_found_raises_error(self):
"""Raises Exception when provider not found in runtime.""" """Raises Exception when provider not found in runtime."""
# Setup # Setup
@@ -961,4 +1025,4 @@ class TestRerankModelsServiceGetRerankModelsByProvider:
result = await service.get_rerank_models_by_provider('provider-uuid') result = await service.get_rerank_models_by_provider('provider-uuid')
# Verify # Verify
assert len(result) == 2 assert len(result) == 2

View File

@@ -21,6 +21,9 @@ import langbot_plugin.api.entities.builtin.provider.session as provider_session
from langbot.pkg.pipeline import entities as pipeline_entities from langbot.pkg.pipeline import entities as pipeline_entities
DEFAULT_RUNNER_ID = 'plugin:langbot/local-agent/default'
class MockApplication: class MockApplication:
"""Mock Application object providing all basic dependencies needed by stages""" """Mock Application object providing all basic dependencies needed by stages"""
@@ -193,8 +196,13 @@ def sample_query(sample_message_chain, sample_message_event, mock_adapter):
bot_uuid='test-bot-uuid', bot_uuid='test-bot-uuid',
pipeline_config={ pipeline_config={
'ai': { 'ai': {
'runner': {'runner': 'local-agent'}, 'runner': {'id': DEFAULT_RUNNER_ID},
'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'}, 'runner_config': {
DEFAULT_RUNNER_ID: {
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
'prompt': [{'role': 'system', 'content': 'test-prompt'}],
},
},
}, },
'output': {'misc': {'at-sender': False, 'quote-origin': False}}, 'output': {'misc': {'at-sender': False, 'quote-origin': False}},
'trigger': {'misc': {'combine-quote-message': False}}, 'trigger': {'misc': {'combine-quote-message': False}},
@@ -218,8 +226,13 @@ def sample_pipeline_config():
"""Provides sample pipeline configuration""" """Provides sample pipeline configuration"""
return { return {
'ai': { 'ai': {
'runner': {'runner': 'local-agent'}, 'runner': {'id': DEFAULT_RUNNER_ID},
'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'}, 'runner_config': {
DEFAULT_RUNNER_ID: {
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
'prompt': [{'role': 'system', 'content': 'test-prompt'}],
},
},
}, },
'output': {'misc': {'at-sender': False, 'quote-origin': False}}, 'output': {'misc': {'at-sender': False, 'quote-origin': False}},
'trigger': {'misc': {'combine-quote-message': False}}, 'trigger': {'misc': {'combine-quote-message': False}},

View File

@@ -13,6 +13,24 @@ from unittest.mock import AsyncMock, Mock
from tests.factories import FakeApp from tests.factories import FakeApp
DEFAULT_RUNNER_ID = 'plugin:langbot/local-agent/default'
def runner_pipeline_config(output_misc: dict) -> dict:
return {
'output': {'misc': output_misc},
'ai': {
'runner': {'id': DEFAULT_RUNNER_ID},
'runner_config': {
DEFAULT_RUNNER_ID: {
'prompt': [{'role': 'system', 'content': 'default'}],
'model': {'primary': 'test', 'fallbacks': []},
},
},
},
}
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ============== # ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
@pytest.fixture(scope='module') @pytest.fixture(scope='module')
@@ -53,7 +71,22 @@ def mock_circular_import_chain():
@pytest.fixture @pytest.fixture
def fake_app(): def fake_app():
"""Create FakeApp instance.""" """Create FakeApp instance."""
return FakeApp() app = FakeApp()
class ProviderRunnerBackedOrchestrator:
async def run_from_query(self, query):
import sys
runner_class = sys.modules['langbot.pkg.provider.runner'].preregistered_runners[0]
runner = runner_class(app, {})
async for result in runner.run(query):
yield result
def resolve_runner_id_for_telemetry(self, query):
return DEFAULT_RUNNER_ID
app.agent_run_orchestrator = ProviderRunnerBackedOrchestrator()
return app
@pytest.fixture @pytest.fixture
@@ -301,10 +334,9 @@ class TestChatHandlerExceptions:
query.adapter.is_stream_output_supported = AsyncMock(return_value=False) query.adapter.is_stream_output_supported = AsyncMock(return_value=False)
query.user_message = Message(role='user', content=[]) query.user_message = Message(role='user', content=[])
query.pipeline_config = { query.pipeline_config = runner_pipeline_config(
'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}}, {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}}, )
}
class FailingRunner: class FailingRunner:
name = 'local-agent' name = 'local-agent'
@@ -344,10 +376,7 @@ class TestChatHandlerExceptions:
query.adapter.is_stream_output_supported = AsyncMock(return_value=False) query.adapter.is_stream_output_supported = AsyncMock(return_value=False)
query.user_message = Message(role='user', content=[]) query.user_message = Message(role='user', content=[])
query.pipeline_config = { query.pipeline_config = runner_pipeline_config({'exception-handling': 'show-error'})
'output': {'misc': {'exception-handling': 'show-error'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
}
class ErrorRunner: class ErrorRunner:
name = 'local-agent' name = 'local-agent'
@@ -384,10 +413,7 @@ class TestChatHandlerExceptions:
query.adapter.is_stream_output_supported = AsyncMock(return_value=False) query.adapter.is_stream_output_supported = AsyncMock(return_value=False)
query.user_message = Message(role='user', content=[]) query.user_message = Message(role='user', content=[])
query.pipeline_config = { query.pipeline_config = runner_pipeline_config({'exception-handling': 'hide'})
'output': {'misc': {'exception-handling': 'hide'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
}
class HideErrorRunner: class HideErrorRunner:
name = 'local-agent' name = 'local-agent'
@@ -433,4 +459,4 @@ class TestChatHandlerHelper:
chat = get_chat_handler() chat = get_chat_handler()
handler = chat.ChatMessageHandler(fake_app) handler = chat.ChatMessageHandler(fake_app)
result = handler.cut_str('first line\nsecond line') result = handler.cut_str('first line\nsecond line')
assert '...' in result assert '...' in result

View File

@@ -21,6 +21,9 @@ from tests.factories import (
import langbot_plugin.api.entities.builtin.provider.message as provider_message import langbot_plugin.api.entities.builtin.provider.message as provider_message
RUNNER_ID = 'plugin:langbot/local-agent/default'
def get_msgtrun_module(): def get_msgtrun_module():
"""Lazy import to avoid circular import issues.""" """Lazy import to avoid circular import issues."""
# Import pipelinemgr first to trigger stage registration # Import pipelinemgr first to trigger stage registration
@@ -47,9 +50,12 @@ def make_truncate_config(max_round: int = 5):
"""Create a pipeline config with max-round setting.""" """Create a pipeline config with max-round setting."""
return { return {
'ai': { 'ai': {
'local-agent': { 'runner': {'id': RUNNER_ID},
'max-round': max_round, 'runner_config': {
} RUNNER_ID: {
'max-round': max_round,
},
},
} }
} }

View File

@@ -24,6 +24,9 @@ from tests.factories import (
) )
RUNNER_ID = 'plugin:langbot/local-agent/default'
def get_preproc_module(): def get_preproc_module():
"""Lazy import to avoid circular import issues.""" """Lazy import to avoid circular import issues."""
return import_module('langbot.pkg.pipeline.preproc.preproc') return import_module('langbot.pkg.pipeline.preproc.preproc')
@@ -34,6 +37,76 @@ def get_entities_module():
return import_module('langbot.pkg.pipeline.entities') return import_module('langbot.pkg.pipeline.entities')
class FakeAgentRunnerRegistry:
def __init__(self, descriptor):
self.descriptor = descriptor
async def get(self, runner_id, bound_plugins=None):
return self.descriptor
def make_host_model_runner_descriptor(
*,
multimodal_input: bool = True,
tool_calling: bool = True,
knowledge_retrieval: bool = True,
):
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
return AgentRunnerDescriptor(
id=RUNNER_ID,
source='plugin',
label={'en_US': 'Local Agent'},
plugin_author='langbot',
plugin_name='local-agent',
runner_name='default',
config_schema=[
{'name': 'model', 'type': 'model-fallback-selector'},
{'name': 'prompt', 'type': 'prompt-editor', 'default': []},
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector', 'default': []},
],
capabilities={
'tool_calling': tool_calling,
'knowledge_retrieval': knowledge_retrieval,
'multimodal_input': multimodal_input,
},
permissions={
'models': ['list', 'invoke', 'stream'],
'tools': ['list', 'detail', 'call'],
'knowledge_bases': ['list', 'retrieve'],
},
)
def set_runner_descriptor(app, descriptor=None):
app.agent_runner_registry = FakeAgentRunnerRegistry(
descriptor or make_host_model_runner_descriptor()
)
def make_runner_config(
*,
primary: str = 'test-model-uuid',
fallbacks: list[str] | None = None,
prompt: list[dict] | None = None,
knowledge_bases: list[str] | None = None,
):
return {
'ai': {
'runner': {'id': RUNNER_ID},
'runner_config': {
RUNNER_ID: {
'model': {'primary': primary, 'fallbacks': fallbacks or []},
'prompt': prompt if prompt is not None else [],
'knowledge-bases': knowledge_bases or [],
},
},
},
'output': {'misc': {'at-sender': False}},
'trigger': {'misc': {}},
}
class TestPreProcessorNormalText: class TestPreProcessorNormalText:
"""Tests for normal text message preprocessing.""" """Tests for normal text message preprocessing."""
@@ -107,6 +180,7 @@ class TestPreProcessorNormalText:
mock_model.model_entity = Mock(uuid='test-model', abilities=['func_call']) mock_model.model_entity = Mock(uuid='test-model', abilities=['func_call'])
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=mock_model) app.model_mgr.get_model_by_uuid = AsyncMock(return_value=mock_model)
app.tool_mgr.get_all_tools = AsyncMock(return_value=[]) app.tool_mgr.get_all_tools = AsyncMock(return_value=[])
set_runner_descriptor(app)
mock_event_ctx = Mock() mock_event_ctx = Mock()
mock_event_ctx.event = Mock(default_prompt=[], prompt=[]) mock_event_ctx.event = Mock(default_prompt=[], prompt=[])
@@ -195,6 +269,7 @@ class TestPreProcessorImageSegment:
stage = preproc.PreProcessor(app) stage = preproc.PreProcessor(app)
# Image query with base64 # Image query with base64
query = image_query(text="look at this", url=None) query = image_query(text="look at this", url=None)
query.pipeline_config = make_runner_config(primary='vision-model')
# Set base64 on the image component # Set base64 on the image component
import langbot_plugin.api.entities.builtin.platform.message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
chain = platform_message.MessageChain([ chain = platform_message.MessageChain([
@@ -206,8 +281,8 @@ class TestPreProcessorImageSegment:
result = await stage.process(query, 'PreProcessor') result = await stage.process(query, 'PreProcessor')
assert result.result_type == preproc.entities.ResultType.CONTINUE assert result.result_type == preproc.entities.ResultType.CONTINUE
# User message should have content content_types = [elem.type for elem in result.new_query.user_message.content]
assert result.new_query.user_message.content is not None assert 'image_base64' in content_types
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_image_without_vision_model(self): async def test_image_without_vision_model(self):
@@ -232,6 +307,7 @@ class TestPreProcessorImageSegment:
mock_model.model_entity = Mock(uuid='text-only-model', abilities=['func_call']) mock_model.model_entity = Mock(uuid='text-only-model', abilities=['func_call'])
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=mock_model) app.model_mgr.get_model_by_uuid = AsyncMock(return_value=mock_model)
app.tool_mgr.get_all_tools = AsyncMock(return_value=[]) app.tool_mgr.get_all_tools = AsyncMock(return_value=[])
set_runner_descriptor(app)
mock_event_ctx = Mock() mock_event_ctx = Mock()
mock_event_ctx.event = Mock(default_prompt=[], prompt=[]) mock_event_ctx.event = Mock(default_prompt=[], prompt=[])
@@ -239,10 +315,13 @@ class TestPreProcessorImageSegment:
stage = preproc.PreProcessor(app) stage = preproc.PreProcessor(app)
query = image_query(text="describe this") query = image_query(text="describe this")
query.pipeline_config = make_runner_config(primary='text-only-model')
result = await stage.process(query, 'PreProcessor') result = await stage.process(query, 'PreProcessor')
assert result.result_type == preproc.entities.ResultType.CONTINUE assert result.result_type == preproc.entities.ResultType.CONTINUE
content_types = [elem.type for elem in result.new_query.user_message.content]
assert 'image_url' not in content_types
class TestPreProcessorModelSelection: class TestPreProcessorModelSelection:
@@ -270,6 +349,7 @@ class TestPreProcessorModelSelection:
mock_model.model_entity = Mock(uuid='primary-model-uuid', abilities=['func_call']) mock_model.model_entity = Mock(uuid='primary-model-uuid', abilities=['func_call'])
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=mock_model) app.model_mgr.get_model_by_uuid = AsyncMock(return_value=mock_model)
app.tool_mgr.get_all_tools = AsyncMock(return_value=[]) app.tool_mgr.get_all_tools = AsyncMock(return_value=[])
set_runner_descriptor(app)
mock_event_ctx = Mock() mock_event_ctx = Mock()
mock_event_ctx.event = Mock(default_prompt=[], prompt=[]) mock_event_ctx.event = Mock(default_prompt=[], prompt=[])
@@ -279,17 +359,7 @@ class TestPreProcessorModelSelection:
query = text_query("hello") query = text_query("hello")
# Set pipeline config with primary model # Set pipeline config with primary model
query.pipeline_config = { query.pipeline_config = make_runner_config(primary='primary-model-uuid')
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {
'model': {'primary': 'primary-model-uuid', 'fallbacks': []},
'prompt': 'default',
},
},
'output': {'misc': {'at-sender': False}},
'trigger': {'misc': {}},
}
result = await stage.process(query, 'PreProcessor') result = await stage.process(query, 'PreProcessor')
@@ -329,6 +399,7 @@ class TestPreProcessorModelSelection:
app.model_mgr.get_model_by_uuid = AsyncMock(side_effect=mock_get_model) app.model_mgr.get_model_by_uuid = AsyncMock(side_effect=mock_get_model)
app.tool_mgr.get_all_tools = AsyncMock(return_value=[]) app.tool_mgr.get_all_tools = AsyncMock(return_value=[])
set_runner_descriptor(app)
mock_event_ctx = Mock() mock_event_ctx = Mock()
mock_event_ctx.event = Mock(default_prompt=[], prompt=[]) mock_event_ctx.event = Mock(default_prompt=[], prompt=[])
@@ -337,17 +408,7 @@ class TestPreProcessorModelSelection:
stage = preproc.PreProcessor(app) stage = preproc.PreProcessor(app)
query = text_query("hello") query = text_query("hello")
query.pipeline_config = { query.pipeline_config = make_runner_config(primary='primary-uuid', fallbacks=['fallback-uuid'])
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {
'model': {'primary': 'primary-uuid', 'fallbacks': ['fallback-uuid']},
'prompt': 'default',
},
},
'output': {'misc': {'at-sender': False}},
'trigger': {'misc': {}},
}
result = await stage.process(query, 'PreProcessor') result = await stage.process(query, 'PreProcessor')

View File

@@ -12,6 +12,7 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session import langbot_plugin.api.entities.builtin.provider.session as provider_session
from langbot.pkg.api.http.service.model import _runtime_model_data from langbot.pkg.api.http.service.model import _runtime_model_data
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
from langbot.pkg.api.http.service.provider import ModelProviderService from langbot.pkg.api.http.service.provider import ModelProviderService
from langbot.pkg.entity.persistence import model as persistence_model from langbot.pkg.entity.persistence import model as persistence_model
from langbot.pkg.pipeline.preproc.preproc import PreProcessor from langbot.pkg.pipeline.preproc.preproc import PreProcessor
@@ -23,6 +24,32 @@ from langbot.pkg.provider.modelmgr.token import TokenManager
from langbot.pkg.provider.runners.localagent import LocalAgentRunner from langbot.pkg.provider.runners.localagent import LocalAgentRunner
DEFAULT_RUNNER_ID = 'plugin:langbot/local-agent/default'
class FakeAgentRunnerRegistry:
async def get(self, runner_id, bound_plugins=None):
return AgentRunnerDescriptor(
id=runner_id,
source='plugin',
label={'en_US': 'Local Agent'},
plugin_author='langbot',
plugin_name='local-agent',
runner_name='default',
config_schema=[
{'name': 'model', 'type': 'model-fallback-selector'},
{'name': 'prompt', 'type': 'prompt-editor', 'default': []},
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector', 'default': []},
],
capabilities={'tool_calling': True, 'knowledge_retrieval': True, 'multimodal_input': True},
permissions={
'models': ['list', 'invoke', 'stream'],
'tools': ['list', 'detail', 'call'],
'knowledge_bases': ['list', 'retrieve'],
},
)
def test_runtime_llm_model_data_preserves_uuid_after_update_payload_uuid_removed(): def test_runtime_llm_model_data_preserves_uuid_after_update_payload_uuid_removed():
update_payload = { update_payload = {
'name': 'Qwen3.5-27B', 'name': 'Qwen3.5-27B',
@@ -190,6 +217,7 @@ async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline()
ap = SimpleNamespace() ap = SimpleNamespace()
ap.logger = Mock() ap.logger = Mock()
ap.agent_runner_registry = FakeAgentRunnerRegistry()
ap.persistence_mgr = SimpleNamespace(execute_async=AsyncMock()) ap.persistence_mgr = SimpleNamespace(execute_async=AsyncMock())
ap.tool_mgr = SimpleNamespace(get_all_tools=AsyncMock(return_value=[])) ap.tool_mgr = SimpleNamespace(get_all_tools=AsyncMock(return_value=[]))
ap.plugin_connector = SimpleNamespace( ap.plugin_connector = SimpleNamespace(
@@ -252,11 +280,13 @@ async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline()
) )
pipeline_config = { pipeline_config = {
'ai': { 'ai': {
'runner': {'runner': 'local-agent'}, 'runner': {'id': DEFAULT_RUNNER_ID},
'local-agent': { 'runner_config': {
'model': {'primary': model_uuid, 'fallbacks': []}, DEFAULT_RUNNER_ID: {
'prompt': [], 'model': {'primary': model_uuid, 'fallbacks': []},
'knowledge-bases': [], 'prompt': [],
'knowledge-bases': [],
},
}, },
}, },
'trigger': {'misc': {'combine-quote-message': False}}, 'trigger': {'misc': {'combine-quote-message': False}},