"""Workflow execution engine""" from __future__ import annotations import ast import asyncio import logging import operator import traceback import uuid from datetime import datetime from typing import Any, Optional, TYPE_CHECKING import sqlalchemy from .entities import ( WorkflowDefinition, NodeDefinition, EdgeDefinition, ExecutionContext, ExecutionStatus, NodeState, NodeStatus, ExecutionStep, ) from ..entity.persistence import workflow as persistence_workflow from .registry import NodeTypeRegistry if TYPE_CHECKING: from ..core import app logger = logging.getLogger(__name__) class ExecutionLog: """Execution log entry""" def __init__(self, level: str, message: str, node_id: Optional[str] = None, data: Optional[dict] = None): self.id = str(uuid.uuid4()) self.timestamp = datetime.now().isoformat() self.level = level self.message = message self.node_id = node_id self.data = data or {} def to_dict(self) -> dict: return { 'id': self.id, 'timestamp': self.timestamp, 'level': self.level, 'message': self.message, 'node_id': self.node_id, 'data': self.data, } class DebugExecutionState: """State for a debug execution""" def __init__(self, execution_id: str, breakpoints: list[str] = None): self.execution_id = execution_id self.status: str = 'running' self.is_paused: bool = False self.is_stopped: bool = False self.current_node_id: Optional[str] = None self.breakpoints: set[str] = set(breakpoints or []) self.logs: list[ExecutionLog] = [] self.pending_logs: list[ExecutionLog] = [] self._pause_event = asyncio.Event() self._pause_event.set() # Initially not paused self._stop_event = asyncio.Event() def add_log(self, level: str, message: str, node_id: str = None, data: dict = None): """Add a log entry""" log = ExecutionLog(level, message, node_id, data) self.logs.append(log) self.pending_logs.append(log) logger.log( getattr(logging, level.upper(), logging.INFO), f'[Workflow Debug] {message}', extra={'node_id': node_id, 'data': data}, ) def get_pending_logs(self) -> list[dict]: """Get and clear pending logs""" logs = [log.to_dict() for log in self.pending_logs] self.pending_logs = [] return logs def pause(self): """Pause execution""" self.is_paused = True self._pause_event.clear() self.add_log('info', 'Execution paused') def resume(self): """Resume execution""" self.is_paused = False self._pause_event.set() self.add_log('info', 'Execution resumed') def stop(self): """Stop execution""" self.is_stopped = True self.status = 'cancelled' self._stop_event.set() self._pause_event.set() # Release any pause self.add_log('info', 'Execution stopped') async def wait_if_paused(self): """Wait if execution is paused""" if self.is_paused: self.add_log('info', 'Waiting for resume...') await self._pause_event.wait() def check_breakpoint(self, node_id: str) -> bool: """Check if there's a breakpoint at the given node""" return node_id in self.breakpoints # ─── Safe expression evaluator (replaces eval()) ───────────────────── # Uses Python's ast module to whitelist only comparison / boolean / arithmetic # operations. No function calls, attribute access, or subscript injection. _SAFE_OPS = { ast.Add: operator.add, ast.Sub: operator.sub, ast.Mult: operator.mul, ast.Div: operator.truediv, ast.FloorDiv: operator.floordiv, ast.Mod: operator.mod, ast.Pow: operator.pow, ast.USub: operator.neg, ast.UAdd: operator.pos, ast.Not: operator.not_, ast.Eq: operator.eq, ast.NotEq: operator.ne, ast.Lt: operator.lt, ast.LtE: operator.le, ast.Gt: operator.gt, ast.GtE: operator.ge, ast.Is: operator.is_, ast.IsNot: operator.is_not, ast.In: lambda a, b: a in b, ast.NotIn: lambda a, b: a not in b, } def _safe_eval(expr: str) -> Any: """Evaluate a simple expression safely via AST whitelist. Supports: literals, comparisons (==, !=, <, >, <=, >=, in, not in, is, is not), boolean logic (and, or, not), arithmetic (+, -, *, /, //, %, **), and string operations (contains via ``in``). Raises ``ValueError`` on any disallowed construct (function calls, attribute access, imports, etc.). """ tree = ast.parse(expr.strip(), mode='eval') return _eval_node(tree.body) def _eval_node(node: ast.AST) -> Any: # Literals: numbers, strings, True/False/None if isinstance(node, ast.Constant): return node.value # Unary operators: -x, +x, not x if isinstance(node, ast.UnaryOp): op_fn = _SAFE_OPS.get(type(node.op)) if op_fn is None: raise ValueError(f'Unsupported unary op: {type(node.op).__name__}') return op_fn(_eval_node(node.operand)) # Binary operators: x + y, x * y, etc. if isinstance(node, ast.BinOp): op_fn = _SAFE_OPS.get(type(node.op)) if op_fn is None: raise ValueError(f'Unsupported binary op: {type(node.op).__name__}') return op_fn(_eval_node(node.left), _eval_node(node.right)) # Comparisons: x == y, x > y, x in y, etc. (chained) if isinstance(node, ast.Compare): left = _eval_node(node.left) for op, comparator in zip(node.ops, node.comparators): op_fn = _SAFE_OPS.get(type(op)) if op_fn is None: raise ValueError(f'Unsupported comparison: {type(op).__name__}') right = _eval_node(comparator) if not op_fn(left, right): return False left = right return True # Boolean operators: x and y, x or y if isinstance(node, ast.BoolOp): if isinstance(node.op, ast.And): return all(_eval_node(v) for v in node.values) if isinstance(node.op, ast.Or): return any(_eval_node(v) for v in node.values) # Ternary: x if cond else y if isinstance(node, ast.IfExp): return _eval_node(node.body) if _eval_node(node.test) else _eval_node(node.orelse) # Tuples / Lists (used in "x in [1,2,3]") if isinstance(node, (ast.Tuple, ast.List)): return [_eval_node(e) for e in node.elts] # Name lookup – only allow None, True, False if isinstance(node, ast.Name): if node.id == 'None': return None if node.id == 'True': return True if node.id == 'False': return False raise ValueError(f'Unsupported variable reference: {node.id}') raise ValueError(f'Unsupported expression node: {type(node).__name__}') class WorkflowExecutor: """ Workflow execution engine. Handles the execution of workflow definitions with proper control flow. """ def __init__(self, ap: Optional['app.Application'] = None): self.ap = ap self.registry = NodeTypeRegistry.instance() self._edges: list[EdgeDefinition] = [] async def execute( self, workflow: WorkflowDefinition, context: ExecutionContext, start_node_id: Optional[str] = None ) -> ExecutionContext: """ Execute a workflow. Args: workflow: Workflow definition context: Execution context start_node_id: Optional starting node (for resumption) Returns: Updated execution context """ context.status = ExecutionStatus.RUNNING context.start_time = datetime.now() try: # Build execution graph node_map = {node.id: node for node in workflow.nodes} edge_map = self._build_edge_map(workflow.edges) self._edges = workflow.edges # Initialize node states for node in workflow.nodes: if node.id not in context.node_states: context.node_states[node.id] = NodeState(node_id=node.id) # Find start node(s) if start_node_id: start_nodes = [node_map[start_node_id]] else: start_nodes = self._find_start_nodes(workflow.nodes, workflow.edges) if not start_nodes: raise ValueError('No start nodes found in workflow') # Execute from start nodes for start_node in start_nodes: await self._execute_from_node( start_node, node_map, edge_map, context, workflow.settings.max_retries, path=set() ) # Check final status all_completed = all( state.status in (NodeStatus.COMPLETED, NodeStatus.SKIPPED) for state in context.node_states.values() ) if all_completed: context.status = ExecutionStatus.COMPLETED else: # Some nodes might still be waiting has_failed = any(state.status == NodeStatus.FAILED for state in context.node_states.values()) if has_failed: context.status = ExecutionStatus.FAILED except Exception as e: context.status = ExecutionStatus.FAILED context.error = str(e) logger.error( 'Workflow execution failed', exc_info=True, extra={ 'workflow_id': workflow.uuid, 'execution_id': context.execution_id, 'node_states': { node_id: { 'status': state.status.value if state.status else None, 'error': state.error, } for node_id, state in context.node_states.items() }, }, ) finally: context.end_time = datetime.now() return context async def _execute_from_node( self, node: NodeDefinition, node_map: dict[str, NodeDefinition], edge_map: dict[str, list[EdgeDefinition]], context: ExecutionContext, max_retries: int = 3, path: set[str] | None = None, ): """Execute workflow starting from a specific node""" # Initialize path set for cycle detection (path-based, not global visited) if path is None: path = set() # Check for circular dependency on the *current path* only # This correctly allows diamond shapes (A→B, A→C, B→D, C→D) if node.id in path: logger.warning(f'Circular dependency detected at node: {node.id}') context.node_states[node.id].status = NodeStatus.SKIPPED context.node_states[node.id].error = 'Circular dependency detected' context.node_states[node.id].end_time = datetime.now() await self._persist_node_execution(node, context.node_states[node.id], context) return # Add node to current path path.add(node.id) # Check if node should be skipped if await self._should_skip_node(node, context): existing_state = context.node_states[node.id] if existing_state.status == NodeStatus.SKIPPED: existing_state.end_time = existing_state.end_time or datetime.now() await self._persist_node_execution(node, existing_state, context) path.discard(node.id) return # Execute current node await self._execute_node(node, context, max_retries) # If node failed and we should stop on error, return if context.node_states[node.id].status == NodeStatus.FAILED: path.discard(node.id) return node_state = context.node_states[node.id] node_type_name = node.type.split('.')[-1] if '.' in node.type else node.type # ── Control flow integration ──────────────────────────────── # For loop / iterator nodes: run the LoopExecutor over # downstream body nodes for each item, then continue to the # "completed" output edge. if node_type_name in ('loop', 'iterator'): items = node_state.outputs.get('_items') or [] if not items: # iterator: items come from inputs items = node_state.inputs.get('items', node_state.inputs.get('array', [])) if not isinstance(items, list): items = [items] if items else [] max_iter = int(node.config.get('max_iterations', 100)) items = items[:max_iter] # Collect downstream "body" nodes (connected via edges) outgoing_edges = edge_map.get(node.id, []) body_nodes = [] for edge in outgoing_edges: target = node_map.get(edge.target_node) if target: body_nodes.append(target) if body_nodes and items: loop_exec = LoopExecutor(self) results = await loop_exec.execute_loop(items, body_nodes, context, max_iter) node_state.outputs['results'] = results node_state.outputs['completed'] = True else: node_state.outputs['results'] = [] node_state.outputs['completed'] = True path.discard(node.id) return # body nodes already executed by LoopExecutor # For parallel nodes: run downstream branches concurrently if node_type_name == 'parallel': outgoing_edges = edge_map.get(node.id, []) branch_nodes = [] for edge in outgoing_edges: target = node_map.get(edge.target_node) if target: branch_nodes.append([target]) if branch_nodes: par_exec = ParallelExecutor(self) results = await par_exec.execute_parallel(branch_nodes, context) node_state.outputs['results'] = results path.discard(node.id) return # branch nodes already executed by ParallelExecutor # ── Standard edge-based continuation ──────────────────────── # Get outgoing edges outgoing_edges = edge_map.get(node.id, []) # Execute next nodes based on edge conditions for edge in outgoing_edges: target_node = node_map.get(edge.target_node) if not target_node: continue # Check edge condition if edge.condition: condition_met = await self._evaluate_condition(edge.condition, context) if not condition_met: continue # Check if all inputs are ready if await self._inputs_ready(target_node, edge_map, context): await self._execute_from_node(target_node, node_map, edge_map, context, max_retries, path) # Remove node from path when backtracking (allows diamond revisit) path.discard(node.id) async def _execute_node(self, node: NodeDefinition, context: ExecutionContext, max_retries: int = 3): """Execute a single node with retry logic""" node_state = context.node_states[node.id] node_state.status = NodeStatus.RUNNING node_state.start_time = datetime.now() # Get node instance (pass ap for access to services) node_instance = self.registry.create_instance(node.type, node.id, node.config, ap=self.ap) if not node_instance: node_state.status = NodeStatus.FAILED node_state.error = f'Unknown node type: {node.type}' node_state.end_time = datetime.now() self._record_execution_step(node, node_state, context) await self._persist_node_execution(node, node_state, context) return # Resolve inputs inputs = await self._resolve_inputs(node, context) node_state.inputs = inputs # Validate inputs validation_errors = await node_instance.validate_inputs(inputs) if validation_errors: node_state.status = NodeStatus.FAILED node_state.error = '; '.join(validation_errors) node_state.end_time = datetime.now() self._record_execution_step(node, node_state, context) await self._persist_node_execution(node, node_state, context) return # Execute with retries for attempt in range(max_retries + 1): try: outputs = await node_instance.execute(inputs, context) node_state.outputs = outputs node_state.status = NodeStatus.COMPLETED node_state.end_time = datetime.now() break except Exception as e: node_state.retry_count = attempt + 1 logger.error( f'Node {node.id} ({node.type}) execution failed (attempt {attempt + 1}/{max_retries + 1}): {e}', exc_info=True, extra={ 'node_id': node.id, 'node_type': node.type, 'attempt': attempt + 1, 'max_retries': max_retries, 'execution_id': context.execution_id, }, ) if attempt < max_retries: await asyncio.sleep(1) # Brief delay before retry else: node_state.status = NodeStatus.FAILED node_state.error = str(e) node_state.end_time = datetime.now() logger.error( f'Node {node.id} ({node.type}) permanently failed after {max_retries + 1} attempts', extra={ 'node_id': node.id, 'node_type': node.type, 'error': str(e), 'execution_id': context.execution_id, }, ) self._record_execution_step(node, node_state, context) await self._persist_node_execution(node, node_state, context) async def _resolve_inputs(self, node: NodeDefinition, context: ExecutionContext) -> dict[str, Any]: """Resolve input values for a node from connected nodes and context""" inputs = {} # Get inputs from context variables if 'message' in context.variables: inputs['message'] = context.variables['message'] # Get inputs from message context if context.message_context: inputs['message_content'] = context.message_context.message_content inputs['sender_id'] = context.message_context.sender_id inputs['platform'] = context.message_context.platform # Get inputs from node config that reference other nodes for key, value in node.config.items(): if isinstance(value, str) and value.startswith('{{') and value.endswith('}}'): resolved = await self._resolve_expression(value[2:-2], context) inputs[key] = resolved else: inputs[key] = value # Get inputs from connected upstream nodes via edges # Build a reverse map: for each incoming edge to this node, find the # source node and the specific source/target port. for edge in self._edges: if edge.target_node != node.id: continue source_state = context.node_states.get(edge.source_node) if not source_state or source_state.status != NodeStatus.COMPLETED: continue target_port = edge.target_port or 'input' source_port = edge.source_port or 'output' # Map the source node's output port value to this node's input port if source_port in source_state.outputs: inputs[target_port] = source_state.outputs[source_port] elif 'output' in source_state.outputs: # Fallback: if exact port not found, try generic 'output' inputs[target_port] = source_state.outputs['output'] elif source_state.outputs: # Last resort: use the first available output inputs[target_port] = next(iter(source_state.outputs.values())) return inputs async def _resolve_expression(self, expression: str, context: ExecutionContext) -> Any: """Resolve a variable expression like 'nodes.node1.outputs.text'""" parts = expression.strip().split('.') if not parts: return None if parts[0] == 'nodes' and len(parts) >= 4: # nodes.node_id.outputs.output_name node_id = parts[1] if parts[2] == 'outputs' and node_id in context.node_states: output_name = '.'.join(parts[3:]) return context.node_states[node_id].outputs.get(output_name) elif parts[0] == 'variables': # variables.var_name var_name = '.'.join(parts[1:]) return context.variables.get(var_name) elif parts[0] == 'conversation_variables': # conversation_variables.var_name var_name = '.'.join(parts[1:]) return context.conversation_variables.get(var_name) elif parts[0] == 'message': # message.content, message.sender_id, etc. if context.message_context: attr = parts[1] if len(parts) > 1 else None if attr == 'content': return context.message_context.message_content elif attr == 'sender_id': return context.message_context.sender_id elif attr == 'platform': return context.message_context.platform elif attr == 'conversation_id': return context.message_context.conversation_id return None async def _evaluate_condition(self, condition: str, context: ExecutionContext) -> bool: """Evaluate a condition expression safely using AST whitelist""" try: # Resolve variable references in condition if '{{' in condition: import re pattern = r'\{\{([^}]+)\}\}' # First pass: replace all variable references with placeholders placeholders = {} placeholder_idx = 0 def replace_with_placeholder(match): nonlocal placeholder_idx var_expr = match.group(1) placeholder = f'__PH{placeholder_idx}__' placeholders[placeholder] = var_expr placeholder_idx += 1 return placeholder condition_with_placeholders = re.sub(pattern, replace_with_placeholder, condition) # Second pass: resolve each placeholder asynchronously for placeholder, var_expr in placeholders.items(): value = await self._resolve_expression(var_expr, context) if isinstance(value, str): condition_with_placeholders = condition_with_placeholders.replace(placeholder, f'"{value}"') elif value is None: condition_with_placeholders = condition_with_placeholders.replace(placeholder, 'None') else: condition_with_placeholders = condition_with_placeholders.replace(placeholder, str(value)) condition = condition_with_placeholders # Safe expression evaluation using AST whitelist result = _safe_eval(condition) return bool(result) except Exception as e: logger.warning(f'Condition evaluation failed: {condition} - {e}') return False async def _should_skip_node(self, node: NodeDefinition, context: ExecutionContext) -> bool: """Check if a node should be skipped""" state = context.node_states.get(node.id) if state and state.status in (NodeStatus.COMPLETED, NodeStatus.RUNNING, NodeStatus.SKIPPED): return True return False async def _inputs_ready( self, node: NodeDefinition, edge_map: dict[str, list[EdgeDefinition]], context: ExecutionContext ) -> bool: """Check if all inputs for a node are ready""" # Find all edges that connect to this node incoming_nodes = set() for source_id, edges in edge_map.items(): for edge in edges: if edge.target_node == node.id: incoming_nodes.add(source_id) # Check if all incoming nodes have completed for source_id in incoming_nodes: state = context.node_states.get(source_id) if not state or state.status not in (NodeStatus.COMPLETED, NodeStatus.SKIPPED): return False return True def _find_start_nodes(self, nodes: list[NodeDefinition], edges: list[EdgeDefinition]) -> list[NodeDefinition]: """Find nodes that have no incoming edges (start nodes)""" target_nodes = {edge.target_node for edge in edges} start_nodes = [node for node in nodes if node.id not in target_nodes] # Also check for trigger nodes trigger_types = {'message_trigger', 'cron_trigger', 'webhook_trigger', 'event_trigger'} for node in nodes: if node.type in trigger_types and node not in start_nodes: start_nodes.insert(0, node) return start_nodes def _build_edge_map(self, edges: list[EdgeDefinition]) -> dict[str, list[EdgeDefinition]]: """Build a map of source node ID to outgoing edges""" edge_map: dict[str, list[EdgeDefinition]] = {} for edge in edges: if edge.source_node not in edge_map: edge_map[edge.source_node] = [] edge_map[edge.source_node].append(edge) return edge_map def _record_execution_step(self, node: NodeDefinition, node_state: NodeState, context: ExecutionContext): """Record an execution step in the history""" duration_ms = 0 if node_state.start_time and node_state.end_time: duration_ms = int((node_state.end_time - node_state.start_time).total_seconds() * 1000) step = ExecutionStep( timestamp=datetime.now(), node_id=node.id, node_type=node.type, status=node_state.status.value, inputs=node_state.inputs, outputs=node_state.outputs, duration_ms=duration_ms, error=node_state.error, ) context.history.append(step) async def _persist_node_execution( self, node: NodeDefinition, node_state: NodeState, context: ExecutionContext, ): """Persist node execution state for execution detail and logs.""" if not self.ap: return values = { 'execution_uuid': context.execution_id, 'node_id': node.id, 'node_type': node.type, 'status': node_state.status.value, 'inputs': node_state.inputs, 'outputs': node_state.outputs, 'start_time': node_state.start_time, 'end_time': node_state.end_time, 'error': node_state.error, 'retry_count': node_state.retry_count, } existing_query = sqlalchemy.select(persistence_workflow.WorkflowNodeExecution).where( persistence_workflow.WorkflowNodeExecution.execution_uuid == context.execution_id, persistence_workflow.WorkflowNodeExecution.node_id == node.id, ) existing_result = await self.ap.persistence_mgr.execute_async(existing_query) existing = existing_result.first() if existing is None: await self.ap.persistence_mgr.execute_async( sqlalchemy.insert(persistence_workflow.WorkflowNodeExecution).values(**values) ) else: await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_workflow.WorkflowNodeExecution) .where(persistence_workflow.WorkflowNodeExecution.id == existing.id) .values(**values) ) class ParallelExecutor: """Execute multiple branches in parallel""" def __init__(self, executor: WorkflowExecutor): self.executor = executor async def execute_parallel( self, branches: list[list[NodeDefinition]], context: ExecutionContext ) -> list[dict[str, Any]]: """ Execute multiple branches in parallel. Args: branches: List of node sequences to execute in parallel context: Execution context Returns: List of results from each branch """ tasks = [] for branch in branches: task = self._execute_branch(branch, context) tasks.append(task) results = await asyncio.gather(*tasks, return_exceptions=True) processed_results = [] for result in results: if isinstance(result, Exception): processed_results.append({'error': str(result)}) else: processed_results.append(result) return processed_results async def _execute_branch(self, nodes: list[NodeDefinition], context: ExecutionContext) -> dict[str, Any]: """Execute a single branch""" # Create a copy of context for this branch branch_outputs = {} for node in nodes: await self.executor._execute_node(node, context, max_retries=3) state = context.node_states.get(node.id) if state and state.status == NodeStatus.COMPLETED: branch_outputs[node.id] = state.outputs elif state and state.status == NodeStatus.FAILED: branch_outputs['error'] = state.error break return branch_outputs class LoopExecutor: """Execute loop iterations""" def __init__(self, executor: WorkflowExecutor): self.executor = executor async def execute_loop( self, items: list[Any], loop_body: list[NodeDefinition], context: ExecutionContext, max_iterations: int = 100 ) -> list[dict[str, Any]]: """ Execute a loop over items. Args: items: Items to iterate over loop_body: Nodes to execute for each item context: Execution context max_iterations: Maximum number of iterations Returns: List of results from each iteration """ results = [] for i, item in enumerate(items[:max_iterations]): # Set loop variables context.variables['loop_item'] = item context.variables['loop_index'] = i context.variables['loop_is_first'] = i == 0 context.variables['loop_is_last'] = i == len(items) - 1 iteration_result = {} for node in loop_body: # Reset node state for this iteration context.node_states[node.id] = NodeState(node_id=node.id) await self.executor._execute_node(node, context, max_retries=3) state = context.node_states.get(node.id) if state: iteration_result[node.id] = state.outputs # Check for break condition if state.outputs.get('break', False): results.append(iteration_result) return results results.append(iteration_result) # Clean up loop variables context.variables.pop('loop_item', None) context.variables.pop('loop_index', None) context.variables.pop('loop_is_first', None) context.variables.pop('loop_is_last', None) return results class DebugWorkflowExecutor(WorkflowExecutor): """ Debug-enabled workflow executor with step-by-step execution support. Extends WorkflowExecutor with debugging capabilities. """ # Class-level storage for active debug sessions _debug_states: dict[str, DebugExecutionState] = {} def __init__(self, ap: Optional['app.Application'] = None): super().__init__(ap) @classmethod def get_debug_state(cls, execution_id: str) -> Optional[DebugExecutionState]: """Get debug state for an execution""" return cls._debug_states.get(execution_id) @classmethod def create_debug_state(cls, execution_id: str, breakpoints: list[str] = None) -> DebugExecutionState: """Create a new debug state""" state = DebugExecutionState(execution_id, breakpoints) cls._debug_states[execution_id] = state return state @classmethod def remove_debug_state(cls, execution_id: str): """Remove debug state for an execution""" cls._debug_states.pop(execution_id, None) async def execute_debug( self, workflow: WorkflowDefinition, context: ExecutionContext, debug_state: DebugExecutionState, ) -> ExecutionContext: """ Execute a workflow in debug mode. Args: workflow: Workflow definition context: Execution context debug_state: Debug execution state Returns: Updated execution context """ context.status = ExecutionStatus.RUNNING context.start_time = datetime.now() debug_state.add_log('info', f'Starting debug execution for workflow: {workflow.name}') try: # Build execution graph node_map = {node.id: node for node in workflow.nodes} edge_map = self._build_edge_map(workflow.edges) self._edges = workflow.edges # Initialize node states for node in workflow.nodes: if node.id not in context.node_states: context.node_states[node.id] = NodeState(node_id=node.id) # Find start node(s) start_nodes = self._find_start_nodes(workflow.nodes, workflow.edges) if not start_nodes: raise ValueError('No start nodes found in workflow') debug_state.add_log('info', f'Found {len(start_nodes)} start node(s)') # Execute from start nodes for start_node in start_nodes: if debug_state.is_stopped: break await self._execute_debug_from_node( start_node, node_map, edge_map, context, debug_state, workflow.settings.max_retries ) # Set final status if debug_state.is_stopped: context.status = ExecutionStatus.CANCELLED debug_state.status = 'cancelled' else: all_completed = all( state.status in (NodeStatus.COMPLETED, NodeStatus.SKIPPED) for state in context.node_states.values() ) if all_completed: context.status = ExecutionStatus.COMPLETED debug_state.status = 'completed' debug_state.add_log('info', 'Workflow execution completed successfully') else: has_failed = any(state.status == NodeStatus.FAILED for state in context.node_states.values()) if has_failed: context.status = ExecutionStatus.FAILED debug_state.status = 'error' except Exception as e: context.status = ExecutionStatus.FAILED context.error = str(e) debug_state.status = 'error' debug_state.add_log('error', f'Workflow execution failed: {e}', data={'traceback': traceback.format_exc()}) logger.error(f'Debug workflow execution failed: {e}\n{traceback.format_exc()}') finally: context.end_time = datetime.now() return context async def _execute_debug_from_node( self, node: NodeDefinition, node_map: dict[str, NodeDefinition], edge_map: dict[str, list[EdgeDefinition]], context: ExecutionContext, debug_state: DebugExecutionState, max_retries: int = 3, ): """Execute workflow from a node with debug support""" # Check if stopped if debug_state.is_stopped: return # Wait if paused await debug_state.wait_if_paused() # Check if should skip if await self._should_skip_node(node, context): if context.node_states[node.id].status == NodeStatus.SKIPPED: debug_state.add_log('info', f'Skipping node: {node.id}', node_id=node.id) return # Check breakpoint if debug_state.check_breakpoint(node.id): debug_state.add_log('info', f'Hit breakpoint at node: {node.id}', node_id=node.id) debug_state.pause() await debug_state.wait_if_paused() # Update current node debug_state.current_node_id = node.id debug_state.add_log('info', f'Executing node: {node.id} ({node.type})', node_id=node.id) # Execute node await self._execute_debug_node(node, context, debug_state, max_retries) # Check if stopped or failed if debug_state.is_stopped: return if context.node_states[node.id].status == NodeStatus.FAILED: return # Get outgoing edges outgoing_edges = edge_map.get(node.id, []) # Execute next nodes for edge in outgoing_edges: if debug_state.is_stopped: break target_node = node_map.get(edge.target_node) if not target_node: continue # Check edge condition if edge.condition: condition_met = await self._evaluate_condition(edge.condition, context) if not condition_met: debug_state.add_log('debug', f'Edge condition not met: {edge.condition}', node_id=node.id) continue # Check if all inputs are ready if await self._inputs_ready(target_node, edge_map, context): await self._execute_debug_from_node(target_node, node_map, edge_map, context, debug_state, max_retries) async def _execute_debug_node( self, node: NodeDefinition, context: ExecutionContext, debug_state: DebugExecutionState, max_retries: int = 3 ): """Execute a single node with debug logging""" node_state = context.node_states[node.id] node_state.status = NodeStatus.RUNNING node_state.start_time = datetime.now() # Get node instance (pass ap for access to services) node_instance = self.registry.create_instance(node.type, node.id, node.config, ap=self.ap) if not node_instance: node_state.status = NodeStatus.FAILED node_state.error = f'Unknown node type: {node.type}' node_state.end_time = datetime.now() debug_state.add_log('error', f'Unknown node type: {node.type}', node_id=node.id) self._record_execution_step(node, node_state, context) await self._persist_node_execution(node, node_state, context) return # Resolve inputs inputs = await self._resolve_inputs(node, context) node_state.inputs = inputs debug_state.add_log( 'debug', 'Node inputs resolved', node_id=node.id, data={'inputs': self._safe_serialize(inputs)} ) # Validate inputs validation_errors = await node_instance.validate_inputs(inputs) if validation_errors: node_state.status = NodeStatus.FAILED node_state.error = '; '.join(validation_errors) node_state.end_time = datetime.now() debug_state.add_log('error', f'Input validation failed: {node_state.error}', node_id=node.id) self._record_execution_step(node, node_state, context) await self._persist_node_execution(node, node_state, context) return # Execute with retries for attempt in range(max_retries + 1): if debug_state.is_stopped: node_state.status = NodeStatus.FAILED node_state.error = 'Execution stopped' node_state.end_time = datetime.now() break try: outputs = await node_instance.execute(inputs, context) node_state.outputs = outputs node_state.status = NodeStatus.COMPLETED node_state.end_time = datetime.now() duration_ms = int((node_state.end_time - node_state.start_time).total_seconds() * 1000) debug_state.add_log( 'info', f'Node completed in {duration_ms}ms', node_id=node.id, data={'outputs': self._safe_serialize(outputs), 'duration_ms': duration_ms}, ) break except Exception as e: node_state.retry_count = attempt + 1 debug_state.add_log( 'warning', f'Node execution failed (attempt {attempt + 1}/{max_retries + 1}): {e}', node_id=node.id ) if attempt < max_retries: await asyncio.sleep(1) else: node_state.status = NodeStatus.FAILED node_state.error = str(e) node_state.end_time = datetime.now() debug_state.add_log( 'error', f'Node failed after {max_retries + 1} attempts: {e}', node_id=node.id, data={'error': str(e), 'traceback': traceback.format_exc()}, ) self._record_execution_step(node, node_state, context) await self._persist_node_execution(node, node_state, context) async def step_execute( self, workflow: WorkflowDefinition, context: ExecutionContext, debug_state: DebugExecutionState, ) -> dict: """ Execute one step (one node) in debug mode. Returns: Dict with node_id, node_state, and completed status """ # Find next node to execute next_node = self._find_next_executable_node(workflow, context) if not next_node: debug_state.status = 'completed' return {'completed': True} # Execute single node debug_state.current_node_id = next_node.id await self._execute_debug_node(next_node, context, debug_state, workflow.settings.max_retries) node_state = context.node_states.get(next_node.id) # Check if workflow is complete all_done = all( state.status in (NodeStatus.COMPLETED, NodeStatus.SKIPPED, NodeStatus.FAILED) for state in context.node_states.values() ) if all_done: debug_state.status = 'completed' context.status = ExecutionStatus.COMPLETED return { 'node_id': next_node.id, 'node_state': { 'status': node_state.status.value if node_state else 'unknown', 'inputs': self._safe_serialize(node_state.inputs) if node_state else {}, 'outputs': self._safe_serialize(node_state.outputs) if node_state else {}, 'error': node_state.error if node_state else None, }, 'completed': all_done, } def _find_next_executable_node( self, workflow: WorkflowDefinition, context: ExecutionContext ) -> Optional[NodeDefinition]: """Find the next node that can be executed""" edge_map = self._build_edge_map(workflow.edges) for node in workflow.nodes: state = context.node_states.get(node.id) # Skip completed, running, or failed nodes if state and state.status in ( NodeStatus.COMPLETED, NodeStatus.RUNNING, NodeStatus.FAILED, NodeStatus.SKIPPED, ): continue # Check if this node's inputs are ready incoming_nodes = set() for source_id, edges in edge_map.items(): for edge in edges: if edge.target_node == node.id: incoming_nodes.add(source_id) # If no incoming nodes, it's a start node if not incoming_nodes: return node # Check if all incoming nodes are done all_incoming_done = True for source_id in incoming_nodes: source_state = context.node_states.get(source_id) if not source_state or source_state.status not in (NodeStatus.COMPLETED, NodeStatus.SKIPPED): all_incoming_done = False break if all_incoming_done: return node return None def _safe_serialize(self, data: Any) -> Any: """Safely serialize data for logging""" if data is None: return None if isinstance(data, (str, int, float, bool)): return data if isinstance(data, (list, tuple)): return [self._safe_serialize(item) for item in data[:100]] # Limit list size if isinstance(data, dict): result = {} for key, value in list(data.items())[:50]: # Limit dict size result[str(key)] = self._safe_serialize(value) return result # For complex objects, try to convert to string try: return str(data)[:1000] # Limit string length except Exception: return '' def get_execution_state(self, context: ExecutionContext, debug_state: DebugExecutionState) -> dict: """Get current execution state for API response""" node_states = {} for node_id, state in context.node_states.items(): node_states[node_id] = { 'status': state.status.value, 'inputs': self._safe_serialize(state.inputs), 'outputs': self._safe_serialize(state.outputs), 'error': state.error, 'startTime': state.start_time.isoformat() if state.start_time else None, 'endTime': state.end_time.isoformat() if state.end_time else None, 'duration': int((state.end_time - state.start_time).total_seconds() * 1000) if state.start_time and state.end_time else None, } return { 'status': debug_state.status, 'current_node_id': debug_state.current_node_id, 'node_states': node_states, 'new_logs': debug_state.get_pending_logs(), 'error': context.error, }