mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
change
This commit is contained in:
@@ -342,6 +342,23 @@ class WorkflowService:
|
||||
|
||||
if trigger_type == 'message':
|
||||
message_context_data = raw_trigger_data.get('message_context') or {}
|
||||
# Fallback: if message_context is missing but trigger_data has 'message',
|
||||
# construct a minimal message_context so rerun and downstream nodes work.
|
||||
if not message_context_data and raw_trigger_data.get('message'):
|
||||
raw_msg = raw_trigger_data['message']
|
||||
message_context_data = {
|
||||
'message_id': str(raw_trigger_data.get('message_id', execution_uuid)),
|
||||
'message_content': raw_msg if isinstance(raw_msg, str) else str(raw_msg),
|
||||
'sender_id': str(raw_trigger_data.get('sender_id', '')),
|
||||
'sender_name': str(raw_trigger_data.get('sender_name', 'User')),
|
||||
'platform': str(raw_trigger_data.get('platform', '')),
|
||||
'conversation_id': str(raw_trigger_data.get('connection_id', '')),
|
||||
'is_group': bool(raw_trigger_data.get('is_group', False)),
|
||||
'group_id': raw_trigger_data.get('group_id'),
|
||||
'mentions': raw_trigger_data.get('mentions', []),
|
||||
'reply_to': raw_trigger_data.get('reply_to'),
|
||||
'raw_message': raw_trigger_data.get('raw_message', {}),
|
||||
}
|
||||
if message_context_data:
|
||||
context.message_context = MessageContext(
|
||||
message_id=str(message_context_data.get('message_id', execution_uuid)),
|
||||
|
||||
@@ -44,6 +44,11 @@ def __getattr__(name: str) -> Any:
|
||||
|
||||
return WorkflowExecutor
|
||||
|
||||
if name in ('DebugWorkflowExecutor', 'DebugExecutionState', 'ExecutionLog'):
|
||||
from . import debug
|
||||
|
||||
return getattr(debug, name)
|
||||
|
||||
if name == 'nodes':
|
||||
return import_module('.nodes', __name__)
|
||||
|
||||
@@ -69,4 +74,8 @@ __all__ = [
|
||||
'NodeTypeRegistry',
|
||||
# Executor
|
||||
'WorkflowExecutor',
|
||||
# Debug
|
||||
'DebugWorkflowExecutor',
|
||||
'DebugExecutionState',
|
||||
'ExecutionLog',
|
||||
]
|
||||
|
||||
509
src/langbot/pkg/workflow/debug.py
Normal file
509
src/langbot/pkg/workflow/debug.py
Normal file
@@ -0,0 +1,509 @@
|
||||
"""Workflow debug execution support.
|
||||
|
||||
This module provides debugging capabilities for workflow execution, including:
|
||||
- ExecutionLog: Structured log entries for execution tracking
|
||||
- DebugExecutionState: State management for debug sessions (pause, resume, breakpoints)
|
||||
- DebugWorkflowExecutor: Extended executor with step-by-step debugging support
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
from .entities import (
|
||||
WorkflowDefinition,
|
||||
NodeDefinition,
|
||||
EdgeDefinition,
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
NodeState,
|
||||
NodeStatus,
|
||||
)
|
||||
from .executor import WorkflowExecutor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..core import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExecutionLog:
|
||||
"""Execution log entry"""
|
||||
|
||||
def __init__(self, level: str, message: str, node_id: Optional[str] = None, data: Optional[dict] = None):
|
||||
self.id = str(uuid.uuid4())
|
||||
self.timestamp = datetime.now().isoformat()
|
||||
self.level = level
|
||||
self.message = message
|
||||
self.node_id = node_id
|
||||
self.data = data or {}
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'id': self.id,
|
||||
'timestamp': self.timestamp,
|
||||
'level': self.level,
|
||||
'message': self.message,
|
||||
'node_id': self.node_id,
|
||||
'data': self.data,
|
||||
}
|
||||
|
||||
|
||||
class DebugExecutionState:
|
||||
"""State for a debug execution"""
|
||||
|
||||
def __init__(self, execution_id: str, breakpoints: list[str] = None):
|
||||
self.execution_id = execution_id
|
||||
self.status: str = 'running'
|
||||
self.is_paused: bool = False
|
||||
self.is_stopped: bool = False
|
||||
self.current_node_id: Optional[str] = None
|
||||
self.breakpoints: set[str] = set(breakpoints or [])
|
||||
self.logs: list[ExecutionLog] = []
|
||||
self.pending_logs: list[ExecutionLog] = []
|
||||
self._pause_event = asyncio.Event()
|
||||
self._pause_event.set() # Initially not paused
|
||||
self._stop_event = asyncio.Event()
|
||||
|
||||
def add_log(self, level: str, message: str, node_id: str = None, data: dict = None):
|
||||
"""Add a log entry"""
|
||||
log = ExecutionLog(level, message, node_id, data)
|
||||
self.logs.append(log)
|
||||
self.pending_logs.append(log)
|
||||
logger.log(
|
||||
getattr(logging, level.upper(), logging.INFO),
|
||||
f'[Workflow Debug] {message}',
|
||||
extra={'node_id': node_id, 'data': data},
|
||||
)
|
||||
|
||||
def get_pending_logs(self) -> list[dict]:
|
||||
"""Get and clear pending logs"""
|
||||
logs = [log.to_dict() for log in self.pending_logs]
|
||||
self.pending_logs = []
|
||||
return logs
|
||||
|
||||
def pause(self):
|
||||
"""Pause execution"""
|
||||
self.is_paused = True
|
||||
self._pause_event.clear()
|
||||
self.add_log('info', 'Execution paused')
|
||||
|
||||
def resume(self):
|
||||
"""Resume execution"""
|
||||
self.is_paused = False
|
||||
self._pause_event.set()
|
||||
self.add_log('info', 'Execution resumed')
|
||||
|
||||
def stop(self):
|
||||
"""Stop execution"""
|
||||
self.is_stopped = True
|
||||
self.status = 'cancelled'
|
||||
self._stop_event.set()
|
||||
self._pause_event.set() # Release any pause
|
||||
self.add_log('info', 'Execution stopped')
|
||||
|
||||
async def wait_if_paused(self):
|
||||
"""Wait if execution is paused"""
|
||||
if self.is_paused:
|
||||
self.add_log('info', 'Waiting for resume...')
|
||||
await self._pause_event.wait()
|
||||
|
||||
def check_breakpoint(self, node_id: str) -> bool:
|
||||
"""Check if there's a breakpoint at the given node"""
|
||||
return node_id in self.breakpoints
|
||||
|
||||
|
||||
class DebugWorkflowExecutor(WorkflowExecutor):
|
||||
"""
|
||||
Debug-enabled workflow executor with step-by-step execution support.
|
||||
Extends WorkflowExecutor with debugging capabilities.
|
||||
"""
|
||||
|
||||
# Class-level storage for active debug sessions
|
||||
_debug_states: dict[str, DebugExecutionState] = {}
|
||||
|
||||
def __init__(self, ap: Optional['app.Application'] = None):
|
||||
super().__init__(ap)
|
||||
|
||||
@classmethod
|
||||
def get_debug_state(cls, execution_id: str) -> Optional[DebugExecutionState]:
|
||||
"""Get debug state for an execution"""
|
||||
return cls._debug_states.get(execution_id)
|
||||
|
||||
@classmethod
|
||||
def create_debug_state(cls, execution_id: str, breakpoints: list[str] = None) -> DebugExecutionState:
|
||||
"""Create a new debug state"""
|
||||
state = DebugExecutionState(execution_id, breakpoints)
|
||||
cls._debug_states[execution_id] = state
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
def remove_debug_state(cls, execution_id: str):
|
||||
"""Remove debug state for an execution"""
|
||||
cls._debug_states.pop(execution_id, None)
|
||||
|
||||
async def execute_debug(
|
||||
self,
|
||||
workflow: WorkflowDefinition,
|
||||
context: ExecutionContext,
|
||||
debug_state: DebugExecutionState,
|
||||
) -> ExecutionContext:
|
||||
"""
|
||||
Execute a workflow in debug mode.
|
||||
|
||||
Args:
|
||||
workflow: Workflow definition
|
||||
context: Execution context
|
||||
debug_state: Debug execution state
|
||||
|
||||
Returns:
|
||||
Updated execution context
|
||||
"""
|
||||
context.status = ExecutionStatus.RUNNING
|
||||
context.start_time = datetime.now()
|
||||
debug_state.add_log('info', f'Starting debug execution for workflow: {workflow.name}')
|
||||
|
||||
try:
|
||||
# Build execution graph
|
||||
node_map = {node.id: node for node in workflow.nodes}
|
||||
edge_map = self._build_edge_map(workflow.edges)
|
||||
self._edges = workflow.edges
|
||||
|
||||
# Initialize node states
|
||||
for node in workflow.nodes:
|
||||
if node.id not in context.node_states:
|
||||
context.node_states[node.id] = NodeState(node_id=node.id)
|
||||
|
||||
# Find start node(s)
|
||||
start_nodes = self._find_start_nodes(workflow.nodes, workflow.edges)
|
||||
|
||||
if not start_nodes:
|
||||
raise ValueError('No start nodes found in workflow')
|
||||
|
||||
debug_state.add_log('info', f'Found {len(start_nodes)} start node(s)')
|
||||
|
||||
# Execute from start nodes
|
||||
for start_node in start_nodes:
|
||||
if debug_state.is_stopped:
|
||||
break
|
||||
|
||||
await self._execute_debug_from_node(
|
||||
start_node, node_map, edge_map, context, debug_state, workflow.settings.max_retries
|
||||
)
|
||||
|
||||
# Set final status
|
||||
if debug_state.is_stopped:
|
||||
context.status = ExecutionStatus.CANCELLED
|
||||
debug_state.status = 'cancelled'
|
||||
else:
|
||||
all_completed = all(
|
||||
state.status in (NodeStatus.COMPLETED, NodeStatus.SKIPPED) for state in context.node_states.values()
|
||||
)
|
||||
|
||||
if all_completed:
|
||||
context.status = ExecutionStatus.COMPLETED
|
||||
debug_state.status = 'completed'
|
||||
debug_state.add_log('info', 'Workflow execution completed successfully')
|
||||
else:
|
||||
has_failed = any(state.status == NodeStatus.FAILED for state in context.node_states.values())
|
||||
if has_failed:
|
||||
context.status = ExecutionStatus.FAILED
|
||||
debug_state.status = 'error'
|
||||
|
||||
except Exception as e:
|
||||
context.status = ExecutionStatus.FAILED
|
||||
context.error = str(e)
|
||||
debug_state.status = 'error'
|
||||
debug_state.add_log('error', f'Workflow execution failed: {e}', data={'traceback': traceback.format_exc()})
|
||||
logger.error(f'Debug workflow execution failed: {e}\n{traceback.format_exc()}')
|
||||
|
||||
finally:
|
||||
context.end_time = datetime.now()
|
||||
|
||||
return context
|
||||
|
||||
async def _execute_debug_from_node(
|
||||
self,
|
||||
node: NodeDefinition,
|
||||
node_map: dict[str, NodeDefinition],
|
||||
edge_map: dict[str, list[EdgeDefinition]],
|
||||
context: ExecutionContext,
|
||||
debug_state: DebugExecutionState,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
"""Execute workflow from a node with debug support"""
|
||||
|
||||
# Check if stopped
|
||||
if debug_state.is_stopped:
|
||||
return
|
||||
|
||||
# Wait if paused
|
||||
await debug_state.wait_if_paused()
|
||||
|
||||
# Check if should skip
|
||||
if await self._should_skip_node(node, context):
|
||||
if context.node_states[node.id].status == NodeStatus.SKIPPED:
|
||||
debug_state.add_log('info', f'Skipping node: {node.id}', node_id=node.id)
|
||||
return
|
||||
|
||||
# Check breakpoint
|
||||
if debug_state.check_breakpoint(node.id):
|
||||
debug_state.add_log('info', f'Hit breakpoint at node: {node.id}', node_id=node.id)
|
||||
debug_state.pause()
|
||||
await debug_state.wait_if_paused()
|
||||
|
||||
# Update current node
|
||||
debug_state.current_node_id = node.id
|
||||
debug_state.add_log('info', f'Executing node: {node.id} ({node.type})', node_id=node.id)
|
||||
|
||||
# Execute node
|
||||
await self._execute_debug_node(node, context, debug_state, max_retries)
|
||||
|
||||
# Check if stopped or failed
|
||||
if debug_state.is_stopped:
|
||||
return
|
||||
if context.node_states[node.id].status == NodeStatus.FAILED:
|
||||
return
|
||||
|
||||
# Get outgoing edges
|
||||
outgoing_edges = edge_map.get(node.id, [])
|
||||
|
||||
# Execute next nodes
|
||||
for edge in outgoing_edges:
|
||||
if debug_state.is_stopped:
|
||||
break
|
||||
|
||||
target_node = node_map.get(edge.target_node)
|
||||
if not target_node:
|
||||
continue
|
||||
|
||||
# Check edge condition
|
||||
if edge.condition:
|
||||
condition_met = await self._evaluate_condition(edge.condition, context)
|
||||
if not condition_met:
|
||||
debug_state.add_log('debug', f'Edge condition not met: {edge.condition}', node_id=node.id)
|
||||
continue
|
||||
|
||||
# Check if all inputs are ready
|
||||
if await self._inputs_ready(target_node, edge_map, context):
|
||||
await self._execute_debug_from_node(target_node, node_map, edge_map, context, debug_state, max_retries)
|
||||
|
||||
async def _execute_debug_node(
|
||||
self, node: NodeDefinition, context: ExecutionContext, debug_state: DebugExecutionState, max_retries: int = 3
|
||||
):
|
||||
"""Execute a single node with debug logging"""
|
||||
|
||||
node_state = context.node_states[node.id]
|
||||
node_state.status = NodeStatus.RUNNING
|
||||
node_state.start_time = datetime.now()
|
||||
|
||||
# Get node instance (pass ap for access to services)
|
||||
node_instance = self.registry.create_instance(node.type, node.id, node.config, ap=self.ap)
|
||||
|
||||
if not node_instance:
|
||||
node_state.status = NodeStatus.FAILED
|
||||
node_state.error = f'Unknown node type: {node.type}'
|
||||
node_state.end_time = datetime.now()
|
||||
debug_state.add_log('error', f'Unknown node type: {node.type}', node_id=node.id)
|
||||
self._record_execution_step(node, node_state, context)
|
||||
await self._persist_node_execution(node, node_state, context)
|
||||
return
|
||||
|
||||
# Resolve inputs
|
||||
inputs = await self._resolve_inputs(node, context)
|
||||
node_state.inputs = inputs
|
||||
debug_state.add_log(
|
||||
'debug', 'Node inputs resolved', node_id=node.id, data={'inputs': self._safe_serialize(inputs)}
|
||||
)
|
||||
|
||||
# Validate inputs
|
||||
validation_errors = await node_instance.validate_inputs(inputs)
|
||||
if validation_errors:
|
||||
node_state.status = NodeStatus.FAILED
|
||||
node_state.error = '; '.join(validation_errors)
|
||||
node_state.end_time = datetime.now()
|
||||
debug_state.add_log('error', f'Input validation failed: {node_state.error}', node_id=node.id)
|
||||
self._record_execution_step(node, node_state, context)
|
||||
await self._persist_node_execution(node, node_state, context)
|
||||
return
|
||||
|
||||
# Execute with retries
|
||||
for attempt in range(max_retries + 1):
|
||||
if debug_state.is_stopped:
|
||||
node_state.status = NodeStatus.FAILED
|
||||
node_state.error = 'Execution stopped'
|
||||
node_state.end_time = datetime.now()
|
||||
break
|
||||
|
||||
try:
|
||||
outputs = await node_instance.execute(inputs, context)
|
||||
node_state.outputs = outputs
|
||||
node_state.status = NodeStatus.COMPLETED
|
||||
node_state.end_time = datetime.now()
|
||||
|
||||
duration_ms = int((node_state.end_time - node_state.start_time).total_seconds() * 1000)
|
||||
debug_state.add_log(
|
||||
'info',
|
||||
f'Node completed in {duration_ms}ms',
|
||||
node_id=node.id,
|
||||
data={'outputs': self._safe_serialize(outputs), 'duration_ms': duration_ms},
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
node_state.retry_count = attempt + 1
|
||||
debug_state.add_log(
|
||||
'warning', f'Node execution failed (attempt {attempt + 1}/{max_retries + 1}): {e}', node_id=node.id
|
||||
)
|
||||
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
node_state.status = NodeStatus.FAILED
|
||||
node_state.error = str(e)
|
||||
node_state.end_time = datetime.now()
|
||||
debug_state.add_log(
|
||||
'error',
|
||||
f'Node failed after {max_retries + 1} attempts: {e}',
|
||||
node_id=node.id,
|
||||
data={'error': str(e), 'traceback': traceback.format_exc()},
|
||||
)
|
||||
|
||||
self._record_execution_step(node, node_state, context)
|
||||
await self._persist_node_execution(node, node_state, context)
|
||||
|
||||
async def step_execute(
|
||||
self,
|
||||
workflow: WorkflowDefinition,
|
||||
context: ExecutionContext,
|
||||
debug_state: DebugExecutionState,
|
||||
) -> dict:
|
||||
"""
|
||||
Execute one step (one node) in debug mode.
|
||||
|
||||
Returns:
|
||||
Dict with node_id, node_state, and completed status
|
||||
"""
|
||||
# Find next node to execute
|
||||
next_node = self._find_next_executable_node(workflow, context)
|
||||
|
||||
if not next_node:
|
||||
debug_state.status = 'completed'
|
||||
return {'completed': True}
|
||||
|
||||
# Execute single node
|
||||
debug_state.current_node_id = next_node.id
|
||||
await self._execute_debug_node(next_node, context, debug_state, workflow.settings.max_retries)
|
||||
|
||||
node_state = context.node_states.get(next_node.id)
|
||||
|
||||
# Check if workflow is complete
|
||||
all_done = all(
|
||||
state.status in (NodeStatus.COMPLETED, NodeStatus.SKIPPED, NodeStatus.FAILED)
|
||||
for state in context.node_states.values()
|
||||
)
|
||||
|
||||
if all_done:
|
||||
debug_state.status = 'completed'
|
||||
context.status = ExecutionStatus.COMPLETED
|
||||
|
||||
return {
|
||||
'node_id': next_node.id,
|
||||
'node_state': {
|
||||
'status': node_state.status.value if node_state else 'unknown',
|
||||
'inputs': self._safe_serialize(node_state.inputs) if node_state else {},
|
||||
'outputs': self._safe_serialize(node_state.outputs) if node_state else {},
|
||||
'error': node_state.error if node_state else None,
|
||||
},
|
||||
'completed': all_done,
|
||||
}
|
||||
|
||||
def _find_next_executable_node(
|
||||
self, workflow: WorkflowDefinition, context: ExecutionContext
|
||||
) -> Optional[NodeDefinition]:
|
||||
"""Find the next node that can be executed"""
|
||||
edge_map = self._build_edge_map(workflow.edges)
|
||||
|
||||
for node in workflow.nodes:
|
||||
state = context.node_states.get(node.id)
|
||||
|
||||
# Skip completed, running, or failed nodes
|
||||
if state and state.status in (
|
||||
NodeStatus.COMPLETED,
|
||||
NodeStatus.RUNNING,
|
||||
NodeStatus.FAILED,
|
||||
NodeStatus.SKIPPED,
|
||||
):
|
||||
continue
|
||||
|
||||
# Check if this node's inputs are ready
|
||||
incoming_nodes = set()
|
||||
for source_id, edges in edge_map.items():
|
||||
for edge in edges:
|
||||
if edge.target_node == node.id:
|
||||
incoming_nodes.add(source_id)
|
||||
|
||||
# If no incoming nodes, it's a start node
|
||||
if not incoming_nodes:
|
||||
return node
|
||||
|
||||
# Check if all incoming nodes are done
|
||||
all_incoming_done = True
|
||||
for source_id in incoming_nodes:
|
||||
source_state = context.node_states.get(source_id)
|
||||
if not source_state or source_state.status not in (NodeStatus.COMPLETED, NodeStatus.SKIPPED):
|
||||
all_incoming_done = False
|
||||
break
|
||||
|
||||
if all_incoming_done:
|
||||
return node
|
||||
|
||||
return None
|
||||
|
||||
def _safe_serialize(self, data: Any) -> Any:
|
||||
"""Safely serialize data for logging"""
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, (str, int, float, bool)):
|
||||
return data
|
||||
if isinstance(data, (list, tuple)):
|
||||
return [self._safe_serialize(item) for item in data[:100]] # Limit list size
|
||||
if isinstance(data, dict):
|
||||
result = {}
|
||||
for key, value in list(data.items())[:50]: # Limit dict size
|
||||
result[str(key)] = self._safe_serialize(value)
|
||||
return result
|
||||
# For complex objects, try to convert to string
|
||||
try:
|
||||
return str(data)[:1000] # Limit string length
|
||||
except Exception:
|
||||
return '<non-serializable>'
|
||||
|
||||
def get_execution_state(self, context: ExecutionContext, debug_state: DebugExecutionState) -> dict:
|
||||
"""Get current execution state for API response"""
|
||||
node_states = {}
|
||||
for node_id, state in context.node_states.items():
|
||||
node_states[node_id] = {
|
||||
'status': state.status.value,
|
||||
'inputs': self._safe_serialize(state.inputs),
|
||||
'outputs': self._safe_serialize(state.outputs),
|
||||
'error': state.error,
|
||||
'startTime': state.start_time.isoformat() if state.start_time else None,
|
||||
'endTime': state.end_time.isoformat() if state.end_time else None,
|
||||
'duration': int((state.end_time - state.start_time).total_seconds() * 1000)
|
||||
if state.start_time and state.end_time
|
||||
else None,
|
||||
}
|
||||
|
||||
return {
|
||||
'status': debug_state.status,
|
||||
'current_node_id': debug_state.current_node_id,
|
||||
'node_states': node_states,
|
||||
'new_logs': debug_state.get_pending_logs(),
|
||||
'error': context.error,
|
||||
}
|
||||
@@ -1,4 +1,12 @@
|
||||
"""Workflow execution engine"""
|
||||
"""Workflow execution engine.
|
||||
|
||||
This module contains the core workflow execution logic:
|
||||
- WorkflowExecutor: Main execution engine with control flow handling
|
||||
- ParallelExecutor: Parallel branch execution
|
||||
- LoopExecutor: Loop/iterator execution
|
||||
|
||||
Debug execution support has been moved to the ``debug`` module.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -6,7 +14,6 @@ import ast
|
||||
import asyncio
|
||||
import logging
|
||||
import operator
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
@@ -32,92 +39,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExecutionLog:
|
||||
"""Execution log entry"""
|
||||
|
||||
def __init__(self, level: str, message: str, node_id: Optional[str] = None, data: Optional[dict] = None):
|
||||
self.id = str(uuid.uuid4())
|
||||
self.timestamp = datetime.now().isoformat()
|
||||
self.level = level
|
||||
self.message = message
|
||||
self.node_id = node_id
|
||||
self.data = data or {}
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'id': self.id,
|
||||
'timestamp': self.timestamp,
|
||||
'level': self.level,
|
||||
'message': self.message,
|
||||
'node_id': self.node_id,
|
||||
'data': self.data,
|
||||
}
|
||||
|
||||
|
||||
class DebugExecutionState:
|
||||
"""State for a debug execution"""
|
||||
|
||||
def __init__(self, execution_id: str, breakpoints: list[str] = None):
|
||||
self.execution_id = execution_id
|
||||
self.status: str = 'running'
|
||||
self.is_paused: bool = False
|
||||
self.is_stopped: bool = False
|
||||
self.current_node_id: Optional[str] = None
|
||||
self.breakpoints: set[str] = set(breakpoints or [])
|
||||
self.logs: list[ExecutionLog] = []
|
||||
self.pending_logs: list[ExecutionLog] = []
|
||||
self._pause_event = asyncio.Event()
|
||||
self._pause_event.set() # Initially not paused
|
||||
self._stop_event = asyncio.Event()
|
||||
|
||||
def add_log(self, level: str, message: str, node_id: str = None, data: dict = None):
|
||||
"""Add a log entry"""
|
||||
log = ExecutionLog(level, message, node_id, data)
|
||||
self.logs.append(log)
|
||||
self.pending_logs.append(log)
|
||||
logger.log(
|
||||
getattr(logging, level.upper(), logging.INFO),
|
||||
f'[Workflow Debug] {message}',
|
||||
extra={'node_id': node_id, 'data': data},
|
||||
)
|
||||
|
||||
def get_pending_logs(self) -> list[dict]:
|
||||
"""Get and clear pending logs"""
|
||||
logs = [log.to_dict() for log in self.pending_logs]
|
||||
self.pending_logs = []
|
||||
return logs
|
||||
|
||||
def pause(self):
|
||||
"""Pause execution"""
|
||||
self.is_paused = True
|
||||
self._pause_event.clear()
|
||||
self.add_log('info', 'Execution paused')
|
||||
|
||||
def resume(self):
|
||||
"""Resume execution"""
|
||||
self.is_paused = False
|
||||
self._pause_event.set()
|
||||
self.add_log('info', 'Execution resumed')
|
||||
|
||||
def stop(self):
|
||||
"""Stop execution"""
|
||||
self.is_stopped = True
|
||||
self.status = 'cancelled'
|
||||
self._stop_event.set()
|
||||
self._pause_event.set() # Release any pause
|
||||
self.add_log('info', 'Execution stopped')
|
||||
|
||||
async def wait_if_paused(self):
|
||||
"""Wait if execution is paused"""
|
||||
if self.is_paused:
|
||||
self.add_log('info', 'Waiting for resume...')
|
||||
await self._pause_event.wait()
|
||||
|
||||
def check_breakpoint(self, node_id: str) -> bool:
|
||||
"""Check if there's a breakpoint at the given node"""
|
||||
return node_id in self.breakpoints
|
||||
|
||||
|
||||
# ─── Safe expression evaluator (replaces eval()) ─────────────────────
|
||||
# Uses Python's ast module to whitelist only comparison / boolean / arithmetic
|
||||
# operations. No function calls, attribute access, or subscript injection.
|
||||
@@ -465,9 +386,30 @@ class WorkflowExecutor:
|
||||
await self._persist_node_execution(node, node_state, context)
|
||||
return
|
||||
|
||||
# Check if node supports streaming (has execute_stream method and stream config is enabled)
|
||||
use_streaming = hasattr(node_instance, 'execute_stream') and node.config.get('stream', False)
|
||||
|
||||
# Execute with retries
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
if use_streaming:
|
||||
# Streaming execution with aggregation and timeout
|
||||
aggregated_response = ''
|
||||
try:
|
||||
async with asyncio.timeout(300): # 5 minute timeout for streaming
|
||||
async for chunk in node_instance.execute_stream(inputs, context):
|
||||
if chunk:
|
||||
aggregated_response += chunk
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f'Node {node.id} ({node.type}) streaming timed out, falling back to non-streaming')
|
||||
use_streaming = False
|
||||
outputs = await node_instance.execute(inputs, context)
|
||||
else:
|
||||
# Get response from context if set by execute_stream, otherwise use aggregated
|
||||
final_response = context.variables.pop('_last_llm_response', aggregated_response)
|
||||
outputs = {'response': final_response, 'usage': {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0}}
|
||||
logger.info(f'Node {node.id} ({node.type}) streaming completed, response length: {len(final_response)}')
|
||||
else:
|
||||
outputs = await node_instance.execute(inputs, context)
|
||||
node_state.outputs = outputs
|
||||
node_state.status = NodeStatus.COMPLETED
|
||||
@@ -516,9 +458,25 @@ class WorkflowExecutor:
|
||||
|
||||
# Get inputs from message context
|
||||
if context.message_context:
|
||||
inputs['message'] = context.message_context.message_content
|
||||
inputs['message_content'] = context.message_context.message_content
|
||||
inputs['sender_id'] = context.message_context.sender_id
|
||||
inputs['platform'] = context.message_context.platform
|
||||
else:
|
||||
logger.warning(
|
||||
f'[_resolve_inputs] node={node.id} ({node.type}): message_context is None!',
|
||||
extra={
|
||||
'node_id': node.id,
|
||||
'node_type': node.type,
|
||||
'execution_id': context.execution_id,
|
||||
'variables_keys': list(context.variables.keys()) if context.variables else [],
|
||||
},
|
||||
)
|
||||
|
||||
# Log current inputs state after message_context processing
|
||||
logger.debug(
|
||||
f'[_resolve_inputs] node={node.id} after message_context: {list(inputs.keys())}',
|
||||
)
|
||||
|
||||
# Get inputs from node config that reference other nodes
|
||||
for key, value in node.config.items():
|
||||
@@ -549,6 +507,22 @@ class WorkflowExecutor:
|
||||
# Last resort: use the first available output
|
||||
inputs[target_port] = next(iter(source_state.outputs.values()))
|
||||
|
||||
# Smart input mapping: if a node needs 'message' but received a different
|
||||
# port name (e.g., 'content' from llm_call), copy the value to 'message'.
|
||||
# This handles edge connection mismatches where the sender uses a different
|
||||
# port name than what the receiver expects.
|
||||
if 'message' not in inputs or inputs.get('message') is None:
|
||||
for fallback_key in ('content', 'response', 'input', 'output', 'result', 'text'):
|
||||
if fallback_key in inputs and inputs[fallback_key] is not None:
|
||||
inputs['message'] = inputs[fallback_key]
|
||||
logger.debug(
|
||||
f'[_resolve_inputs] node={node.id}: mapped {fallback_key} -> message',
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
f'[_resolve_inputs] node={node.id} final inputs keys: {list(inputs.keys())}, message={repr(inputs.get("message", "<missing>")[:100] if isinstance(inputs.get("message"), str) else inputs.get("message"))}',
|
||||
)
|
||||
return inputs
|
||||
|
||||
async def _resolve_expression(self, expression: str, context: ExecutionContext) -> Any:
|
||||
@@ -851,392 +825,3 @@ class LoopExecutor:
|
||||
return results
|
||||
|
||||
|
||||
class DebugWorkflowExecutor(WorkflowExecutor):
|
||||
"""
|
||||
Debug-enabled workflow executor with step-by-step execution support.
|
||||
Extends WorkflowExecutor with debugging capabilities.
|
||||
"""
|
||||
|
||||
# Class-level storage for active debug sessions
|
||||
_debug_states: dict[str, DebugExecutionState] = {}
|
||||
|
||||
def __init__(self, ap: Optional['app.Application'] = None):
|
||||
super().__init__(ap)
|
||||
|
||||
@classmethod
|
||||
def get_debug_state(cls, execution_id: str) -> Optional[DebugExecutionState]:
|
||||
"""Get debug state for an execution"""
|
||||
return cls._debug_states.get(execution_id)
|
||||
|
||||
@classmethod
|
||||
def create_debug_state(cls, execution_id: str, breakpoints: list[str] = None) -> DebugExecutionState:
|
||||
"""Create a new debug state"""
|
||||
state = DebugExecutionState(execution_id, breakpoints)
|
||||
cls._debug_states[execution_id] = state
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
def remove_debug_state(cls, execution_id: str):
|
||||
"""Remove debug state for an execution"""
|
||||
cls._debug_states.pop(execution_id, None)
|
||||
|
||||
async def execute_debug(
|
||||
self,
|
||||
workflow: WorkflowDefinition,
|
||||
context: ExecutionContext,
|
||||
debug_state: DebugExecutionState,
|
||||
) -> ExecutionContext:
|
||||
"""
|
||||
Execute a workflow in debug mode.
|
||||
|
||||
Args:
|
||||
workflow: Workflow definition
|
||||
context: Execution context
|
||||
debug_state: Debug execution state
|
||||
|
||||
Returns:
|
||||
Updated execution context
|
||||
"""
|
||||
context.status = ExecutionStatus.RUNNING
|
||||
context.start_time = datetime.now()
|
||||
debug_state.add_log('info', f'Starting debug execution for workflow: {workflow.name}')
|
||||
|
||||
try:
|
||||
# Build execution graph
|
||||
node_map = {node.id: node for node in workflow.nodes}
|
||||
edge_map = self._build_edge_map(workflow.edges)
|
||||
self._edges = workflow.edges
|
||||
|
||||
# Initialize node states
|
||||
for node in workflow.nodes:
|
||||
if node.id not in context.node_states:
|
||||
context.node_states[node.id] = NodeState(node_id=node.id)
|
||||
|
||||
# Find start node(s)
|
||||
start_nodes = self._find_start_nodes(workflow.nodes, workflow.edges)
|
||||
|
||||
if not start_nodes:
|
||||
raise ValueError('No start nodes found in workflow')
|
||||
|
||||
debug_state.add_log('info', f'Found {len(start_nodes)} start node(s)')
|
||||
|
||||
# Execute from start nodes
|
||||
for start_node in start_nodes:
|
||||
if debug_state.is_stopped:
|
||||
break
|
||||
|
||||
await self._execute_debug_from_node(
|
||||
start_node, node_map, edge_map, context, debug_state, workflow.settings.max_retries
|
||||
)
|
||||
|
||||
# Set final status
|
||||
if debug_state.is_stopped:
|
||||
context.status = ExecutionStatus.CANCELLED
|
||||
debug_state.status = 'cancelled'
|
||||
else:
|
||||
all_completed = all(
|
||||
state.status in (NodeStatus.COMPLETED, NodeStatus.SKIPPED) for state in context.node_states.values()
|
||||
)
|
||||
|
||||
if all_completed:
|
||||
context.status = ExecutionStatus.COMPLETED
|
||||
debug_state.status = 'completed'
|
||||
debug_state.add_log('info', 'Workflow execution completed successfully')
|
||||
else:
|
||||
has_failed = any(state.status == NodeStatus.FAILED for state in context.node_states.values())
|
||||
if has_failed:
|
||||
context.status = ExecutionStatus.FAILED
|
||||
debug_state.status = 'error'
|
||||
|
||||
except Exception as e:
|
||||
context.status = ExecutionStatus.FAILED
|
||||
context.error = str(e)
|
||||
debug_state.status = 'error'
|
||||
debug_state.add_log('error', f'Workflow execution failed: {e}', data={'traceback': traceback.format_exc()})
|
||||
logger.error(f'Debug workflow execution failed: {e}\n{traceback.format_exc()}')
|
||||
|
||||
finally:
|
||||
context.end_time = datetime.now()
|
||||
|
||||
return context
|
||||
|
||||
async def _execute_debug_from_node(
|
||||
self,
|
||||
node: NodeDefinition,
|
||||
node_map: dict[str, NodeDefinition],
|
||||
edge_map: dict[str, list[EdgeDefinition]],
|
||||
context: ExecutionContext,
|
||||
debug_state: DebugExecutionState,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
"""Execute workflow from a node with debug support"""
|
||||
|
||||
# Check if stopped
|
||||
if debug_state.is_stopped:
|
||||
return
|
||||
|
||||
# Wait if paused
|
||||
await debug_state.wait_if_paused()
|
||||
|
||||
# Check if should skip
|
||||
if await self._should_skip_node(node, context):
|
||||
if context.node_states[node.id].status == NodeStatus.SKIPPED:
|
||||
debug_state.add_log('info', f'Skipping node: {node.id}', node_id=node.id)
|
||||
return
|
||||
|
||||
# Check breakpoint
|
||||
if debug_state.check_breakpoint(node.id):
|
||||
debug_state.add_log('info', f'Hit breakpoint at node: {node.id}', node_id=node.id)
|
||||
debug_state.pause()
|
||||
await debug_state.wait_if_paused()
|
||||
|
||||
# Update current node
|
||||
debug_state.current_node_id = node.id
|
||||
debug_state.add_log('info', f'Executing node: {node.id} ({node.type})', node_id=node.id)
|
||||
|
||||
# Execute node
|
||||
await self._execute_debug_node(node, context, debug_state, max_retries)
|
||||
|
||||
# Check if stopped or failed
|
||||
if debug_state.is_stopped:
|
||||
return
|
||||
if context.node_states[node.id].status == NodeStatus.FAILED:
|
||||
return
|
||||
|
||||
# Get outgoing edges
|
||||
outgoing_edges = edge_map.get(node.id, [])
|
||||
|
||||
# Execute next nodes
|
||||
for edge in outgoing_edges:
|
||||
if debug_state.is_stopped:
|
||||
break
|
||||
|
||||
target_node = node_map.get(edge.target_node)
|
||||
if not target_node:
|
||||
continue
|
||||
|
||||
# Check edge condition
|
||||
if edge.condition:
|
||||
condition_met = await self._evaluate_condition(edge.condition, context)
|
||||
if not condition_met:
|
||||
debug_state.add_log('debug', f'Edge condition not met: {edge.condition}', node_id=node.id)
|
||||
continue
|
||||
|
||||
# Check if all inputs are ready
|
||||
if await self._inputs_ready(target_node, edge_map, context):
|
||||
await self._execute_debug_from_node(target_node, node_map, edge_map, context, debug_state, max_retries)
|
||||
|
||||
async def _execute_debug_node(
|
||||
self, node: NodeDefinition, context: ExecutionContext, debug_state: DebugExecutionState, max_retries: int = 3
|
||||
):
|
||||
"""Execute a single node with debug logging"""
|
||||
|
||||
node_state = context.node_states[node.id]
|
||||
node_state.status = NodeStatus.RUNNING
|
||||
node_state.start_time = datetime.now()
|
||||
|
||||
# Get node instance (pass ap for access to services)
|
||||
node_instance = self.registry.create_instance(node.type, node.id, node.config, ap=self.ap)
|
||||
|
||||
if not node_instance:
|
||||
node_state.status = NodeStatus.FAILED
|
||||
node_state.error = f'Unknown node type: {node.type}'
|
||||
node_state.end_time = datetime.now()
|
||||
debug_state.add_log('error', f'Unknown node type: {node.type}', node_id=node.id)
|
||||
self._record_execution_step(node, node_state, context)
|
||||
await self._persist_node_execution(node, node_state, context)
|
||||
return
|
||||
|
||||
# Resolve inputs
|
||||
inputs = await self._resolve_inputs(node, context)
|
||||
node_state.inputs = inputs
|
||||
debug_state.add_log(
|
||||
'debug', 'Node inputs resolved', node_id=node.id, data={'inputs': self._safe_serialize(inputs)}
|
||||
)
|
||||
|
||||
# Validate inputs
|
||||
validation_errors = await node_instance.validate_inputs(inputs)
|
||||
if validation_errors:
|
||||
node_state.status = NodeStatus.FAILED
|
||||
node_state.error = '; '.join(validation_errors)
|
||||
node_state.end_time = datetime.now()
|
||||
debug_state.add_log('error', f'Input validation failed: {node_state.error}', node_id=node.id)
|
||||
self._record_execution_step(node, node_state, context)
|
||||
await self._persist_node_execution(node, node_state, context)
|
||||
return
|
||||
|
||||
# Execute with retries
|
||||
for attempt in range(max_retries + 1):
|
||||
if debug_state.is_stopped:
|
||||
node_state.status = NodeStatus.FAILED
|
||||
node_state.error = 'Execution stopped'
|
||||
node_state.end_time = datetime.now()
|
||||
break
|
||||
|
||||
try:
|
||||
outputs = await node_instance.execute(inputs, context)
|
||||
node_state.outputs = outputs
|
||||
node_state.status = NodeStatus.COMPLETED
|
||||
node_state.end_time = datetime.now()
|
||||
|
||||
duration_ms = int((node_state.end_time - node_state.start_time).total_seconds() * 1000)
|
||||
debug_state.add_log(
|
||||
'info',
|
||||
f'Node completed in {duration_ms}ms',
|
||||
node_id=node.id,
|
||||
data={'outputs': self._safe_serialize(outputs), 'duration_ms': duration_ms},
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
node_state.retry_count = attempt + 1
|
||||
debug_state.add_log(
|
||||
'warning', f'Node execution failed (attempt {attempt + 1}/{max_retries + 1}): {e}', node_id=node.id
|
||||
)
|
||||
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
node_state.status = NodeStatus.FAILED
|
||||
node_state.error = str(e)
|
||||
node_state.end_time = datetime.now()
|
||||
debug_state.add_log(
|
||||
'error',
|
||||
f'Node failed after {max_retries + 1} attempts: {e}',
|
||||
node_id=node.id,
|
||||
data={'error': str(e), 'traceback': traceback.format_exc()},
|
||||
)
|
||||
|
||||
self._record_execution_step(node, node_state, context)
|
||||
await self._persist_node_execution(node, node_state, context)
|
||||
|
||||
async def step_execute(
|
||||
self,
|
||||
workflow: WorkflowDefinition,
|
||||
context: ExecutionContext,
|
||||
debug_state: DebugExecutionState,
|
||||
) -> dict:
|
||||
"""
|
||||
Execute one step (one node) in debug mode.
|
||||
|
||||
Returns:
|
||||
Dict with node_id, node_state, and completed status
|
||||
"""
|
||||
# Find next node to execute
|
||||
next_node = self._find_next_executable_node(workflow, context)
|
||||
|
||||
if not next_node:
|
||||
debug_state.status = 'completed'
|
||||
return {'completed': True}
|
||||
|
||||
# Execute single node
|
||||
debug_state.current_node_id = next_node.id
|
||||
await self._execute_debug_node(next_node, context, debug_state, workflow.settings.max_retries)
|
||||
|
||||
node_state = context.node_states.get(next_node.id)
|
||||
|
||||
# Check if workflow is complete
|
||||
all_done = all(
|
||||
state.status in (NodeStatus.COMPLETED, NodeStatus.SKIPPED, NodeStatus.FAILED)
|
||||
for state in context.node_states.values()
|
||||
)
|
||||
|
||||
if all_done:
|
||||
debug_state.status = 'completed'
|
||||
context.status = ExecutionStatus.COMPLETED
|
||||
|
||||
return {
|
||||
'node_id': next_node.id,
|
||||
'node_state': {
|
||||
'status': node_state.status.value if node_state else 'unknown',
|
||||
'inputs': self._safe_serialize(node_state.inputs) if node_state else {},
|
||||
'outputs': self._safe_serialize(node_state.outputs) if node_state else {},
|
||||
'error': node_state.error if node_state else None,
|
||||
},
|
||||
'completed': all_done,
|
||||
}
|
||||
|
||||
def _find_next_executable_node(
|
||||
self, workflow: WorkflowDefinition, context: ExecutionContext
|
||||
) -> Optional[NodeDefinition]:
|
||||
"""Find the next node that can be executed"""
|
||||
edge_map = self._build_edge_map(workflow.edges)
|
||||
|
||||
for node in workflow.nodes:
|
||||
state = context.node_states.get(node.id)
|
||||
|
||||
# Skip completed, running, or failed nodes
|
||||
if state and state.status in (
|
||||
NodeStatus.COMPLETED,
|
||||
NodeStatus.RUNNING,
|
||||
NodeStatus.FAILED,
|
||||
NodeStatus.SKIPPED,
|
||||
):
|
||||
continue
|
||||
|
||||
# Check if this node's inputs are ready
|
||||
incoming_nodes = set()
|
||||
for source_id, edges in edge_map.items():
|
||||
for edge in edges:
|
||||
if edge.target_node == node.id:
|
||||
incoming_nodes.add(source_id)
|
||||
|
||||
# If no incoming nodes, it's a start node
|
||||
if not incoming_nodes:
|
||||
return node
|
||||
|
||||
# Check if all incoming nodes are done
|
||||
all_incoming_done = True
|
||||
for source_id in incoming_nodes:
|
||||
source_state = context.node_states.get(source_id)
|
||||
if not source_state or source_state.status not in (NodeStatus.COMPLETED, NodeStatus.SKIPPED):
|
||||
all_incoming_done = False
|
||||
break
|
||||
|
||||
if all_incoming_done:
|
||||
return node
|
||||
|
||||
return None
|
||||
|
||||
def _safe_serialize(self, data: Any) -> Any:
|
||||
"""Safely serialize data for logging"""
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, (str, int, float, bool)):
|
||||
return data
|
||||
if isinstance(data, (list, tuple)):
|
||||
return [self._safe_serialize(item) for item in data[:100]] # Limit list size
|
||||
if isinstance(data, dict):
|
||||
result = {}
|
||||
for key, value in list(data.items())[:50]: # Limit dict size
|
||||
result[str(key)] = self._safe_serialize(value)
|
||||
return result
|
||||
# For complex objects, try to convert to string
|
||||
try:
|
||||
return str(data)[:1000] # Limit string length
|
||||
except Exception:
|
||||
return '<non-serializable>'
|
||||
|
||||
def get_execution_state(self, context: ExecutionContext, debug_state: DebugExecutionState) -> dict:
|
||||
"""Get current execution state for API response"""
|
||||
node_states = {}
|
||||
for node_id, state in context.node_states.items():
|
||||
node_states[node_id] = {
|
||||
'status': state.status.value,
|
||||
'inputs': self._safe_serialize(state.inputs),
|
||||
'outputs': self._safe_serialize(state.outputs),
|
||||
'error': state.error,
|
||||
'startTime': state.start_time.isoformat() if state.start_time else None,
|
||||
'endTime': state.end_time.isoformat() if state.end_time else None,
|
||||
'duration': int((state.end_time - state.start_time).total_seconds() * 1000)
|
||||
if state.start_time and state.end_time
|
||||
else None,
|
||||
}
|
||||
|
||||
return {
|
||||
'status': debug_state.status,
|
||||
'current_node_id': debug_state.current_node_id,
|
||||
'node_states': node_states,
|
||||
'new_logs': debug_state.get_pending_logs(),
|
||||
'error': context.error,
|
||||
}
|
||||
|
||||
@@ -2,17 +2,32 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
|
||||
from ..entities import ExecutionContext
|
||||
from ..node import WorkflowNode, workflow_node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pre-compiled regex patterns for CoT content removal (performance optimization)
|
||||
_THINK_PATTERNS = [
|
||||
re.compile(r'<think>.*?</think>', re.DOTALL | re.IGNORECASE),
|
||||
re.compile(r'<thought>.*?</thought>', re.DOTALL | re.IGNORECASE),
|
||||
re.compile(r'<reasoning>.*?</reasoning>', re.DOTALL | re.IGNORECASE),
|
||||
re.compile(r'<\u601d\u8003>.*?</\u601d\u8003>', re.DOTALL | re.IGNORECASE),
|
||||
re.compile(r'<\u63a8\u7406>.*?</\u63a8\u7406>', re.DOTALL | re.IGNORECASE),
|
||||
]
|
||||
|
||||
# Template variable regex
|
||||
_TEMPLATE_VAR_RE = re.compile(r'\{\{([^}]+)\}\}')
|
||||
|
||||
|
||||
@workflow_node('llm_call')
|
||||
class LLMCallNode(WorkflowNode):
|
||||
"""LLM call node - invoke large language model"""
|
||||
@@ -21,6 +36,10 @@ class LLMCallNode(WorkflowNode):
|
||||
|
||||
def _resolve_template(self, template: str, inputs: dict[str, Any], context: ExecutionContext) -> str:
|
||||
"""Resolve {{variable}} placeholders in a template string."""
|
||||
if not template:
|
||||
return ''
|
||||
|
||||
unresolved_vars = []
|
||||
|
||||
def replacer(match: re.Match) -> str:
|
||||
expr = match.group(1).strip()
|
||||
@@ -35,9 +54,121 @@ class LLMCallNode(WorkflowNode):
|
||||
if expr.startswith('message.') and context.message_context:
|
||||
attr = expr[len('message.'):]
|
||||
return str(getattr(context.message_context, attr, ''))
|
||||
unresolved_vars.append(expr)
|
||||
return match.group(0) # leave unresolved
|
||||
|
||||
return re.sub(r'\{\{([^}]+)\}\}', replacer, template)
|
||||
result = _TEMPLATE_VAR_RE.sub(replacer, template)
|
||||
|
||||
# Log warning for unresolved variables
|
||||
if unresolved_vars:
|
||||
logger.warning(
|
||||
f'LLM call node {self.node_id}: unresolved template variables: {unresolved_vars}'
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _remove_think_content(self, text: str) -> str:
|
||||
"""Remove CoT (Chain of Thought) thinking content from response."""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
result = text
|
||||
for pattern in _THINK_PATTERNS:
|
||||
result = pattern.sub('', result)
|
||||
|
||||
return result.strip()
|
||||
|
||||
def _apply_content_filter(self, text: str) -> tuple[str, bool, str]:
|
||||
"""Apply content safety filter to text.
|
||||
|
||||
Returns:
|
||||
(filtered_text, is_blocked, user_notice)
|
||||
"""
|
||||
if not text or not self.ap:
|
||||
return text, False, ''
|
||||
|
||||
# Check if content filter is enabled
|
||||
safety_config = getattr(self.ap, 'pipeline_cfg', None)
|
||||
if not safety_config:
|
||||
return text, False, ''
|
||||
|
||||
# Check sensitive words
|
||||
sensitive_words = []
|
||||
try:
|
||||
if hasattr(self.ap, 'sensitive_meta') and hasattr(self.ap.sensitive_meta, 'data'):
|
||||
sensitive_words = self.ap.sensitive_meta.data.get('words', [])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not sensitive_words:
|
||||
return text, False, ''
|
||||
|
||||
found = False
|
||||
filtered_text = text
|
||||
for word in sensitive_words:
|
||||
try:
|
||||
matches = re.findall(word, filtered_text, re.IGNORECASE)
|
||||
if matches:
|
||||
found = True
|
||||
mask_word = ''
|
||||
mask = '*'
|
||||
try:
|
||||
if hasattr(self.ap, 'sensitive_meta') and hasattr(self.ap.sensitive_meta, 'data'):
|
||||
mask_word = self.ap.sensitive_meta.data.get('mask_word', '')
|
||||
mask = self.ap.sensitive_meta.data.get('mask', '*')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for m in matches:
|
||||
if mask_word:
|
||||
filtered_text = filtered_text.replace(m, mask_word)
|
||||
else:
|
||||
filtered_text = filtered_text.replace(m, mask * len(m))
|
||||
except re.error:
|
||||
# Invalid regex pattern, skip
|
||||
continue
|
||||
|
||||
if found:
|
||||
return filtered_text, False, '消息中存在不合适的内容, 请修改'
|
||||
|
||||
return text, False, ''
|
||||
|
||||
def _parse_tools_config(self, tools_config: Any) -> list[dict]:
|
||||
"""Parse tools configuration from YAML config format."""
|
||||
if not tools_config:
|
||||
return []
|
||||
|
||||
# If it's already a list, return as-is
|
||||
if isinstance(tools_config, list):
|
||||
return tools_config
|
||||
|
||||
# If it's a string, try to parse as JSON
|
||||
if isinstance(tools_config, str):
|
||||
tools_config = tools_config.strip()
|
||||
if not tools_config:
|
||||
return []
|
||||
try:
|
||||
parsed = json.loads(tools_config)
|
||||
if isinstance(parsed, list):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f'Failed to parse tools config as JSON: {tools_config}')
|
||||
return []
|
||||
|
||||
return []
|
||||
|
||||
def _build_system_prompt_with_format(self, base_prompt: str, output_format: str, json_schema: str) -> str:
|
||||
"""Build system prompt with output format instructions."""
|
||||
prompt = base_prompt
|
||||
|
||||
if output_format == 'json':
|
||||
prompt += '\n\nPlease respond in valid JSON format.'
|
||||
if json_schema:
|
||||
prompt += f'\nFollow this JSON schema:\n{json_schema}'
|
||||
elif output_format == 'markdown':
|
||||
prompt += '\n\nPlease respond in Markdown format.'
|
||||
|
||||
return prompt
|
||||
|
||||
async def execute(self, inputs: dict[str, Any], context: ExecutionContext) -> dict[str, Any]:
|
||||
model_uuid = self.get_config('model', '')
|
||||
@@ -45,11 +176,32 @@ class LLMCallNode(WorkflowNode):
|
||||
raise ValueError('No model configured for LLM call node')
|
||||
|
||||
if not self.ap:
|
||||
raise RuntimeError('Application instance not available — cannot call LLM')
|
||||
raise RuntimeError('Application instance not available - cannot call LLM')
|
||||
|
||||
# Resolve prompts
|
||||
system_prompt = self._resolve_template(self.get_config('system_prompt', ''), inputs, context)
|
||||
user_prompt = self._resolve_template(self.get_config('user_prompt_template', '{{input}}'), inputs, context)
|
||||
# Get error handling config
|
||||
exception_handling = self.get_config('exception_handling', 'show-error')
|
||||
failure_hint = self.get_config('failure_hint', 'Request failed.')
|
||||
remove_think = self.get_config('remove_think', False)
|
||||
track_function_calls = self.get_config('track_function_calls', False)
|
||||
|
||||
# Get output format and json_schema config
|
||||
output_format = self.get_config('output_format', 'text')
|
||||
json_schema = self.get_config('json_schema', '')
|
||||
|
||||
# Get tools config
|
||||
enable_tools = self.get_config('enable_tools', False)
|
||||
tools_config = self.get_config('tools', [])
|
||||
|
||||
# Resolve prompts - handle null user_prompt_template
|
||||
system_prompt = self._resolve_template(self.get_config('system_prompt') or '', inputs, context)
|
||||
user_prompt_template = self.get_config('user_prompt_template')
|
||||
if user_prompt_template is None:
|
||||
# Default to input if not set
|
||||
user_prompt_template = '{{input}}'
|
||||
user_prompt = self._resolve_template(user_prompt_template, inputs, context)
|
||||
|
||||
# Build system prompt with format instructions
|
||||
system_prompt = self._build_system_prompt_with_format(system_prompt, output_format, json_schema)
|
||||
|
||||
# Build messages
|
||||
messages: list[provider_message.Message] = []
|
||||
@@ -69,30 +221,89 @@ class LLMCallNode(WorkflowNode):
|
||||
if max_tokens and int(max_tokens) > 0:
|
||||
extra_args['max_tokens'] = int(max_tokens)
|
||||
|
||||
# Invoke LLM
|
||||
# Build tools list if enabled
|
||||
funcs: list[resource_tool.LLMTool] | None = None
|
||||
if enable_tools and tools_config:
|
||||
try:
|
||||
tool_names = self._parse_tools_config(tools_config)
|
||||
if tool_names:
|
||||
all_tools = await self.ap.tool_mgr.get_tools()
|
||||
funcs = [t for t in all_tools if t.name in tool_names]
|
||||
if funcs:
|
||||
logger.info(f'LLM call node {self.node_id}: using tools: {[t.name for t in funcs]}')
|
||||
except Exception as e:
|
||||
logger.warning(f'LLM call node {self.node_id}: failed to load tools - {e}')
|
||||
funcs = None
|
||||
|
||||
# Invoke LLM with error handling
|
||||
logger.info(f'LLM call node {self.node_id}: invoking model {model_uuid}')
|
||||
try:
|
||||
result_message = await runtime_model.provider.invoke_llm(
|
||||
query=None,
|
||||
model=runtime_model,
|
||||
messages=messages,
|
||||
funcs=None,
|
||||
funcs=funcs,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'LLM call node {self.node_id}: request failed - {e}')
|
||||
|
||||
# Handle based on exception handling strategy
|
||||
if exception_handling == 'show-error':
|
||||
raise
|
||||
elif exception_handling == 'show-hint':
|
||||
return {
|
||||
'response': failure_hint,
|
||||
'usage': {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0,
|
||||
},
|
||||
'error': str(e),
|
||||
'error_hint_shown': True,
|
||||
}
|
||||
else: # hide
|
||||
return {
|
||||
'response': '',
|
||||
'usage': {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0,
|
||||
},
|
||||
'error': str(e),
|
||||
}
|
||||
|
||||
# Extract response text
|
||||
response_text = ''
|
||||
if isinstance(result_message.content, str):
|
||||
response_text = result_message.content
|
||||
elif isinstance(result_message.content, list):
|
||||
# ContentElement list — concatenate text elements
|
||||
for elem in result_message.content:
|
||||
if hasattr(elem, 'text') and elem.text:
|
||||
response_text += elem.text
|
||||
elif isinstance(elem, str):
|
||||
response_text += elem
|
||||
|
||||
# Remove CoT (Chain of Thought) content if configured
|
||||
if remove_think:
|
||||
response_text = self._remove_think_content(response_text)
|
||||
|
||||
# Apply content safety filter
|
||||
response_text, is_blocked, filter_notice = self._apply_content_filter(response_text)
|
||||
if is_blocked:
|
||||
logger.warning(f'LLM call node {self.node_id}: response blocked by content filter - {filter_notice}')
|
||||
return {
|
||||
'response': filter_notice,
|
||||
'usage': usage,
|
||||
'blocked_by_filter': True,
|
||||
}
|
||||
|
||||
# Extract usage info if available
|
||||
usage = {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0}
|
||||
usage = {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0,
|
||||
}
|
||||
if hasattr(result_message, 'usage') and result_message.usage:
|
||||
u = result_message.usage
|
||||
usage = {
|
||||
@@ -108,7 +319,136 @@ class LLMCallNode(WorkflowNode):
|
||||
'total_tokens': getattr(u, 'total_tokens', 0) or 0,
|
||||
}
|
||||
|
||||
return {
|
||||
result: dict[str, Any] = {
|
||||
'response': response_text,
|
||||
'usage': usage,
|
||||
}
|
||||
|
||||
# Parse JSON output if format is json
|
||||
if output_format == 'json' and response_text:
|
||||
try:
|
||||
result['parsed'] = json.loads(response_text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f'LLM call node {self.node_id}: failed to parse JSON response - {e}')
|
||||
result['parsed'] = None
|
||||
result['parse_error'] = str(e)
|
||||
|
||||
# Add function call tracking info if configured
|
||||
if track_function_calls:
|
||||
result['function_calls'] = []
|
||||
|
||||
return result
|
||||
|
||||
async def execute_stream(
|
||||
self, inputs: dict[str, Any], context: ExecutionContext
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Execute the LLM call with streaming output.
|
||||
|
||||
Yields chunks of response text as they arrive.
|
||||
Falls back to non-streaming if streaming is not available.
|
||||
"""
|
||||
model_uuid = self.get_config('model', '')
|
||||
if not model_uuid:
|
||||
raise ValueError('No model configured for LLM call node')
|
||||
|
||||
if not self.ap:
|
||||
raise RuntimeError('Application instance not available - cannot call LLM')
|
||||
|
||||
remove_think = self.get_config('remove_think', False)
|
||||
exception_handling = self.get_config('exception_handling', 'show-error')
|
||||
failure_hint = self.get_config('failure_hint', 'Request failed.')
|
||||
|
||||
# Resolve prompts
|
||||
system_prompt = self._resolve_template(self.get_config('system_prompt') or '', inputs, context)
|
||||
user_prompt_template = self.get_config('user_prompt_template')
|
||||
if user_prompt_template is None:
|
||||
user_prompt_template = '{{input}}'
|
||||
user_prompt = self._resolve_template(user_prompt_template, inputs, context)
|
||||
|
||||
# Build messages
|
||||
messages: list[provider_message.Message] = []
|
||||
if system_prompt:
|
||||
messages.append(provider_message.Message(role='system', content=system_prompt))
|
||||
messages.append(provider_message.Message(role='user', content=user_prompt))
|
||||
|
||||
# Get model
|
||||
runtime_model = await self.ap.model_mgr.get_model_by_uuid(model_uuid)
|
||||
|
||||
# Build extra args
|
||||
extra_args: dict[str, Any] = {}
|
||||
temperature = self.get_config('temperature')
|
||||
if temperature is not None:
|
||||
extra_args['temperature'] = float(temperature)
|
||||
max_tokens = self.get_config('max_tokens', 0)
|
||||
if max_tokens and int(max_tokens) > 0:
|
||||
extra_args['max_tokens'] = int(max_tokens)
|
||||
|
||||
logger.info(f'LLM call node {self.node_id}: streaming model {model_uuid}')
|
||||
|
||||
try:
|
||||
# Try streaming first
|
||||
stream = runtime_model.provider.invoke_llm_stream(
|
||||
query=None,
|
||||
model=runtime_model,
|
||||
messages=messages,
|
||||
funcs=None,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
full_response = ''
|
||||
async for chunk in stream:
|
||||
chunk_text = ''
|
||||
if hasattr(chunk, 'content'):
|
||||
if isinstance(chunk.content, str):
|
||||
chunk_text = chunk.content
|
||||
elif isinstance(chunk.content, list):
|
||||
for elem in chunk.content:
|
||||
if hasattr(elem, 'text') and elem.text:
|
||||
chunk_text += elem.text
|
||||
elif isinstance(elem, str):
|
||||
chunk_text += elem
|
||||
|
||||
if chunk_text:
|
||||
if remove_think:
|
||||
chunk_text = self._remove_think_content(chunk_text)
|
||||
full_response += chunk_text
|
||||
yield chunk_text
|
||||
|
||||
# Store in context for downstream nodes
|
||||
context.variables['_last_llm_response'] = full_response
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'LLM call node {self.node_id}: streaming failed, falling back - {e}')
|
||||
# Fallback to non-streaming
|
||||
try:
|
||||
result_message = await runtime_model.provider.invoke_llm(
|
||||
query=None,
|
||||
model=runtime_model,
|
||||
messages=messages,
|
||||
funcs=None,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
response_text = self._extract_response_text(result_message)
|
||||
if remove_think:
|
||||
response_text = self._remove_think_content(response_text)
|
||||
yield response_text
|
||||
context.variables['_last_llm_response'] = response_text
|
||||
except Exception as e2:
|
||||
logger.error(f'LLM call node {self.node_id}: fallback also failed - {e2}')
|
||||
if exception_handling == 'show-hint':
|
||||
yield failure_hint
|
||||
elif exception_handling != 'hide':
|
||||
raise
|
||||
|
||||
def _extract_response_text(self, result_message: provider_message.Message) -> str:
|
||||
"""Extract response text from LLM result message."""
|
||||
response_text = ''
|
||||
if isinstance(result_message.content, str):
|
||||
response_text = result_message.content
|
||||
elif isinstance(result_message.content, list):
|
||||
for elem in result_message.content:
|
||||
if hasattr(elem, 'text') and elem.text:
|
||||
response_text += elem.text
|
||||
elif isinstance(elem, str):
|
||||
response_text += elem
|
||||
return response_text
|
||||
|
||||
@@ -153,3 +153,58 @@ config:
|
||||
description:
|
||||
en_US: Select tools that the model can use
|
||||
zh_Hans: 选择模型可以使用的工具
|
||||
|
||||
- name: exception_handling
|
||||
type: select
|
||||
required: true
|
||||
default: show-hint
|
||||
options:
|
||||
- name: show-error
|
||||
label:
|
||||
en_US: Show Full Error
|
||||
zh_Hans: 显示完整报错信息
|
||||
- name: show-hint
|
||||
label:
|
||||
en_US: Show Failure Hint
|
||||
zh_Hans: 仅文字提示
|
||||
- name: hide
|
||||
label:
|
||||
en_US: Hide All
|
||||
zh_Hans: 不显示任何异常信息
|
||||
label:
|
||||
en_US: Exception Handling Strategy
|
||||
zh_Hans: 异常处理策略
|
||||
description:
|
||||
en_US: Controls how error messages are displayed to the user when an AI request fails
|
||||
zh_Hans: 控制 AI 请求失败时向用户展示错误信息的方式
|
||||
|
||||
- name: failure_hint
|
||||
type: string
|
||||
required: false
|
||||
default: 'Request failed.'
|
||||
label:
|
||||
en_US: Failure Hint Text
|
||||
zh_Hans: 失败提示文本
|
||||
description:
|
||||
en_US: The text to display when a request fails. Only effective when Exception Handling Strategy is set to "Show Failure Hint"
|
||||
zh_Hans: 请求失败时显示的提示文本,仅在异常处理策略设置为"仅文字提示"时生效
|
||||
|
||||
- name: remove_think
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: Remove CoT
|
||||
zh_Hans: 删除思维链
|
||||
description:
|
||||
en_US: 'If enabled, the model thinking content in the response will be automatically removed. Note: When using streaming response, removing CoT may cause the first token to wait for a long time.'
|
||||
zh_Hans: '如果启用,将自动删除大模型回复中的模型思考内容。注意:当您使用流式响应时,删除思维链可能会导致首个 Token 的等待时间过长'
|
||||
|
||||
- name: track_function_calls
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: Track Function Calls
|
||||
zh_Hans: 跟踪函数调用
|
||||
description:
|
||||
en_US: If enabled, the Agent will output a hint to the user each time a tool is called
|
||||
zh_Hans: 启用后,Agent 每次调用工具时都会输出一个提示给用户
|
||||
|
||||
@@ -13,6 +13,7 @@ description:
|
||||
inputs:
|
||||
- name: message
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Message
|
||||
zh_Hans: 消息
|
||||
|
||||
@@ -447,7 +447,7 @@ function WorkflowEditorInner() {
|
||||
panOnDrag={[1, 2]} // Middle click and right click to pan
|
||||
selectNodesOnDrag={false}
|
||||
defaultEdgeOptions={{
|
||||
type: 'bezier',
|
||||
type: 'default',
|
||||
animated: true,
|
||||
markerEnd: {
|
||||
type: MarkerType.ArrowClosed,
|
||||
|
||||
@@ -257,7 +257,7 @@ export const useWorkflowStore = create<WorkflowState>((set, get) => ({
|
||||
const newEdge: WorkflowEdge = {
|
||||
...connection,
|
||||
id: generateEdgeId(),
|
||||
type: 'bezier',
|
||||
type: 'default',
|
||||
} as WorkflowEdge;
|
||||
|
||||
set((state) => ({
|
||||
@@ -464,7 +464,7 @@ export const useWorkflowStore = create<WorkflowState>((set, get) => ({
|
||||
target: edge.target,
|
||||
sourceHandle: edge.source_port,
|
||||
targetHandle: edge.target_port,
|
||||
type: 'bezier',
|
||||
type: 'default',
|
||||
data: {
|
||||
label: edge.label,
|
||||
condition: edge.condition,
|
||||
|
||||
Reference in New Issue
Block a user