"""Workflow node type registry.""" from __future__ import annotations import copy import logging from typing import Any, Optional from .metadata import build_node_type from .node import WorkflowNode, clear_pending_registrations, get_pending_registrations logger = logging.getLogger(__name__) class NodeConflictError(Exception): """Raised when two workflow node metadata definitions conflict.""" class NodeTypeRegistry: """ Central registry for workflow node types. YAML metadata is the UI-facing source of truth. Python node classes are registered separately and provide execution logic only. """ _instance: Optional['NodeTypeRegistry'] = None def __init__(self): self._nodes: dict[str, type[WorkflowNode]] = {} self._metadata: dict[str, dict[str, Any]] = {} self._metadata_sources: dict[str, str] = {} self._categories: dict[str, list[str]] = { 'trigger': [], 'process': [], 'control': [], 'action': [], 'integration': [], 'misc': [], } self._conflicts: list[dict[str, str]] = [] @classmethod def instance(cls) -> 'NodeTypeRegistry': """Get singleton instance.""" if cls._instance is None: cls._instance = cls() return cls._instance def register_metadata(self, metadata: dict[str, Any], source: str = 'core') -> bool: """Register YAML metadata for a workflow node type. Core metadata cannot be overridden by plugin metadata. Plugin-plugin conflicts are allowed with a warning so hot-reload/development flows can replace plugin definitions. """ node_type = build_node_type(metadata) existing_source = self._metadata_sources.get(node_type) if existing_source: conflict = {'type': node_type, 'existing_source': existing_source, 'new_source': source} if existing_source == 'core' and source != 'core': self._conflicts.append(conflict) logger.error('Plugin source %s attempted to override core workflow node %s', source, node_type) return False logger.warning( 'Workflow node metadata %s from %s overrides previous source %s', node_type, source, existing_source ) cached_metadata = copy.deepcopy(metadata) cached_metadata['_source'] = source self._metadata[node_type] = cached_metadata self._metadata_sources[node_type] = source self._add_to_category(metadata.get('category', 'misc'), node_type) return True def register(self, node_type: str, node_class: type[WorkflowNode]): """Register a Python workflow node implementation class.""" canonical_type = self._canonical_type_for_class(node_type, node_class) self._nodes[canonical_type] = node_class metadata = self.get_metadata(canonical_type) if metadata: category = metadata.get('category', getattr(node_class, 'category', 'misc')) else: category = getattr(node_class, 'category', 'misc') logger.warning('Workflow node implementation %s has no YAML metadata', canonical_type) self._add_to_category(category, canonical_type) def unregister(self, node_type: str): """Unregister a Python workflow node implementation.""" canonical_type = self._resolve_registered_node_key(node_type) if canonical_type is None: return node_class = self._nodes[canonical_type] metadata = self.get_metadata(canonical_type) category = metadata.get('category') if metadata else getattr(node_class, 'category', 'misc') self._remove_from_category(category or 'misc', canonical_type) del self._nodes[canonical_type] def unregister_metadata(self, node_type: str): """Unregister YAML metadata for a node type, primarily for plugin unload.""" canonical_type = self._resolve_metadata_key(node_type) if canonical_type is None: return metadata = self._metadata[canonical_type] self._remove_from_category(metadata.get('category', 'misc'), canonical_type) del self._metadata[canonical_type] self._metadata_sources.pop(canonical_type, None) def get(self, node_type: str) -> Optional[type[WorkflowNode]]: """Get node class by type. Supports both ``category.name`` and short names.""" canonical_type = self._resolve_registered_node_key(node_type) if canonical_type: return self._nodes[canonical_type] if get_pending_registrations(): self.process_pending_registrations() canonical_type = self._resolve_registered_node_key(node_type) if canonical_type: return self._nodes[canonical_type] return None def get_metadata(self, node_type: str) -> Optional[dict[str, Any]]: """Get YAML metadata by full type or short node name.""" canonical_type = self._resolve_metadata_key(node_type) if canonical_type: return copy.deepcopy(self._metadata[canonical_type]) return None def create_instance( self, node_type: str, node_id: str, config: dict[str, Any], ap: Optional['app.Application'] = None ) -> Optional[WorkflowNode]: """Create a node instance. Supports both ``category.name`` and short names.""" node_class = self.get(node_type) if node_class: return node_class(node_id, config, ap=ap) logger.warning('No workflow node implementation registered for type: %s', node_type) return None def get_merged_schema(self, node_type: str) -> Optional[dict[str, Any]]: """Get frontend schema from YAML metadata. Python node classes no longer carry UI metadata. If a node class is registered but has no YAML metadata, a minimal schema is generated from the class attributes (category, type_name) so it still appears in the editor. """ metadata = self.get_metadata(node_type) node_class = self.get(node_type) if metadata: schema = self._metadata_to_schema(metadata) if node_class: # Supplement pipeline config reuse fields from Python class for key in ('config_schema_source', 'config_stages'): if not schema.get(key) and getattr(node_class, key, None): schema[key] = getattr(node_class, key) return schema if node_class: # Fallback: node has Python class but no YAML metadata short_name = getattr(node_class, 'type_name', '') or node_type.split('.')[-1] category = getattr(node_class, 'category', 'misc') return { 'type': f'{category}.{short_name}', 'name': short_name, 'label': self._normalize_i18n(None, self._prettify_name(short_name)), 'description': self._normalize_i18n(None, ''), 'category': category, 'icon': '', 'color': '', 'inputs': [], 'outputs': [], 'config_schema': [], 'config_schema_source': getattr(node_class, 'config_schema_source', None), 'config_stages': getattr(node_class, 'config_stages', []), 'source': 'python-only', } return None def list_all(self) -> list[dict[str, Any]]: """Get all registered node type schemas, including metadata-only nodes.""" node_types = self._ordered_node_types(set(self._metadata.keys()) | set(self._nodes.keys())) return [schema for node_type in node_types if (schema := self.get_merged_schema(node_type)) is not None] def list_by_category(self, category: str) -> list[dict[str, Any]]: """Get node type schemas by category.""" if category not in self._categories: return [] return [schema for node_type in self._categories[category] if (schema := self.get_merged_schema(node_type)) is not None] def get_categories(self) -> dict[str, list[dict[str, Any]]]: """Get all nodes organized by category.""" return {category: self.list_by_category(category) for category in self._categories.keys()} def has_type(self, node_type: str) -> bool: """Check whether a node has metadata or an implementation registered.""" return self.get_metadata(node_type) is not None or self.get(node_type) is not None def process_pending_registrations(self): """Process all pending node registrations from decorators.""" for node_type, node_class in get_pending_registrations(): self.register(node_type, node_class) clear_pending_registrations() def count(self) -> int: """Get total number of node types exposed by metadata or implementation.""" return len(set(self._metadata.keys()) | set(self._nodes.keys())) def metadata_count(self) -> int: """Get number of registered YAML metadata definitions.""" return len(self._metadata) def get_conflicts(self) -> list[dict[str, str]]: """Return metadata registration conflicts.""" return copy.deepcopy(self._conflicts) def clear(self): """Clear all registrations (for testing).""" self._nodes.clear() self._metadata.clear() self._metadata_sources.clear() self._conflicts.clear() for category in self._categories: self._categories[category] = [] def _canonical_type_for_class(self, node_type: str, node_class: type[WorkflowNode]) -> str: short_name = node_type.split('.')[-1] metadata_key = self._resolve_metadata_key(node_type) or self._resolve_metadata_key(short_name) if metadata_key: return metadata_key category = getattr(node_class, 'category', 'misc') return node_type if '.' in node_type else f'{category}.{short_name}' def _resolve_registered_node_key(self, node_type: str) -> Optional[str]: if node_type in self._nodes: return node_type short_name = node_type.split('.')[-1] for registered_type, node_class in self._nodes.items(): if registered_type.split('.')[-1] == short_name or getattr(node_class, 'type_name', None) == short_name: return registered_type return None def _resolve_metadata_key(self, node_type: str) -> Optional[str]: if node_type in self._metadata: return node_type short_name = node_type.split('.')[-1] for registered_type, metadata in self._metadata.items(): if registered_type.split('.')[-1] == short_name or metadata.get('name') == short_name: return registered_type return None def _ordered_node_types(self, node_types: set[str]) -> list[str]: ordered: list[str] = [] for category in self._categories: for node_type in self._categories[category]: if node_type in node_types and node_type not in ordered: ordered.append(node_type) for node_type in sorted(node_types): if node_type not in ordered: ordered.append(node_type) return ordered def _add_to_category(self, category: str, node_type: str) -> None: if category not in self._categories: self._categories[category] = [] if node_type not in self._categories[category]: self._categories[category].append(node_type) def _remove_from_category(self, category: str, node_type: str) -> None: if category in self._categories and node_type in self._categories[category]: self._categories[category].remove(node_type) def _metadata_to_schema(self, metadata: dict[str, Any]) -> dict[str, Any]: node_type = build_node_type(metadata) node_name = metadata.get('name', node_type.split('.')[-1]) return { 'type': node_type, 'name': node_name, 'label': self._normalize_i18n(metadata.get('label'), self._prettify_name(node_name)), 'description': self._normalize_i18n(metadata.get('description'), ''), 'category': metadata.get('category', 'misc'), 'icon': metadata.get('icon', ''), 'color': metadata.get('color', ''), 'inputs': [self._normalize_port_item(item) for item in metadata.get('inputs', [])], 'outputs': [self._normalize_port_item(item) for item in metadata.get('outputs', [])], 'config_schema': [self._normalize_config_item(item) for item in metadata.get('config', [])], 'config_schema_source': metadata.get('config_schema_source'), 'config_stages': metadata.get('config_stages', []), 'source': metadata.get('_source', 'core'), } def _merge_missing_schema_fields(self, yaml_schema: dict[str, Any], python_schema: dict[str, Any]) -> dict[str, Any]: result = copy.deepcopy(yaml_schema) for key in ('config_schema_source', 'config_stages'): if not result.get(key) and python_schema.get(key): result[key] = python_schema[key] return result def _normalize_port_item(self, port: dict[str, Any]) -> dict[str, Any]: item = copy.deepcopy(port) name = item.get('name', '') item['label'] = self._normalize_i18n(item.get('label'), self._prettify_name(name)) item['description'] = self._normalize_i18n(item.get('description'), '') item.setdefault('type', 'any') item.setdefault('required', True) return item def _normalize_config_item(self, config: dict[str, Any]) -> dict[str, Any]: item = copy.deepcopy(config) name = item.get('name', '') frontend_type = self._normalize_config_type(item.get('type', 'string')) item['id'] = item.get('id') or name item['type'] = frontend_type item['label'] = self._normalize_i18n(item.get('label'), self._prettify_name(name)) item['description'] = self._normalize_i18n(item.get('description'), '') item['required'] = bool(item.get('required', False)) item['default'] = item.get('default', self._default_value_for_type(frontend_type)) if 'options' in item: item['options'] = self._normalize_options(item.get('options'), name) return item def _normalize_options(self, options: Any, field_name: str) -> list[dict[str, Any]]: if not isinstance(options, list): return [] normalized: list[dict[str, Any]] = [] for option in options: if isinstance(option, dict): option_item = copy.deepcopy(option) option_name = option_item.get('name', option_item.get('value', '')) option_item['name'] = str(option_name) option_item['label'] = self._normalize_i18n(option_item.get('label'), str(option_name)) normalized.append(option_item) else: option_name = str(option) normalized.append({'name': option_name, 'label': self._normalize_i18n(None, option_name)}) return normalized def _normalize_i18n(self, value: Any, fallback: str) -> dict[str, str]: if isinstance(value, dict): en_value = ( value.get('en_US') or value.get('en-US') or value.get('en') or value.get('en_US'.replace('_', '-')) or fallback ) zh_value = value.get('zh_Hans') or value.get('zh-Hans') or value.get('zh-CN') or value.get('zh') or en_value return { 'en_US': str(en_value), 'en': str(en_value), 'en-US': str(en_value), 'zh_Hans': str(zh_value), 'zh-Hans': str(zh_value), 'zh-CN': str(zh_value), } if isinstance(value, str) and value: return { 'en_US': value, 'en': value, 'en-US': value, 'zh_Hans': value, 'zh-Hans': value, 'zh-CN': value, } return { 'en_US': fallback, 'en': fallback, 'en-US': fallback, 'zh_Hans': fallback, 'zh-Hans': fallback, 'zh-CN': fallback, } def _normalize_config_type(self, field_type: str) -> str: type_map = { 'number': 'float', 'json': 'text', 'textarea': 'text', } return type_map.get(field_type, field_type) def _default_value_for_type(self, field_type: str) -> Any: if field_type == 'boolean': return False if field_type in {'integer', 'float'}: return 0 if field_type in {'array[string]', 'knowledge-base-multi-selector', 'tools-selector'}: return [] if field_type == 'model-fallback-selector': return {'primary': '', 'fallbacks': []} if field_type == 'prompt-editor': return [{'role': 'system', 'content': ''}] return '' def _prettify_name(self, name: str) -> str: return ' '.join(part.capitalize() for part in str(name).replace('-', '_').split('_') if part) # Convenience functions for module-level access def register_node(node_type: str, node_class: type[WorkflowNode]): """Register a node type to the global registry.""" NodeTypeRegistry.instance().register(node_type, node_class) def get_node_class(node_type: str) -> Optional[type[WorkflowNode]]: """Get a node class from the global registry.""" return NodeTypeRegistry.instance().get(node_type) def list_node_types() -> list[dict[str, Any]]: """List all registered node types.""" return NodeTypeRegistry.instance().list_all()