diff --git a/src/langbot/pkg/agent/runner/config_migration.py b/src/langbot/pkg/agent/runner/config_migration.py index fab878ce..0dac8cf0 100644 --- a/src/langbot/pkg/agent/runner/config_migration.py +++ b/src/langbot/pkg/agent/runner/config_migration.py @@ -24,7 +24,8 @@ class ConfigMigration: Responsibilities: - Resolve runner ID from new ai.runner.id or old ai.runner.runner - Map old built-in runner names to official plugin runner IDs - - Extract runner config from ai.runner_config or old ai. + - Extract runtime runner config from ai.runner_config + - Migrate old ai. blocks into ai.runner_config """ @staticmethod @@ -74,9 +75,9 @@ class ConfigMigration: ) -> dict[str, typing.Any]: """Resolve runner binding configuration from pipeline configuration. - Priority: - 1. New format: ai.runner_config[runner_id] - 2. Old format: ai. (mapped from runner_id if applicable) + Runtime code should only read the migrated format. Legacy + ai. blocks are handled by migration helpers, not by the + hot path. Args: pipeline_config: Pipeline configuration dict @@ -92,7 +93,16 @@ class ConfigMigration: if runner_id in runner_configs: return runner_configs[runner_id] - # Check old format: ai. + return {} + + @staticmethod + def resolve_legacy_runner_config( + pipeline_config: dict[str, typing.Any], + runner_id: str, + ) -> dict[str, typing.Any]: + """Resolve old ai. config for migration only.""" + ai_config = pipeline_config.get('ai', {}) + # Try to find old runner name from runner_id old_runner_name = None for old_name, mapped_id in OLD_RUNNER_TO_PLUGIN_RUNNER_ID.items(): @@ -105,12 +115,6 @@ class ConfigMigration: if old_config: return old_config - # If runner_id is plugin:* format, try extracting runner_name as config key - if is_plugin_runner_id(runner_id): - # Some configs might use just the runner_name component as key - # But this is legacy behavior - prefer ai.runner_config[id] - pass - return {} @staticmethod @@ -181,6 +185,8 @@ class ConfigMigration: # Migrate runner config resolved_config = ConfigMigration.resolve_runner_config(pipeline_config, runner_id) + if not resolved_config: + resolved_config = ConfigMigration.resolve_legacy_runner_config(pipeline_config, runner_id) if resolved_config: runner_configs[runner_id] = resolved_config # Remove old runner config block @@ -193,4 +199,4 @@ class ConfigMigration: ai_config['runner_config'] = runner_configs new_config['ai'] = ai_config - return new_config \ No newline at end of file + return new_config diff --git a/src/langbot/pkg/agent/runner/config_schema.py b/src/langbot/pkg/agent/runner/config_schema.py new file mode 100644 index 00000000..430d2d5e --- /dev/null +++ b/src/langbot/pkg/agent/runner/config_schema.py @@ -0,0 +1,208 @@ +"""Helpers for interpreting AgentRunner DynamicForm configuration.""" +from __future__ import annotations + +import typing + +from .descriptor import AgentRunnerDescriptor + + +LLM_MODEL_SELECTOR_TYPES = {'model-fallback-selector', 'llm-model-selector'} +KB_SELECTOR_TYPES = {'knowledge-base-multi-selector'} +PROMPT_EDITOR_TYPES = {'prompt-editor'} +NONE_SENTINELS = {'', '__none__', '__none'} + + +def iter_schema_items( + descriptor: AgentRunnerDescriptor | None, + field_types: set[str], +) -> typing.Iterator[dict[str, typing.Any]]: + """Yield descriptor config schema items whose type is in field_types.""" + if descriptor is None: + return + for item in descriptor.config_schema or []: + if not isinstance(item, dict): + continue + if item.get('type') in field_types: + yield item + + +def has_permission( + descriptor: AgentRunnerDescriptor | None, + name: str, + actions: set[str], +) -> bool: + """Return whether a runner descriptor requests one of the given actions.""" + if descriptor is None: + return False + configured_actions = descriptor.permissions.get(name, []) + return any(action in configured_actions for action in actions) + + +def uses_host_models(descriptor: AgentRunnerDescriptor | None) -> bool: + """Return whether LangBot should resolve model resources for this runner.""" + return ( + has_permission(descriptor, 'models', {'invoke', 'stream', 'list'}) + and any(True for _ in iter_schema_items(descriptor, LLM_MODEL_SELECTOR_TYPES)) + ) + + +def uses_host_tools(descriptor: AgentRunnerDescriptor | None) -> bool: + """Return whether LangBot should expose tool resources to this runner.""" + return ( + descriptor is not None + and descriptor.supports_tool_calling() + and has_permission(descriptor, 'tools', {'list', 'detail', 'call'}) + ) + + +def uses_host_knowledge_bases(descriptor: AgentRunnerDescriptor | None) -> bool: + """Return whether LangBot should expose knowledge-base resources to this runner.""" + return ( + descriptor is not None + and descriptor.supports_knowledge_retrieval() + and has_permission(descriptor, 'knowledge_bases', {'list', 'retrieve'}) + ) + + +def extract_prompt_config( + descriptor: AgentRunnerDescriptor | None, + runner_config: dict[str, typing.Any], + default_prompt: list[dict[str, typing.Any]], +) -> list[dict[str, typing.Any]]: + """Extract the prompt-editor value selected by the runner schema.""" + for item in iter_schema_items(descriptor, PROMPT_EDITOR_TYPES): + field_name = item.get('name') + if field_name and field_name in runner_config: + configured_prompt = runner_config[field_name] + if isinstance(configured_prompt, list): + return configured_prompt + default_value = item.get('default') + if isinstance(default_value, list): + return default_value + return default_prompt + + +def extract_model_selection( + descriptor: AgentRunnerDescriptor | None, + runner_config: dict[str, typing.Any], +) -> tuple[str, list[str]]: + """Extract primary/fallback LLM selections from schema-defined fields.""" + primary_uuid = '' + fallback_uuids: list[str] = [] + + for item in iter_schema_items(descriptor, LLM_MODEL_SELECTOR_TYPES): + field_name = item.get('name') + if not field_name: + continue + + value = runner_config.get(field_name, item.get('default')) + if item.get('type') == 'model-fallback-selector': + if isinstance(value, str): + primary_uuid = value + elif isinstance(value, dict): + primary_uuid = value.get('primary') or '' + fallbacks = value.get('fallbacks', []) + if isinstance(fallbacks, list): + fallback_uuids = [fallback for fallback in fallbacks if isinstance(fallback, str)] + break + + if item.get('type') == 'llm-model-selector' and isinstance(value, str): + primary_uuid = value + break + + return primary_uuid, fallback_uuids + + +def extract_knowledge_base_uuids( + descriptor: AgentRunnerDescriptor | None, + runner_config: dict[str, typing.Any], +) -> list[str]: + """Extract configured knowledge-base UUIDs from schema-defined fields.""" + if not uses_host_knowledge_bases(descriptor): + return [] + + kb_uuids: list[str] = [] + for item in iter_schema_items(descriptor, KB_SELECTOR_TYPES): + field_name = item.get('name') + if not field_name: + continue + value = runner_config.get(field_name, item.get('default', [])) + if isinstance(value, list): + kb_uuids.extend( + kb_uuid for kb_uuid in value if isinstance(kb_uuid, str) and kb_uuid not in NONE_SENTINELS + ) + + return list(dict.fromkeys(kb_uuids)) + + +def iter_config_model_refs( + descriptor: AgentRunnerDescriptor, + runner_config: dict[str, typing.Any], +) -> typing.Iterator[tuple[str, str]]: + """Yield model references declared by schema-defined model selector fields.""" + for item in descriptor.config_schema or []: + if not isinstance(item, dict): + continue + + field_name = item.get('name') + field_type = item.get('type') + if not field_name or field_name not in runner_config: + continue + + value = runner_config.get(field_name) + if field_type == 'model-fallback-selector': + if isinstance(value, str) and value not in NONE_SENTINELS: + yield 'llm', value + elif isinstance(value, dict): + primary = value.get('primary') + if isinstance(primary, str) and primary not in NONE_SENTINELS: + yield 'llm', primary + fallbacks = value.get('fallbacks', []) + if isinstance(fallbacks, list): + for fallback_uuid in fallbacks: + if isinstance(fallback_uuid, str) and fallback_uuid not in NONE_SENTINELS: + yield 'llm', fallback_uuid + elif field_type == 'llm-model-selector': + if isinstance(value, str) and value not in NONE_SENTINELS: + yield 'llm', value + elif field_type == 'rerank-model-selector': + if isinstance(value, str) and value not in NONE_SENTINELS: + yield 'rerank', value + + +def set_empty_llm_model_selection( + descriptor: AgentRunnerDescriptor, + runner_config: dict[str, typing.Any], + model_uuid: str, +) -> bool: + """Set the first empty schema-defined LLM selector to model_uuid.""" + for item in iter_schema_items(descriptor, LLM_MODEL_SELECTOR_TYPES): + field_name = item.get('name') + field_type = item.get('type') + if not field_name: + continue + + value = runner_config.get(field_name, item.get('default')) + if field_type == 'model-fallback-selector': + if isinstance(value, dict): + primary = value.get('primary') or '' + if primary not in NONE_SENTINELS: + return False + fallbacks = value.get('fallbacks', []) + runner_config[field_name] = { + 'primary': model_uuid, + 'fallbacks': fallbacks if isinstance(fallbacks, list) else [], + } + return True + if isinstance(value, str) and value not in NONE_SENTINELS: + return False + runner_config[field_name] = {'primary': model_uuid, 'fallbacks': []} + return True + + if field_type == 'llm-model-selector': + if isinstance(value, str) and value not in NONE_SENTINELS: + return False + runner_config[field_name] = model_uuid + return True + + return False diff --git a/src/langbot/pkg/agent/runner/context_builder.py b/src/langbot/pkg/agent/runner/context_builder.py index 8cafb67a..a43aac66 100644 --- a/src/langbot/pkg/agent/runner/context_builder.py +++ b/src/langbot/pkg/agent/runner/context_builder.py @@ -15,6 +15,9 @@ from .state_store import get_state_store from . import events as runner_events +DEFAULT_RUNNER_TIMEOUT_SECONDS = 300 + + # Internal models for the agent runner context protocol. @@ -106,7 +109,7 @@ class AgentRuntimeContext(typing.TypedDict): sdk_protocol_version: str query_id: int | None trace_id: str | None - deadline_at: int | None + deadline_at: float | None metadata: dict[str, typing.Any] @@ -480,9 +483,13 @@ class AgentRunContextBuilder: }, } - def _build_deadline(self, runner_config: dict[str, typing.Any]) -> int | None: - """Build deadline timestamp from runner timeout config if present.""" - timeout = runner_config.get('timeout') + def _build_deadline(self, runner_config: dict[str, typing.Any]) -> float | None: + """Build deadline timestamp from runner timeout config. + + A missing timeout uses the host default. Explicit null, zero, or negative + values disable the total run deadline for advanced deployments. + """ + timeout = runner_config.get('timeout', DEFAULT_RUNNER_TIMEOUT_SECONDS) if timeout is None: return None @@ -494,7 +501,7 @@ class AgentRunContextBuilder: if timeout_seconds <= 0: return None - return int(time.time() + timeout_seconds) + return time.time() + timeout_seconds async def _is_stream_output_supported(self, query: pipeline_query.Query) -> bool: """Check whether the current adapter can consume streaming chunks.""" diff --git a/src/langbot/pkg/agent/runner/orchestrator.py b/src/langbot/pkg/agent/runner/orchestrator.py index 8974de9a..2e8f6d25 100644 --- a/src/langbot/pkg/agent/runner/orchestrator.py +++ b/src/langbot/pkg/agent/runner/orchestrator.py @@ -3,9 +3,12 @@ from __future__ import annotations import typing import traceback +import asyncio +import time from langbot_plugin.api.entities.builtin.provider import message as provider_message from langbot_plugin.api.entities.builtin.pipeline import query as pipeline_query +from langbot_plugin.entities.io.errors import ActionCallTimeoutError from ...core import app from .descriptor import AgentRunnerDescriptor @@ -155,14 +158,32 @@ class AgentRunOrchestrator: ) try: - async for result_dict in self.ap.plugin_connector.run_agent( + gen = self.ap.plugin_connector.run_agent( plugin_author=descriptor.plugin_author, plugin_name=descriptor.plugin_name, runner_name=descriptor.runner_name, context=context, - ): + ) + + while True: + try: + result_dict = await self._next_with_deadline(gen, descriptor, context) + except StopAsyncIteration: + break yield result_dict + except asyncio.TimeoutError as e: + raise RunnerExecutionError( + descriptor.id, + 'Runner timed out (code: runner.timeout)', + retryable=True, + ) from e + except ActionCallTimeoutError as e: + raise RunnerExecutionError( + descriptor.id, + f'{e} (code: runner.timeout)', + retryable=True, + ) from e except RunnerExecutionError: raise except Exception as e: @@ -176,6 +197,57 @@ class AgentRunOrchestrator: retryable=False, ) + async def _next_with_deadline( + self, + gen: typing.AsyncGenerator[dict[str, typing.Any], None], + descriptor: AgentRunnerDescriptor, + context: AgentRunContextPayload, + ) -> dict[str, typing.Any]: + """Read the next runner result while enforcing the run deadline.""" + remaining = self._remaining_deadline_seconds(context) + if remaining is not None and remaining <= 0: + await self._close_generator(gen, descriptor) + raise asyncio.TimeoutError + + try: + if remaining is None: + return await anext(gen) + return await asyncio.wait_for(anext(gen), timeout=remaining) + except StopAsyncIteration: + if self._is_deadline_exhausted(context): + raise asyncio.TimeoutError + raise + except asyncio.TimeoutError: + await self._close_generator(gen, descriptor) + raise + + def _remaining_deadline_seconds( + self, + context: AgentRunContextPayload, + ) -> float | None: + runtime = context.get('runtime') or {} + deadline_at = runtime.get('deadline_at') + if deadline_at is None: + return None + try: + return float(deadline_at) - time.time() + except (TypeError, ValueError): + return None + + def _is_deadline_exhausted(self, context: AgentRunContextPayload) -> bool: + remaining = self._remaining_deadline_seconds(context) + return remaining is not None and remaining <= 0 + + async def _close_generator( + self, + gen: typing.AsyncGenerator[dict[str, typing.Any], None], + descriptor: AgentRunnerDescriptor, + ) -> None: + try: + await gen.aclose() + except Exception as e: + self.ap.logger.warning(f'Failed to close timed-out runner {descriptor.id}: {e}') + def resolve_runner_id_for_telemetry(self, query: pipeline_query.Query) -> str | None: """Resolve runner ID for telemetry/logging without full execution. diff --git a/src/langbot/pkg/agent/runner/resource_builder.py b/src/langbot/pkg/agent/runner/resource_builder.py index 82d5e24c..8c925b4e 100644 --- a/src/langbot/pkg/agent/runner/resource_builder.py +++ b/src/langbot/pkg/agent/runner/resource_builder.py @@ -13,6 +13,7 @@ from .context_builder import ( KnowledgeBaseResource, StorageResource, ) +from . import config_schema class AgentResourceBuilder: @@ -73,7 +74,7 @@ class AgentResourceBuilder: models, tools, knowledge_bases = await asyncio.gather( self._build_models(manifest_perms, runner_config, descriptor, query), self._build_tools(manifest_perms, bound_plugins, bound_mcp_servers, query), - self._build_knowledge_bases(manifest_perms, runner_config, query), + self._build_knowledge_bases(manifest_perms, runner_config, descriptor, query), ) storage = self._build_storage(manifest_perms) @@ -132,34 +133,11 @@ class AgentResourceBuilder: runner_config: dict[str, typing.Any], ) -> None: """Authorize model-like values selected through DynamicForm fields.""" - for item in descriptor.config_schema or []: - if not isinstance(item, dict): - continue - - field_name = item.get('name') - field_type = item.get('type') - if not field_name or field_name not in runner_config: - continue - - value = runner_config.get(field_name) - if field_type == 'model-fallback-selector': - if isinstance(value, str): - await self._append_llm_model_resource(models, seen_model_ids, value) - elif isinstance(value, dict): - primary = value.get('primary') - if isinstance(primary, str): - await self._append_llm_model_resource(models, seen_model_ids, primary) - fallbacks = value.get('fallbacks', []) - if isinstance(fallbacks, list): - for fallback_uuid in fallbacks: - if isinstance(fallback_uuid, str): - await self._append_llm_model_resource(models, seen_model_ids, fallback_uuid) - elif field_type == 'llm-model-selector': - if isinstance(value, str): - await self._append_llm_model_resource(models, seen_model_ids, value) - elif field_type == 'rerank-model-selector': - if isinstance(value, str): - await self._append_rerank_model_resource(models, seen_model_ids, value) + for model_type, model_uuid in config_schema.iter_config_model_refs(descriptor, runner_config): + if model_type == 'llm': + await self._append_llm_model_resource(models, seen_model_ids, model_uuid) + elif model_type == 'rerank': + await self._append_rerank_model_resource(models, seen_model_ids, model_uuid) async def _append_llm_model_resource( self, @@ -236,6 +214,7 @@ class AgentResourceBuilder: self, manifest_perms: dict[str, list[str]], runner_config: dict[str, typing.Any], + descriptor: AgentRunnerDescriptor, query: typing.Any, ) -> list[KnowledgeBaseResource]: """Build knowledge bases list with plugin SDK field names.""" @@ -246,13 +225,8 @@ class AgentResourceBuilder: if 'list' not in kb_perms and 'retrieve' not in kb_perms: return kb_resources - # Get knowledge base UUIDs from config - kb_uuids = runner_config.get('knowledge-bases', []) - if not kb_uuids: - # Old single KB config - old_kb_uuid = runner_config.get('knowledge-base', '') - if old_kb_uuid and old_kb_uuid != '__none__': - kb_uuids = [old_kb_uuid] + # Get knowledge base UUIDs from schema-defined config fields. + kb_uuids = config_schema.extract_knowledge_base_uuids(descriptor, runner_config) # Also check query variables (may be modified by plugin PromptPreProcessing) kb_uuids_from_vars = query.variables.get('_knowledge_base_uuids', []) diff --git a/src/langbot/pkg/api/http/service/model.py b/src/langbot/pkg/api/http/service/model.py index 320104d8..3758cbbc 100644 --- a/src/langbot/pkg/api/http/service/model.py +++ b/src/langbot/pkg/api/http/service/model.py @@ -9,6 +9,8 @@ from ....core import app from ....entity.persistence import model as persistence_model from ....entity.persistence import pipeline as persistence_pipeline from ....provider.modelmgr import requester as model_requester +from ....agent.runner.config_migration import ConfigMigration +from ....agent.runner import config_schema def _parse_provider_api_keys(provider_dict: dict) -> dict: @@ -40,6 +42,40 @@ class LLMModelsService: def __init__(self, ap: app.Application) -> None: self.ap = ap + async def _get_runner_descriptor(self, runner_id: str): + registry = getattr(self.ap, 'agent_runner_registry', None) + if registry is None: + return None + try: + return await registry.get(runner_id, bound_plugins=None) + except Exception as e: + logger = getattr(self.ap, 'logger', None) + if logger: + logger.warning(f'Failed to load AgentRunner descriptor while setting default model: {e}') + return None + + async def _auto_set_default_pipeline_llm_model(self, pipeline: persistence_pipeline.LegacyPipeline, model_uuid: str): + pipeline_config = pipeline.config + if not isinstance(pipeline_config, dict): + return + + runner_id = ConfigMigration.resolve_runner_id(pipeline_config) + if not runner_id: + return + + descriptor = await self._get_runner_descriptor(runner_id) + if descriptor is None: + return + + ai_config = pipeline_config.setdefault('ai', {}) + runner_configs = ai_config.setdefault('runner_config', {}) + runner_config = runner_configs.setdefault(runner_id, {}) + + if not config_schema.set_empty_llm_model_selection(descriptor, runner_config, model_uuid): + return + + await self.ap.pipeline_service.update_pipeline(pipeline.uuid, {'config': pipeline_config}) + async def get_llm_models(self, include_secret: bool = True) -> list[dict]: """Get all LLM models with provider info""" result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) @@ -109,7 +145,6 @@ class LLMModelsService: self.ap.model_mgr.llm_models.append(runtime_llm_model) if auto_set_to_default_pipeline: - # set the default pipeline model to this model result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_pipeline.LegacyPipeline).where( persistence_pipeline.LegacyPipeline.is_default == True @@ -117,15 +152,7 @@ class LLMModelsService: ) pipeline = result.first() if pipeline is not None: - model_config = pipeline.config.get('ai', {}).get('local-agent', {}).get('model', {}) - if not model_config.get('primary', ''): - pipeline_config = pipeline.config - pipeline_config['ai']['local-agent']['model'] = { - 'primary': model_data['uuid'], - 'fallbacks': [], - } - pipeline_data = {'config': pipeline_config} - await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data) + await self._auto_set_default_pipeline_llm_model(pipeline, model_data['uuid']) return model_data['uuid'] diff --git a/src/langbot/pkg/pipeline/msgtrun/truncators/round.py b/src/langbot/pkg/pipeline/msgtrun/truncators/round.py index 634f3106..f339f341 100644 --- a/src/langbot/pkg/pipeline/msgtrun/truncators/round.py +++ b/src/langbot/pkg/pipeline/msgtrun/truncators/round.py @@ -11,7 +11,8 @@ class RoundTruncator(truncator.Truncator): async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query: """截断""" - # Get max-round from runner config (new or old format) + # max-round remains a pipeline-side trimming knob until token-budget + # based compaction replaces this stage. runner_id = ConfigMigration.resolve_runner_id(query.pipeline_config) runner_config = ConfigMigration.resolve_runner_config(query.pipeline_config, runner_id) if runner_id else {} max_round = runner_config.get('max-round', 10) diff --git a/src/langbot/pkg/pipeline/preproc/preproc.py b/src/langbot/pkg/pipeline/preproc/preproc.py index 9bd6730e..749909ba 100644 --- a/src/langbot/pkg/pipeline/preproc/preproc.py +++ b/src/langbot/pkg/pipeline/preproc/preproc.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime +import typing from .. import stage, entities from langbot_plugin.api.entities.builtin.provider import message as provider_message @@ -9,12 +10,14 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query import langbot_plugin.api.entities.builtin.platform.events as platform_events +from ...agent.runner.descriptor import AgentRunnerDescriptor from ...agent.runner.config_migration import ConfigMigration +from ...agent.runner import config_schema -# Official local-agent runner ID -LOCAL_AGENT_RUNNER_ID = 'plugin:langbot/local-agent/default' - +DEFAULT_PROMPT_CONFIG = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, +] @stage.stage_class('PreProcessor') class PreProcessor(stage.PipelineStage): @@ -31,6 +34,76 @@ class PreProcessor(stage.PipelineStage): - use_funcs """ + async def _get_runner_descriptor( + self, + runner_id: str | None, + bound_plugins: list[str] | None, + ) -> AgentRunnerDescriptor | None: + if not runner_id: + return None + + registry = getattr(self.ap, 'agent_runner_registry', None) + if registry is None: + return None + + try: + return await registry.get(runner_id, bound_plugins) + except Exception as e: + self.ap.logger.debug(f'Unable to load AgentRunner descriptor for {runner_id}: {e}') + return None + + async def _resolve_llm_model( + self, + primary_uuid: str, + ) -> typing.Any | None: + if primary_uuid in config_schema.NONE_SENTINELS: + return None + try: + return await self.ap.model_mgr.get_model_by_uuid(primary_uuid) + except ValueError: + self.ap.logger.warning(f'LLM model {primary_uuid} not found or not configured') + return None + + async def _resolve_fallback_models(self, fallback_uuids: list[str]) -> list[str]: + valid_fallbacks = [] + for fallback_uuid in fallback_uuids: + if fallback_uuid in config_schema.NONE_SENTINELS: + continue + try: + await self.ap.model_mgr.get_model_by_uuid(fallback_uuid) + valid_fallbacks.append(fallback_uuid) + except ValueError: + self.ap.logger.warning(f'Fallback model {fallback_uuid} not found, skipping') + return valid_fallbacks + + def _runner_accepts_multimodal_input(self, descriptor: AgentRunnerDescriptor | None) -> bool: + if descriptor is None: + return True + return descriptor.capabilities.get('multimodal_input', False) + + def _model_supports_vision(self, llm_model: typing.Any | None) -> bool: + if not llm_model: + return False + abilities = getattr(getattr(llm_model, 'model_entity', None), 'abilities', []) + return 'vision' in abilities + + def _should_keep_image_inputs( + self, + descriptor: AgentRunnerDescriptor | None, + uses_host_models: bool, + llm_model: typing.Any | None, + ) -> bool: + if not self._runner_accepts_multimodal_input(descriptor): + return False + if uses_host_models: + return self._model_supports_vision(llm_model) + return True + + def _strip_images_from_history(self, query: pipeline_query.Query) -> None: + for msg in query.messages: + if isinstance(msg.content, list): + msg.content = [elem for elem in msg.content if elem.type != 'image_url'] + async def process( self, query: pipeline_query.Query, @@ -40,56 +113,25 @@ class PreProcessor(stage.PipelineStage): # Resolve runner ID using ConfigMigration (supports both new and old formats) runner_id = ConfigMigration.resolve_runner_id(query.pipeline_config) - # Get runner config (from new ai.runner_config or old ai.) + # Get runner config from ai.runner_config[runner_id]. runner_config = ConfigMigration.resolve_runner_config(query.pipeline_config, runner_id) if runner_id else {} + query.variables = query.variables or {} + bound_plugins = query.variables.get('_pipeline_bound_plugins', None) + bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None) + descriptor = await self._get_runner_descriptor(runner_id, bound_plugins) session = await self.ap.sess_mgr.get_session(query) - # Determine if this is a local-agent runner (built-in LLM capabilities) - # Check by runner_id OR by legacy runner field for backward compatibility - is_local_agent = runner_id == LOCAL_AGENT_RUNNER_ID or ( - runner_id is None and - query.pipeline_config.get('ai', {}).get('runner', {}).get('runner') == 'local-agent' - ) - - # When not local-agent, llm_model is None + uses_host_models = config_schema.uses_host_models(descriptor) llm_model = None - if is_local_agent: - # Read model config — new format is { primary: str, fallbacks: [str] }, - # but handle legacy plain string for backward compatibility - model_config = runner_config.get('model', {}) - if isinstance(model_config, str): - # Legacy format: plain UUID string - primary_uuid = model_config - fallback_uuids = [] - else: - primary_uuid = model_config.get('primary', '') - fallback_uuids = model_config.get('fallbacks', []) + if uses_host_models: + primary_uuid, fallback_uuids = config_schema.extract_model_selection(descriptor, runner_config) + llm_model = await self._resolve_llm_model(primary_uuid) + valid_fallbacks = await self._resolve_fallback_models(fallback_uuids) + if valid_fallbacks: + query.variables['_fallback_model_uuids'] = valid_fallbacks - if primary_uuid: - try: - llm_model = await self.ap.model_mgr.get_model_by_uuid(primary_uuid) - except ValueError: - self.ap.logger.warning(f'LLM model {primary_uuid} not found or not configured') - - # Resolve fallback model UUIDs - if fallback_uuids: - valid_fallbacks = [] - for fb_uuid in fallback_uuids: - try: - await self.ap.model_mgr.get_model_by_uuid(fb_uuid) - valid_fallbacks.append(fb_uuid) - except ValueError: - self.ap.logger.warning(f'Fallback model {fb_uuid} not found, skipping') - if valid_fallbacks: - query.variables['_fallback_model_uuids'] = valid_fallbacks - - # Get prompt config - for local-agent, use runner_config; for others, use default prompt - prompt_config = runner_config.get('prompt', [ - {'role': 'system', 'content': 'You are a helpful assistant.'} - ]) if is_local_agent else [ - {'role': 'system', 'content': 'You are a helpful assistant.'} - ] + prompt_config = config_schema.extract_prompt_config(descriptor, runner_config, DEFAULT_PROMPT_CONFIG) conversation = await self.ap.sess_mgr.get_conversation( query, @@ -125,15 +167,14 @@ class PreProcessor(stage.PipelineStage): query.prompt = conversation.prompt.copy() query.messages = conversation.messages.copy() - if is_local_agent: + if uses_host_models: query.use_funcs = [] if llm_model: query.use_llm_model_uuid = llm_model.model_entity.uuid - if llm_model.model_entity.abilities.__contains__('func_call'): - # Get bound plugins and MCP servers for filtering tools - bound_plugins = query.variables.get('_pipeline_bound_plugins', None) - bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None) + if config_schema.uses_host_tools(descriptor) and llm_model.model_entity.abilities.__contains__( + 'func_call' + ): query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers) self.ap.logger.debug(f'Bound plugins: {bound_plugins}') @@ -142,9 +183,11 @@ class PreProcessor(stage.PipelineStage): # If primary model doesn't support func_call but fallback models exist, # load tools anyway since fallback models may support them - if not query.use_funcs and query.variables.get('_fallback_model_uuids'): - bound_plugins = query.variables.get('_pipeline_bound_plugins', None) - bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None) + if ( + config_schema.uses_host_tools(descriptor) + and not query.use_funcs + and query.variables.get('_fallback_model_uuids') + ): query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers) sender_name = '' @@ -170,18 +213,9 @@ class PreProcessor(stage.PipelineStage): } query.variables.update(variables) - # Check if this model supports vision, if not, remove all images - # TODO this checking should be performed in runner, and in this stage, the image should be reserved - if ( - is_local_agent - and llm_model - and not llm_model.model_entity.abilities.__contains__('vision') - ): - for msg in query.messages: - if isinstance(msg.content, list): - for me in msg.content: - if me.type == 'image_url': - msg.content.remove(me) + keep_image_inputs = self._should_keep_image_inputs(descriptor, uses_host_models, llm_model) + if not keep_image_inputs: + self._strip_images_from_history(query) content_list: list[provider_message.ContentElement] = [] @@ -193,10 +227,7 @@ class PreProcessor(stage.PipelineStage): content_list.append(provider_message.ContentElement.from_text(me.text)) plain_text += me.text elif isinstance(me, platform_message.Image): - # Allow images for non-local-agent runners or if local-agent has vision - if not is_local_agent or ( - llm_model and llm_model.model_entity.abilities.__contains__('vision') - ): + if keep_image_inputs: if me.base64 is not None: content_list.append(provider_message.ContentElement.from_image_base64(me.base64)) elif isinstance(me, platform_message.Voice): @@ -215,9 +246,7 @@ class PreProcessor(stage.PipelineStage): if isinstance(msg, platform_message.Plain): content_list.append(provider_message.ContentElement.from_text(msg.text)) elif isinstance(msg, platform_message.Image): - if not is_local_agent or ( - llm_model and llm_model.model_entity.abilities.__contains__('vision') - ): + if keep_image_inputs: if msg.base64 is not None: content_list.append(provider_message.ContentElement.from_image_base64(msg.base64)) elif isinstance(msg, platform_message.File): @@ -237,15 +266,12 @@ class PreProcessor(stage.PipelineStage): query.user_message = provider_message.Message(role='user', content=content_list) - # Extract knowledge base UUIDs into query variables so plugins can modify them - # during PromptPreProcessing before the runner performs retrieval. - # Only for local-agent runner - kb_uuids = runner_config.get('knowledge-bases', []) if is_local_agent else [] - if not kb_uuids: - old_kb_uuid = runner_config.get('knowledge-base', '') if is_local_agent else '' - if old_kb_uuid and old_kb_uuid != '__none__': - kb_uuids = [old_kb_uuid] - query.variables['_knowledge_base_uuids'] = list(kb_uuids) + # Extract configured KB UUIDs into query variables so PromptPreProcessing + # plugins can still adjust the authorized retrieval set before run_agent. + query.variables['_knowledge_base_uuids'] = config_schema.extract_knowledge_base_uuids( + descriptor, + runner_config, + ) # =========== 触发事件 PromptPreProcessing @@ -263,4 +289,4 @@ class PreProcessor(stage.PipelineStage): query.prompt.messages = event_ctx.event.default_prompt query.messages = event_ctx.event.prompt - return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) \ No newline at end of file + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/src/langbot/pkg/plugin/handler.py b/src/langbot/pkg/plugin/handler.py index dda03b28..0c39d630 100644 --- a/src/langbot/pkg/plugin/handler.py +++ b/src/langbot/pkg/plugin/handler.py @@ -25,6 +25,8 @@ from ..entity.persistence import bstorage as persistence_bstorage from ..core import app from ..utils import constants from ..agent.runner.session_registry import get_session_registry +from ..agent.runner.config_migration import ConfigMigration +from ..agent.runner import config_schema def _make_rag_error_response(error: Exception, error_type: str, **extra_context) -> handler.ActionResponse: @@ -98,6 +100,46 @@ def _build_tool_detail(tool: Any, requested_tool_name: str | None = None) -> dic } +def _normalize_uuid_list(values: Any) -> list[str]: + """Normalize a user/config supplied UUID list while preserving order.""" + if not isinstance(values, list): + return [] + return list( + dict.fromkeys( + value for value in values if isinstance(value, str) and value not in config_schema.NONE_SENTINELS + ) + ) + + +async def _get_pipeline_knowledge_base_uuids(ap: app.Application, query: Any) -> list[str]: + """Resolve pipeline-scoped KBs from preprocessed variables or runner schema.""" + variables = getattr(query, 'variables', {}) or {} + if '_knowledge_base_uuids' in variables: + return _normalize_uuid_list(variables.get('_knowledge_base_uuids')) + + pipeline_config = getattr(query, 'pipeline_config', None) + if not pipeline_config: + return [] + + runner_id = ConfigMigration.resolve_runner_id(pipeline_config) + if not runner_id: + return [] + + runner_config = ConfigMigration.resolve_runner_config(pipeline_config, runner_id) + registry = getattr(ap, 'agent_runner_registry', None) + if registry is None: + return [] + + bound_plugins = variables.get('_pipeline_bound_plugins') + try: + descriptor = await registry.get(runner_id, bound_plugins) + except Exception as e: + ap.logger.warning(f'Failed to load AgentRunner descriptor for pipeline knowledge-base scope: {e}') + return [] + + return config_schema.extract_knowledge_base_uuids(descriptor, runner_config) + + async def _validate_run_authorization( run_id: str, resource_type: str, @@ -1155,15 +1197,7 @@ class RuntimeConnectionHandler(handler.Handler): query = self.ap.query_pool.cached_queries[query_id] - kb_uuids = [] - if query.pipeline_config: - local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {}) - kb_uuids = local_agent_config.get('knowledge-bases', []) - # Backward compatibility - if not kb_uuids: - old_kb_uuid = local_agent_config.get('knowledge-base', '') - if old_kb_uuid and old_kb_uuid != '__none__': - kb_uuids = [old_kb_uuid] + kb_uuids = await _get_pipeline_knowledge_base_uuids(self.ap, query) knowledge_bases = [] for kb_uuid in kb_uuids: @@ -1213,19 +1247,9 @@ class RuntimeConnectionHandler(handler.Handler): if error: return error else: - # Regular plugin call: validate against pipeline's configured knowledge bases - # FIX: First resolve runner_id, then resolve runner_config - allowed_kb_uuids = [] - if query.pipeline_config: - from langbot.pkg.agent.runner.config_migration import ConfigMigration - runner_id = ConfigMigration.resolve_runner_id(query.pipeline_config) - if runner_id: - runner_config = ConfigMigration.resolve_runner_config(query.pipeline_config, runner_id) - allowed_kb_uuids = runner_config.get('knowledge-bases', []) - if not allowed_kb_uuids: - old_kb_uuid = runner_config.get('knowledge-base', '') - if old_kb_uuid and old_kb_uuid != '__none__': - allowed_kb_uuids = [old_kb_uuid] + # Regular plugin call: validate against the runner binding's + # schema-defined KB selectors or the preprocessed query scope. + allowed_kb_uuids = await _get_pipeline_knowledge_base_uuids(self.ap, query) if kb_id not in allowed_kb_uuids: return handler.ActionResponse.error( @@ -1424,6 +1448,7 @@ class RuntimeConnectionHandler(handler.Handler): Yields AgentRunResult dicts. """ + timeout = self._get_runner_action_timeout(context) gen = self.call_action_generator( LangBotToRuntimeAction.RUN_AGENT, { @@ -1432,12 +1457,27 @@ class RuntimeConnectionHandler(handler.Handler): 'runner_name': runner_name, 'context': context, }, - timeout=300, + timeout=timeout, ) async for ret in gen: yield ret + def _get_runner_action_timeout(self, context: dict[str, Any]) -> float: + """Use the run deadline as the transport idle timeout when available.""" + try: + import time + + deadline_at = (context.get('runtime') or {}).get('deadline_at') + if deadline_at is None: + return 300 + remaining = float(deadline_at) - time.time() + if remaining <= 0: + return 0.001 + return max(remaining + 1.0, 0.001) + except (TypeError, ValueError): + return 300 + async def get_plugin_icon(self, plugin_author: str, plugin_name: str) -> dict[str, Any]: """Get plugin icon""" result = await self.call_action( diff --git a/tests/factories/message.py b/tests/factories/message.py index 8871c664..66aec7d5 100644 --- a/tests/factories/message.py +++ b/tests/factories/message.py @@ -18,6 +18,7 @@ import langbot_plugin.api.entities.builtin.provider.session as provider_session # Counter for generating unique IDs _query_counter = 0 +DEFAULT_RUNNER_ID = "plugin:langbot/local-agent/default" def _next_query_id() -> int: @@ -163,10 +164,12 @@ def _base_query( "bot_uuid": "test-bot-uuid", "pipeline_config": { "ai": { - "runner": {"runner": "local-agent"}, - "local-agent": { - "model": {"primary": "test-model-uuid", "fallbacks": []}, - "prompt": "test-prompt", + "runner": {"id": DEFAULT_RUNNER_ID}, + "runner_config": { + DEFAULT_RUNNER_ID: { + "model": {"primary": "test-model-uuid", "fallbacks": []}, + "prompt": [{"role": "system", "content": "test-prompt"}], + }, }, }, "output": {"misc": {"at-sender": False, "quote-origin": False}}, @@ -469,4 +472,4 @@ def at_all_query( sender_id=sender_id, adapter=adapter, **overrides, - ) \ No newline at end of file + ) diff --git a/tests/unit_tests/agent/test_config_migration.py b/tests/unit_tests/agent/test_config_migration.py index 202e0eb3..07be608b 100644 --- a/tests/unit_tests/agent/test_config_migration.py +++ b/tests/unit_tests/agent/test_config_migration.py @@ -132,7 +132,7 @@ class TestResolveRunnerConfig: assert config == {'model': 'uuid-123', 'max_round': 10} def test_resolve_old_format_config(self): - """Resolve runner config from old format.""" + """Runtime config resolver should not read old format.""" pipeline_config = { 'ai': { 'local-agent': { @@ -146,6 +146,23 @@ class TestResolveRunnerConfig: pipeline_config, 'plugin:langbot/local-agent/default', ) + assert config == {} + + def test_resolve_legacy_config_for_migration(self): + """Migration helper should read old format.""" + pipeline_config = { + 'ai': { + 'local-agent': { + 'model': 'uuid-123', + 'max_round': 10, + }, + }, + } + + config = ConfigMigration.resolve_legacy_runner_config( + pipeline_config, + 'plugin:langbot/local-agent/default', + ) assert config == {'model': 'uuid-123', 'max_round': 10} def test_resolve_no_config(self): @@ -228,4 +245,4 @@ class TestGetOldRunnerName: def test_get_old_runner_name_not_mapped(self): """Get old runner name for unmapped runner ID.""" old_name = ConfigMigration.get_old_runner_name('plugin:alice/my-agent/custom') - assert old_name is None \ No newline at end of file + assert old_name is None diff --git a/tests/unit_tests/agent/test_config_migration_full.py b/tests/unit_tests/agent/test_config_migration_full.py index 640111b9..39c4a52e 100644 --- a/tests/unit_tests/agent/test_config_migration_full.py +++ b/tests/unit_tests/agent/test_config_migration_full.py @@ -229,8 +229,8 @@ class TestResolveRunnerIdBackwardCompat: assert runner_id == 'plugin:new-runner/default' -class TestResolveRunnerConfigBackwardCompat: - """Tests for backward compatibility in resolve_runner_config.""" +class TestResolveRunnerConfig: + """Tests for runtime runner config resolution.""" def test_resolve_new_format_config(self): """resolve_runner_config should read from runner_config.""" @@ -245,13 +245,23 @@ class TestResolveRunnerConfigBackwardCompat: assert runner_config['max-round'] == 20 def test_resolve_old_format_config(self): - """resolve_runner_config should read from old ai.local-agent.""" + """resolve_runner_config should not read old ai.local-agent at runtime.""" config = { 'ai': { 'local-agent': {'max-round': 15}, }, } runner_config = ConfigMigration.resolve_runner_config(config, 'plugin:langbot/local-agent/default') + assert runner_config == {} + + def test_resolve_legacy_runner_config_for_migration(self): + """resolve_legacy_runner_config should read old ai.local-agent for migration.""" + config = { + 'ai': { + 'local-agent': {'max-round': 15}, + }, + } + runner_config = ConfigMigration.resolve_legacy_runner_config(config, 'plugin:langbot/local-agent/default') assert runner_config['max-round'] == 15 def test_resolve_new_format_priority(self): diff --git a/tests/unit_tests/agent/test_handler_auth.py b/tests/unit_tests/agent/test_handler_auth.py index f4f303c6..7397ac89 100644 --- a/tests/unit_tests/agent/test_handler_auth.py +++ b/tests/unit_tests/agent/test_handler_auth.py @@ -16,8 +16,9 @@ import pytest import types from unittest.mock import AsyncMock, MagicMock +from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry -from langbot.pkg.plugin.handler import _build_tool_detail +from langbot.pkg.plugin.handler import _build_tool_detail, _get_pipeline_knowledge_base_uuids # Import shared test fixtures from conftest.py from .conftest import make_resources @@ -105,11 +106,53 @@ class MockApplication: self.persistence_mgr.execute_async = AsyncMock(return_value=MagicMock(first=lambda: None)) +class FakeAgentRunnerRegistry: + async def get(self, runner_id, bound_plugins=None): + return AgentRunnerDescriptor( + id=runner_id, + source='plugin', + label={'en_US': 'Test Runner'}, + plugin_author='test', + plugin_name='runner', + runner_name='default', + config_schema=[ + {'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector', 'default': []}, + ], + capabilities={'knowledge_retrieval': True}, + permissions={'knowledge_bases': ['list', 'retrieve']}, + ) + + class MockConnection: """Mock connection for testing.""" pass +class TestPipelineKnowledgeBaseScope: + """Tests for schema-driven pipeline KB scope resolution.""" + + @pytest.mark.asyncio + async def test_uses_preprocessed_query_scope(self): + app = MockApplication() + query = MockQuery() + query.variables = {'_knowledge_base_uuids': ['kb_var', '__none__', 'kb_var']} + + kb_uuids = await _get_pipeline_knowledge_base_uuids(app, query) + + assert kb_uuids == ['kb_var'] + + @pytest.mark.asyncio + async def test_uses_runner_schema_when_query_scope_not_preprocessed(self): + app = MockApplication() + app.agent_runner_registry = FakeAgentRunnerRegistry() + query = MockQuery() + query.variables = {} + + kb_uuids = await _get_pipeline_knowledge_base_uuids(app, query) + + assert kb_uuids == ['kb_001', 'kb_002'] + + class MockDisconnectCallback: """Mock disconnect callback for testing.""" async def __call__(self): diff --git a/tests/unit_tests/agent/test_orchestrator_integration.py b/tests/unit_tests/agent/test_orchestrator_integration.py index 4e9bea95..f74fa098 100644 --- a/tests/unit_tests/agent/test_orchestrator_integration.py +++ b/tests/unit_tests/agent/test_orchestrator_integration.py @@ -1,6 +1,7 @@ """Integration-style tests for AgentRunOrchestrator with a fake plugin runner.""" from __future__ import annotations +import asyncio import datetime import types from unittest.mock import AsyncMock @@ -61,9 +62,10 @@ class FakeKnowledgeBase: class FakePluginConnector: is_enable_plugin = True - def __init__(self, results=None, error: Exception | None = None): + def __init__(self, results=None, error: Exception | None = None, delay: float = 0): self.results = results or [] self.error = error + self.delay = delay self.calls: list[dict] = [] self.contexts: list[dict] = [] self.sessions_during_run: list[dict | None] = [] @@ -83,6 +85,8 @@ class FakePluginConnector: raise self.error for result in self.results: + if self.delay: + await asyncio.sleep(self.delay) yield result @@ -125,7 +129,11 @@ def make_descriptor() -> AgentRunnerDescriptor: plugin_name="local-agent", runner_name="default", protocol_version="1", - capabilities={"streaming": True, "tool_calling": True}, + capabilities={"streaming": True, "tool_calling": True, "knowledge_retrieval": True}, + config_schema=[ + {"name": "model", "type": "model-fallback-selector"}, + {"name": "knowledge-bases", "type": "knowledge-base-multi-selector", "default": []}, + ], permissions={ "models": ["invoke", "stream"], "tools": ["list", "detail", "call"], @@ -367,3 +375,27 @@ async def test_orchestrator_unregisters_session_after_runner_failure(): context = plugin_connector.contexts[0] assert plugin_connector.sessions_during_run[0] is not None assert await get_session_registry().get(context["run_id"]) is None + + +@pytest.mark.asyncio +async def test_orchestrator_enforces_total_runner_deadline(): + descriptor = make_descriptor() + plugin_connector = FakePluginConnector( + results=[ + { + "type": "message.completed", + "data": {"message": {"role": "assistant", "content": "too late"}}, + } + ], + delay=0.05, + ) + orchestrator = AgentRunOrchestrator(FakeApplication(plugin_connector), FakeRegistry(descriptor)) + query = make_query() + query.pipeline_config["ai"]["runner_config"][RUNNER_ID]["timeout"] = 0.01 + + with pytest.raises(RunnerExecutionError) as exc_info: + [message async for message in orchestrator.run_from_query(query)] + + assert exc_info.value.retryable is True + assert "runner.timeout" in str(exc_info.value) + assert await get_session_registry().get(plugin_connector.contexts[0]["run_id"]) is None diff --git a/tests/unit_tests/api/service/test_model_service.py b/tests/unit_tests/api/service/test_model_service.py index 6e6d2598..fb8670f7 100644 --- a/tests/unit_tests/api/service/test_model_service.py +++ b/tests/unit_tests/api/service/test_model_service.py @@ -13,10 +13,12 @@ Source: src/langbot/pkg/api/http/service/model.py from __future__ import annotations -import pytest -from unittest.mock import AsyncMock, Mock from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock +import pytest + +from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor from langbot.pkg.api.http.service.model import ( LLMModelsService, EmbeddingModelsService, @@ -28,6 +30,7 @@ from langbot.pkg.entity.persistence.model import LLMModel, EmbeddingModel, Reran pytestmark = pytest.mark.asyncio +RUNNER_ID = 'plugin:test/runner/default' def _create_mock_llm_model( @@ -98,6 +101,22 @@ def _create_mock_result(items: list = None, first_item=None): return result +class FakeAgentRunnerRegistry: + async def get(self, runner_id, bound_plugins=None): + return AgentRunnerDescriptor( + id=runner_id, + source='plugin', + label={'en_US': 'Test Runner'}, + plugin_author='test', + plugin_name='runner', + runner_name='default', + config_schema=[ + {'name': 'model', 'type': 'model-fallback-selector', 'default': {'primary': '', 'fallbacks': []}}, + ], + permissions={'models': ['invoke']}, + ) + + class TestParseProviderApiKeys: """Tests for _parse_provider_api_keys helper function.""" @@ -402,6 +421,51 @@ class TestLLMModelsServiceCreateLLMModel: # Verify assert model_uuid == 'preserved-uuid' + async def test_create_llm_model_auto_sets_schema_defined_default_pipeline_model(self): + """Auto-default model selection should use runner schema, not legacy field names.""" + ap = SimpleNamespace() + ap.logger = Mock() + ap.persistence_mgr = SimpleNamespace() + ap.model_mgr = SimpleNamespace() + ap.model_mgr.provider_dict = {'provider-uuid': Mock()} + ap.model_mgr.llm_models = [] + ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock()) + ap.pipeline_service = SimpleNamespace(update_pipeline=AsyncMock()) + ap.agent_runner_registry = FakeAgentRunnerRegistry() + + pipeline = SimpleNamespace( + uuid='pipeline-uuid', + config={ + 'ai': { + 'runner': {'id': RUNNER_ID}, + 'runner_config': { + RUNNER_ID: { + 'model': {'primary': '', 'fallbacks': []}, + }, + }, + }, + }, + ) + ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=pipeline)) + + service = LLMModelsService(ap) + + model_uuid = await service.create_llm_model({ + 'uuid': 'new-model-uuid', + 'name': 'New LLM', + 'provider_uuid': 'provider-uuid', + 'abilities': [], + 'extra_args': {}, + }, preserve_uuid=True) + + assert model_uuid == 'new-model-uuid' + ap.pipeline_service.update_pipeline.assert_awaited_once() + updated_config = ap.pipeline_service.update_pipeline.await_args.args[1]['config'] + assert updated_config['ai']['runner_config'][RUNNER_ID]['model'] == { + 'primary': 'new-model-uuid', + 'fallbacks': [], + } + async def test_create_llm_model_provider_not_found_raises_error(self): """Raises Exception when provider not found in runtime.""" # Setup @@ -961,4 +1025,4 @@ class TestRerankModelsServiceGetRerankModelsByProvider: result = await service.get_rerank_models_by_provider('provider-uuid') # Verify - assert len(result) == 2 \ No newline at end of file + assert len(result) == 2 diff --git a/tests/unit_tests/pipeline/conftest.py b/tests/unit_tests/pipeline/conftest.py index a10e0aba..2f731610 100644 --- a/tests/unit_tests/pipeline/conftest.py +++ b/tests/unit_tests/pipeline/conftest.py @@ -21,6 +21,9 @@ import langbot_plugin.api.entities.builtin.provider.session as provider_session from langbot.pkg.pipeline import entities as pipeline_entities +DEFAULT_RUNNER_ID = 'plugin:langbot/local-agent/default' + + class MockApplication: """Mock Application object providing all basic dependencies needed by stages""" @@ -193,8 +196,13 @@ def sample_query(sample_message_chain, sample_message_event, mock_adapter): bot_uuid='test-bot-uuid', pipeline_config={ 'ai': { - 'runner': {'runner': 'local-agent'}, - 'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'}, + 'runner': {'id': DEFAULT_RUNNER_ID}, + 'runner_config': { + DEFAULT_RUNNER_ID: { + 'model': {'primary': 'test-model-uuid', 'fallbacks': []}, + 'prompt': [{'role': 'system', 'content': 'test-prompt'}], + }, + }, }, 'output': {'misc': {'at-sender': False, 'quote-origin': False}}, 'trigger': {'misc': {'combine-quote-message': False}}, @@ -218,8 +226,13 @@ def sample_pipeline_config(): """Provides sample pipeline configuration""" return { 'ai': { - 'runner': {'runner': 'local-agent'}, - 'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'}, + 'runner': {'id': DEFAULT_RUNNER_ID}, + 'runner_config': { + DEFAULT_RUNNER_ID: { + 'model': {'primary': 'test-model-uuid', 'fallbacks': []}, + 'prompt': [{'role': 'system', 'content': 'test-prompt'}], + }, + }, }, 'output': {'misc': {'at-sender': False, 'quote-origin': False}}, 'trigger': {'misc': {'combine-quote-message': False}}, diff --git a/tests/unit_tests/pipeline/test_chat_handler.py b/tests/unit_tests/pipeline/test_chat_handler.py index 097ef2b4..995e3fe5 100644 --- a/tests/unit_tests/pipeline/test_chat_handler.py +++ b/tests/unit_tests/pipeline/test_chat_handler.py @@ -13,6 +13,24 @@ from unittest.mock import AsyncMock, Mock from tests.factories import FakeApp +DEFAULT_RUNNER_ID = 'plugin:langbot/local-agent/default' + + +def runner_pipeline_config(output_misc: dict) -> dict: + return { + 'output': {'misc': output_misc}, + 'ai': { + 'runner': {'id': DEFAULT_RUNNER_ID}, + 'runner_config': { + DEFAULT_RUNNER_ID: { + 'prompt': [{'role': 'system', 'content': 'default'}], + 'model': {'primary': 'test', 'fallbacks': []}, + }, + }, + }, + } + + # ============== FIXTURE USING IMPORT ISOLATION UTILITY ============== @pytest.fixture(scope='module') @@ -53,7 +71,22 @@ def mock_circular_import_chain(): @pytest.fixture def fake_app(): """Create FakeApp instance.""" - return FakeApp() + app = FakeApp() + + class ProviderRunnerBackedOrchestrator: + async def run_from_query(self, query): + import sys + + runner_class = sys.modules['langbot.pkg.provider.runner'].preregistered_runners[0] + runner = runner_class(app, {}) + async for result in runner.run(query): + yield result + + def resolve_runner_id_for_telemetry(self, query): + return DEFAULT_RUNNER_ID + + app.agent_run_orchestrator = ProviderRunnerBackedOrchestrator() + return app @pytest.fixture @@ -301,10 +334,9 @@ class TestChatHandlerExceptions: query.adapter.is_stream_output_supported = AsyncMock(return_value=False) query.user_message = Message(role='user', content=[]) - query.pipeline_config = { - 'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}}, - 'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}}, - } + query.pipeline_config = runner_pipeline_config( + {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'} + ) class FailingRunner: name = 'local-agent' @@ -344,10 +376,7 @@ class TestChatHandlerExceptions: query.adapter.is_stream_output_supported = AsyncMock(return_value=False) query.user_message = Message(role='user', content=[]) - query.pipeline_config = { - 'output': {'misc': {'exception-handling': 'show-error'}}, - 'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}}, - } + query.pipeline_config = runner_pipeline_config({'exception-handling': 'show-error'}) class ErrorRunner: name = 'local-agent' @@ -384,10 +413,7 @@ class TestChatHandlerExceptions: query.adapter.is_stream_output_supported = AsyncMock(return_value=False) query.user_message = Message(role='user', content=[]) - query.pipeline_config = { - 'output': {'misc': {'exception-handling': 'hide'}}, - 'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}}, - } + query.pipeline_config = runner_pipeline_config({'exception-handling': 'hide'}) class HideErrorRunner: name = 'local-agent' @@ -433,4 +459,4 @@ class TestChatHandlerHelper: chat = get_chat_handler() handler = chat.ChatMessageHandler(fake_app) result = handler.cut_str('first line\nsecond line') - assert '...' in result \ No newline at end of file + assert '...' in result diff --git a/tests/unit_tests/pipeline/test_msgtrun.py b/tests/unit_tests/pipeline/test_msgtrun.py index 9cfdabab..1fe44ba4 100644 --- a/tests/unit_tests/pipeline/test_msgtrun.py +++ b/tests/unit_tests/pipeline/test_msgtrun.py @@ -21,6 +21,9 @@ from tests.factories import ( import langbot_plugin.api.entities.builtin.provider.message as provider_message +RUNNER_ID = 'plugin:langbot/local-agent/default' + + def get_msgtrun_module(): """Lazy import to avoid circular import issues.""" # Import pipelinemgr first to trigger stage registration @@ -47,9 +50,12 @@ def make_truncate_config(max_round: int = 5): """Create a pipeline config with max-round setting.""" return { 'ai': { - 'local-agent': { - 'max-round': max_round, - } + 'runner': {'id': RUNNER_ID}, + 'runner_config': { + RUNNER_ID: { + 'max-round': max_round, + }, + }, } } diff --git a/tests/unit_tests/pipeline/test_preproc.py b/tests/unit_tests/pipeline/test_preproc.py index 1413f5f7..9620a1c1 100644 --- a/tests/unit_tests/pipeline/test_preproc.py +++ b/tests/unit_tests/pipeline/test_preproc.py @@ -24,6 +24,9 @@ from tests.factories import ( ) +RUNNER_ID = 'plugin:langbot/local-agent/default' + + def get_preproc_module(): """Lazy import to avoid circular import issues.""" return import_module('langbot.pkg.pipeline.preproc.preproc') @@ -34,6 +37,76 @@ def get_entities_module(): return import_module('langbot.pkg.pipeline.entities') +class FakeAgentRunnerRegistry: + def __init__(self, descriptor): + self.descriptor = descriptor + + async def get(self, runner_id, bound_plugins=None): + return self.descriptor + + +def make_host_model_runner_descriptor( + *, + multimodal_input: bool = True, + tool_calling: bool = True, + knowledge_retrieval: bool = True, +): + from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor + + return AgentRunnerDescriptor( + id=RUNNER_ID, + source='plugin', + label={'en_US': 'Local Agent'}, + plugin_author='langbot', + plugin_name='local-agent', + runner_name='default', + config_schema=[ + {'name': 'model', 'type': 'model-fallback-selector'}, + {'name': 'prompt', 'type': 'prompt-editor', 'default': []}, + {'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector', 'default': []}, + ], + capabilities={ + 'tool_calling': tool_calling, + 'knowledge_retrieval': knowledge_retrieval, + 'multimodal_input': multimodal_input, + }, + permissions={ + 'models': ['list', 'invoke', 'stream'], + 'tools': ['list', 'detail', 'call'], + 'knowledge_bases': ['list', 'retrieve'], + }, + ) + + +def set_runner_descriptor(app, descriptor=None): + app.agent_runner_registry = FakeAgentRunnerRegistry( + descriptor or make_host_model_runner_descriptor() + ) + + +def make_runner_config( + *, + primary: str = 'test-model-uuid', + fallbacks: list[str] | None = None, + prompt: list[dict] | None = None, + knowledge_bases: list[str] | None = None, +): + return { + 'ai': { + 'runner': {'id': RUNNER_ID}, + 'runner_config': { + RUNNER_ID: { + 'model': {'primary': primary, 'fallbacks': fallbacks or []}, + 'prompt': prompt if prompt is not None else [], + 'knowledge-bases': knowledge_bases or [], + }, + }, + }, + 'output': {'misc': {'at-sender': False}}, + 'trigger': {'misc': {}}, + } + + class TestPreProcessorNormalText: """Tests for normal text message preprocessing.""" @@ -107,6 +180,7 @@ class TestPreProcessorNormalText: mock_model.model_entity = Mock(uuid='test-model', abilities=['func_call']) app.model_mgr.get_model_by_uuid = AsyncMock(return_value=mock_model) app.tool_mgr.get_all_tools = AsyncMock(return_value=[]) + set_runner_descriptor(app) mock_event_ctx = Mock() mock_event_ctx.event = Mock(default_prompt=[], prompt=[]) @@ -195,6 +269,7 @@ class TestPreProcessorImageSegment: stage = preproc.PreProcessor(app) # Image query with base64 query = image_query(text="look at this", url=None) + query.pipeline_config = make_runner_config(primary='vision-model') # Set base64 on the image component import langbot_plugin.api.entities.builtin.platform.message as platform_message chain = platform_message.MessageChain([ @@ -206,8 +281,8 @@ class TestPreProcessorImageSegment: result = await stage.process(query, 'PreProcessor') assert result.result_type == preproc.entities.ResultType.CONTINUE - # User message should have content - assert result.new_query.user_message.content is not None + content_types = [elem.type for elem in result.new_query.user_message.content] + assert 'image_base64' in content_types @pytest.mark.asyncio async def test_image_without_vision_model(self): @@ -232,6 +307,7 @@ class TestPreProcessorImageSegment: mock_model.model_entity = Mock(uuid='text-only-model', abilities=['func_call']) app.model_mgr.get_model_by_uuid = AsyncMock(return_value=mock_model) app.tool_mgr.get_all_tools = AsyncMock(return_value=[]) + set_runner_descriptor(app) mock_event_ctx = Mock() mock_event_ctx.event = Mock(default_prompt=[], prompt=[]) @@ -239,10 +315,13 @@ class TestPreProcessorImageSegment: stage = preproc.PreProcessor(app) query = image_query(text="describe this") + query.pipeline_config = make_runner_config(primary='text-only-model') result = await stage.process(query, 'PreProcessor') assert result.result_type == preproc.entities.ResultType.CONTINUE + content_types = [elem.type for elem in result.new_query.user_message.content] + assert 'image_url' not in content_types class TestPreProcessorModelSelection: @@ -270,6 +349,7 @@ class TestPreProcessorModelSelection: mock_model.model_entity = Mock(uuid='primary-model-uuid', abilities=['func_call']) app.model_mgr.get_model_by_uuid = AsyncMock(return_value=mock_model) app.tool_mgr.get_all_tools = AsyncMock(return_value=[]) + set_runner_descriptor(app) mock_event_ctx = Mock() mock_event_ctx.event = Mock(default_prompt=[], prompt=[]) @@ -279,17 +359,7 @@ class TestPreProcessorModelSelection: query = text_query("hello") # Set pipeline config with primary model - query.pipeline_config = { - 'ai': { - 'runner': {'runner': 'local-agent'}, - 'local-agent': { - 'model': {'primary': 'primary-model-uuid', 'fallbacks': []}, - 'prompt': 'default', - }, - }, - 'output': {'misc': {'at-sender': False}}, - 'trigger': {'misc': {}}, - } + query.pipeline_config = make_runner_config(primary='primary-model-uuid') result = await stage.process(query, 'PreProcessor') @@ -329,6 +399,7 @@ class TestPreProcessorModelSelection: app.model_mgr.get_model_by_uuid = AsyncMock(side_effect=mock_get_model) app.tool_mgr.get_all_tools = AsyncMock(return_value=[]) + set_runner_descriptor(app) mock_event_ctx = Mock() mock_event_ctx.event = Mock(default_prompt=[], prompt=[]) @@ -337,17 +408,7 @@ class TestPreProcessorModelSelection: stage = preproc.PreProcessor(app) query = text_query("hello") - query.pipeline_config = { - 'ai': { - 'runner': {'runner': 'local-agent'}, - 'local-agent': { - 'model': {'primary': 'primary-uuid', 'fallbacks': ['fallback-uuid']}, - 'prompt': 'default', - }, - }, - 'output': {'misc': {'at-sender': False}}, - 'trigger': {'misc': {}}, - } + query.pipeline_config = make_runner_config(primary='primary-uuid', fallbacks=['fallback-uuid']) result = await stage.process(query, 'PreProcessor') diff --git a/tests/unit_tests/provider/test_model_service.py b/tests/unit_tests/provider/test_model_service.py index 03e9a43a..cf579eed 100644 --- a/tests/unit_tests/provider/test_model_service.py +++ b/tests/unit_tests/provider/test_model_service.py @@ -12,6 +12,7 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message import langbot_plugin.api.entities.builtin.provider.session as provider_session from langbot.pkg.api.http.service.model import _runtime_model_data +from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor from langbot.pkg.api.http.service.provider import ModelProviderService from langbot.pkg.entity.persistence import model as persistence_model from langbot.pkg.pipeline.preproc.preproc import PreProcessor @@ -23,6 +24,32 @@ from langbot.pkg.provider.modelmgr.token import TokenManager from langbot.pkg.provider.runners.localagent import LocalAgentRunner +DEFAULT_RUNNER_ID = 'plugin:langbot/local-agent/default' + + +class FakeAgentRunnerRegistry: + async def get(self, runner_id, bound_plugins=None): + return AgentRunnerDescriptor( + id=runner_id, + source='plugin', + label={'en_US': 'Local Agent'}, + plugin_author='langbot', + plugin_name='local-agent', + runner_name='default', + config_schema=[ + {'name': 'model', 'type': 'model-fallback-selector'}, + {'name': 'prompt', 'type': 'prompt-editor', 'default': []}, + {'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector', 'default': []}, + ], + capabilities={'tool_calling': True, 'knowledge_retrieval': True, 'multimodal_input': True}, + permissions={ + 'models': ['list', 'invoke', 'stream'], + 'tools': ['list', 'detail', 'call'], + 'knowledge_bases': ['list', 'retrieve'], + }, + ) + + def test_runtime_llm_model_data_preserves_uuid_after_update_payload_uuid_removed(): update_payload = { 'name': 'Qwen3.5-27B', @@ -190,6 +217,7 @@ async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline() ap = SimpleNamespace() ap.logger = Mock() + ap.agent_runner_registry = FakeAgentRunnerRegistry() ap.persistence_mgr = SimpleNamespace(execute_async=AsyncMock()) ap.tool_mgr = SimpleNamespace(get_all_tools=AsyncMock(return_value=[])) ap.plugin_connector = SimpleNamespace( @@ -252,11 +280,13 @@ async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline() ) pipeline_config = { 'ai': { - 'runner': {'runner': 'local-agent'}, - 'local-agent': { - 'model': {'primary': model_uuid, 'fallbacks': []}, - 'prompt': [], - 'knowledge-bases': [], + 'runner': {'id': DEFAULT_RUNNER_ID}, + 'runner_config': { + DEFAULT_RUNNER_ID: { + 'model': {'primary': model_uuid, 'fallbacks': []}, + 'prompt': [], + 'knowledge-bases': [], + }, }, }, 'trigger': {'misc': {'combine-quote-message': False}},