Source code for litestar_workflows.engine.graph

"""Workflow graph operations and navigation.

This module provides graph-based operations for workflow definitions,
including step navigation, validation, and path finding.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from litestar_workflows.core.context import WorkflowContext
    from litestar_workflows.core.definition import Edge, WorkflowDefinition

__all__ = ["WorkflowGraph"]


[docs] class WorkflowGraph: """Graph representation of a workflow for navigation and validation. The graph provides methods to navigate between steps, validate structure, and determine execution paths based on workflow state. Attributes: definition: The workflow definition this graph represents. _adjacency: Adjacency list mapping step names to outgoing edges. _reverse_adjacency: Reverse adjacency list for finding predecessors. """
[docs] def __init__(self, definition: WorkflowDefinition) -> None: """Initialize a workflow graph from a definition. Args: definition: The workflow definition to represent as a graph. """ self.definition = definition self._adjacency: dict[str, list[Edge]] = {} self._reverse_adjacency: dict[str, list[str]] = {} self._build_adjacency()
def _build_adjacency(self) -> None: """Build adjacency lists from workflow edges.""" # Initialize adjacency lists for all steps for step_name in self.definition.steps: self._adjacency[step_name] = [] self._reverse_adjacency[step_name] = [] # Build adjacency lists from edges for edge in self.definition.edges: source = edge.source if isinstance(edge.source, str) else edge.source.name target = edge.target if isinstance(edge.target, str) else edge.target.name self._adjacency[source].append(edge) if target not in self._reverse_adjacency: self._reverse_adjacency[target] = [] self._reverse_adjacency[target].append(source)
[docs] @classmethod def from_definition(cls, definition: WorkflowDefinition) -> WorkflowGraph: """Create a workflow graph from a definition. Args: definition: The workflow definition. Returns: A WorkflowGraph instance. Example: >>> graph = WorkflowGraph.from_definition(my_definition) """ return cls(definition)
[docs] def get_next_steps( self, current: str, context: WorkflowContext, ) -> list[str]: """Get the next steps from the current step. Evaluates edge conditions to determine which steps should execute next. If multiple unconditional edges exist, all targets are returned (parallel). Args: current: Name of the current step. context: The workflow context for condition evaluation. Returns: List of next step names. Empty list if current step is terminal. Example: >>> next_steps = graph.get_next_steps("approval", context) >>> if len(next_steps) > 1: ... print("Parallel execution required") """ if current not in self._adjacency: return [] next_steps = [] for edge in self._adjacency[current]: target = edge.target if isinstance(edge.target, str) else edge.target.name # If edge has a condition, evaluate it if edge.condition is not None: # Use the edge's evaluate_condition method for proper type handling if edge.evaluate_condition(context): next_steps.append(target) else: # Unconditional edge next_steps.append(target) return next_steps
[docs] def get_previous_steps(self, current: str) -> list[str]: """Get the steps that lead to the current step. Args: current: Name of the current step. Returns: List of predecessor step names. Example: >>> predecessors = graph.get_previous_steps("final_step") """ return self._reverse_adjacency.get(current, [])
[docs] def is_terminal(self, step_name: str) -> bool: """Check if a step is a terminal (end) step. Args: step_name: Name of the step to check. Returns: True if the step has no outgoing edges or is in terminal_steps. Example: >>> if graph.is_terminal("approval"): ... print("Workflow ends here") """ # Explicitly marked as terminal if step_name in self.definition.terminal_steps: return True # No outgoing edges return not self._adjacency.get(step_name, [])
[docs] def validate(self) -> list[str]: """Validate the workflow graph structure. Checks for common graph issues: - Unreachable steps (not connected to initial step) - Disconnected components - Steps with no outgoing edges that aren't marked terminal - Invalid edge references Returns: List of validation error messages. Empty list if valid. Example: >>> errors = graph.validate() >>> if errors: ... for error in errors: ... print(f"Validation error: {error}") """ errors = [] # Check that initial step exists if self.definition.initial_step not in self.definition.steps: errors.append(f"Initial step '{self.definition.initial_step}' not found in steps") # Check for invalid edge references for edge in self.definition.edges: source = edge.source if isinstance(edge.source, str) else edge.source.name target = edge.target if isinstance(edge.target, str) else edge.target.name if source not in self.definition.steps: errors.append(f"Edge source '{source}' not found in steps") if target not in self.definition.steps: errors.append(f"Edge target '{target}' not found in steps") # Check for unreachable steps reachable = self._get_reachable_steps() for step_name in self.definition.steps: if step_name not in reachable and step_name != self.definition.initial_step: errors.append(f"Step '{step_name}' is unreachable from initial step") # Check for steps with no outgoing edges that aren't terminal for step_name in self.definition.steps: if not self._adjacency.get(step_name) and step_name not in self.definition.terminal_steps: errors.append(f"Step '{step_name}' has no outgoing edges but is not marked as terminal") return errors
def _get_reachable_steps(self) -> set[str]: """Get all steps reachable from the initial step. Returns: Set of step names that can be reached. """ if self.definition.initial_step not in self.definition.steps: return set() reachable = set() to_visit = [self.definition.initial_step] while to_visit: current = to_visit.pop() if current in reachable: continue reachable.add(current) # Add all targets of outgoing edges for edge in self._adjacency.get(current, []): target = edge.target if isinstance(edge.target, str) else edge.target.name if target not in reachable: to_visit.append(target) return reachable
[docs] def get_all_paths( self, start: str, end: str, max_paths: int = 100, ) -> list[list[str]]: """Find all paths from start step to end step. Args: start: Starting step name. end: Ending step name. max_paths: Maximum number of paths to find (prevents infinite loops). Returns: List of paths, where each path is a list of step names. Example: >>> paths = graph.get_all_paths("start", "end") >>> for path in paths: ... print(" -> ".join(path)) """ paths: list[list[str]] = [] self._find_paths(start, end, [], paths, max_paths) return paths
def _find_paths( self, current: str, end: str, path: list[str], paths: list[list[str]], max_paths: int, ) -> None: """Recursive helper for finding all paths. Args: current: Current step being visited. end: Target step. path: Current path being explored. paths: Accumulated list of complete paths. max_paths: Maximum paths to find. """ # Prevent infinite loops and limit results if len(paths) >= max_paths or current in path: return path = [*path, current] if current == end: paths.append(path) return for edge in self._adjacency.get(current, []): target = edge.target if isinstance(edge.target, str) else edge.target.name self._find_paths(target, end, path, paths, max_paths)
[docs] def get_step_depth(self, step_name: str) -> int: """Get the minimum depth from initial step to the given step. Args: step_name: Name of the step. Returns: Minimum number of steps from initial to this step. Returns -1 if unreachable. Example: >>> depth = graph.get_step_depth("final_approval") >>> print(f"Step is at depth {depth}") """ if step_name == self.definition.initial_step: return 0 visited = {self.definition.initial_step: 0} queue = [self.definition.initial_step] while queue: current = queue.pop(0) current_depth = visited[current] for edge in self._adjacency.get(current, []): target = edge.target if isinstance(edge.target, str) else edge.target.name if target == step_name: return current_depth + 1 if target not in visited: visited[target] = current_depth + 1 queue.append(target) return -1 # Unreachable