Source code for litestar_workflows.core.definition

"""Workflow definition and edge structures.

This module provides the data structures for defining workflow graphs, including
edges (transitions) and the complete workflow definition.
"""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from litestar_workflows.core.context import WorkflowContext
    from litestar_workflows.core.protocols import Step

__all__ = ["Edge", "WorkflowDefinition"]


[docs] @dataclass class Edge: """Defines a transition between workflow steps. An edge represents a directed connection from one step to another, optionally conditioned on a predicate function or expression. Attributes: source: Name of the source step or the Step class itself. target: Name of the target step or the Step class itself. condition: Optional condition for edge traversal. Can be a callable that takes WorkflowContext and returns bool, or a string expression. Example: >>> edge = Edge( ... source="submit", ... target="review", ... condition=lambda ctx: ctx.get("auto_approve") is False, ... ) >>> conditional_edge = Edge( ... source="review", target="approve", condition="context.get('approved') == True" ... ) """ source: str | type[Step] target: str | type[Step] condition: str | Callable[[WorkflowContext], bool] | None = None
[docs] def evaluate_condition(self, context: WorkflowContext) -> bool: """Evaluate the edge condition against the workflow context. Args: context: The current workflow execution context. Returns: True if the condition is met or if no condition exists, False otherwise. Example: >>> edge = Edge(source="a", target="b", condition=lambda ctx: ctx.get("value") > 10) >>> context.set("value", 15) >>> edge.evaluate_condition(context) True """ if self.condition is None: return True if callable(self.condition): return self.condition(context) # For string expressions, could implement SpEL-like evaluation # For now, we'll just return True (future enhancement) return True
[docs] def get_source_name(self) -> str: """Get the name of the source step. Returns: The source step name as a string. """ if isinstance(self.source, str): return self.source return self.source.name
[docs] def get_target_name(self) -> str: """Get the name of the target step. Returns: The target step name as a string. """ if isinstance(self.target, str): return self.target return self.target.name
[docs] @dataclass class WorkflowDefinition: """Declarative workflow structure. The WorkflowDefinition captures the complete structure of a workflow including all steps, edges, and metadata. It serves as the blueprint for workflow execution. Attributes: name: Unique identifier for the workflow. version: Version string for workflow versioning. description: Human-readable description of the workflow's purpose. steps: Dictionary mapping step names to Step instances. edges: List of Edge instances defining the workflow graph. initial_step: Name of the step to execute first. terminal_steps: Set of step names that mark workflow completion. Example: >>> from litestar_workflows.core.types import StepType >>> definition = WorkflowDefinition( ... name="approval_flow", ... version="1.0.0", ... description="Document approval workflow", ... steps={ ... "submit": SubmitStep(), ... "review": ReviewStep(), ... "approve": ApproveStep(), ... }, ... edges=[ ... Edge(source="submit", target="review"), ... Edge(source="review", target="approve"), ... ], ... initial_step="submit", ... terminal_steps={"approve"}, ... ) """ name: str version: str description: str steps: dict[str, Step[Any]] edges: list[Edge] initial_step: str terminal_steps: set[str] = field(default_factory=set)
[docs] def validate(self) -> list[str]: """Validate the workflow definition for common issues. Returns: List of validation error messages. Empty list if valid. Example: >>> errors = definition.validate() >>> if errors: ... print("Validation errors:", errors) """ errors: list[str] = [] # Check initial step exists if self.initial_step not in self.steps: errors.append(f"Initial step '{self.initial_step}' not found in steps") # Check terminal steps exist for terminal in self.terminal_steps: if terminal not in self.steps: errors.append(f"Terminal step '{terminal}' not found in steps") # Check edge validity for i, edge in enumerate(self.edges): source_name = edge.get_source_name() target_name = edge.get_target_name() if source_name not in self.steps: errors.append(f"Edge {i}: source step '{source_name}' not found") if target_name not in self.steps: errors.append(f"Edge {i}: target step '{target_name}' not found") # Check for unreachable steps (excluding terminal steps) reachable = {self.initial_step} changed = True while changed: changed = False for edge in self.edges: source = edge.get_source_name() target = edge.get_target_name() if source in reachable and target not in reachable: reachable.add(target) changed = True unreachable = set(self.steps.keys()) - reachable for step_name in unreachable: if step_name not in self.terminal_steps: errors.append(f"Step '{step_name}' is unreachable from initial step") return errors
[docs] def get_next_steps(self, current_step: str, context: WorkflowContext) -> list[str]: """Get the list of next steps from the current step based on edge conditions. Args: current_step: Name of the current step. context: The workflow execution context for condition evaluation. Returns: List of step names that should be executed next. Example: >>> next_steps = definition.get_next_steps("review", context) >>> if "approve" in next_steps: ... print("Moving to approval") """ next_steps = [] for edge in self.edges: if edge.get_source_name() == current_step and edge.evaluate_condition(context): next_steps.append(edge.get_target_name()) return next_steps
[docs] def to_mermaid(self) -> str: """Generate a MermaidJS graph representation of the workflow. Returns: MermaidJS graph definition as a string. Example: >>> mermaid = definition.to_mermaid() >>> print(mermaid) graph TD submit[Submit] review{Review} approve[Approve] submit --> review review --> approve """ lines = ["graph TD"] # Add nodes with shapes based on step type for step_name, step in self.steps.items(): shape_start = "[" shape_end = "]" # Use different shapes for different step types if hasattr(step, "step_type"): from litestar_workflows.core.types import StepType if step.step_type == StepType.HUMAN: shape_start = "{{" shape_end = "}}" elif step.step_type == StepType.GATEWAY: shape_start = "{" shape_end = "}" elif step.step_type == StepType.TIMER: shape_start = "([[" shape_end = "]])" # Mark initial and terminal steps prefix = "" if step_name == self.initial_step: prefix = "START: " elif step_name in self.terminal_steps: prefix = "END: " lines.append(f" {step_name}{shape_start}{prefix}{step_name.replace('_', ' ').title()}{shape_end}") # Add edges for edge in self.edges: source = edge.get_source_name() target = edge.get_target_name() label = "" if edge.condition is not None: if isinstance(edge.condition, str): # Sanitize for mermaid: remove special characters that break syntax safe_condition = ( edge.condition.replace("'", "") .replace('"', "") .replace("(", "") .replace(")", "") .replace("<", "") .replace(">", "") .replace("|", "") .replace("[", "") .replace("]", "") .replace("{", "") .replace("}", "") ) label = f"|{safe_condition}|" else: label = "|conditional|" lines.append(f" {source} -->{label} {target}") return "\n".join(lines)
[docs] def to_mermaid_with_state( self, current_step: str | None = None, completed_steps: list[str] | None = None, failed_steps: list[str] | None = None, ) -> str: """Generate a MermaidJS graph with execution state highlighting. Args: current_step: Name of the currently executing step. completed_steps: List of successfully completed step names. failed_steps: List of failed step names. Returns: MermaidJS graph definition with state styling. Example: >>> mermaid = definition.to_mermaid_with_state( ... current_step="review", completed_steps=["submit"], failed_steps=[] ... ) """ completed_steps = completed_steps or [] failed_steps = failed_steps or [] base_graph = self.to_mermaid() lines = base_graph.split("\n") # Add styling style_lines = [] for step_name in completed_steps: style_lines.append(f" style {step_name} fill:#90EE90,stroke:#006400,stroke-width:2px") for step_name in failed_steps: style_lines.append(f" style {step_name} fill:#FFB6C1,stroke:#8B0000,stroke-width:2px") if current_step: style_lines.append(f" style {current_step} fill:#FFD700,stroke:#FFA500,stroke-width:3px") lines.extend(style_lines) return "\n".join(lines)