Source code for litestar_workflows.steps.groups

"""Composable step groups for litestar-workflows."""

from __future__ import annotations

import asyncio
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from litestar_workflows.core import WorkflowContext
    from litestar_workflows.core.protocols import ExecutionEngine, Step


[docs] class StepGroup(ABC): """Base for composable step patterns. Step groups allow you to compose multiple steps into reusable patterns like sequences, parallel execution, and conditional branching. """
[docs] @abstractmethod async def execute(self, context: WorkflowContext, engine: ExecutionEngine) -> Any: """Execute the step group. Args: context: The workflow execution context. engine: The execution engine to delegate step execution. Returns: The result of the group execution. Raises: Exception: Any exception during group execution. """ ...
[docs] class SequentialGroup(StepGroup): """Execute steps in sequence, passing results. This implements the Chain pattern where each step receives the result of the previous step as input. The final step's result is returned. Example: >>> group = SequentialGroup(step1, step2, step3) >>> result = await group.execute(context, engine) # step1 -> step2(result1) -> step3(result2) -> result3 """
[docs] def __init__(self, *steps: Step[Any] | StepGroup) -> None: """Initialize a sequential group. Args: *steps: Steps or groups to execute in sequence. """ self.steps = steps
[docs] async def execute(self, context: WorkflowContext, engine: ExecutionEngine) -> Any: """Execute steps sequentially, passing results forward. Args: context: The workflow execution context. engine: The execution engine to delegate step execution. Returns: The result of the final step. Raises: Exception: Any exception from step execution. """ result: Any = None for step in self.steps: if isinstance(step, StepGroup): result = await step.execute(context, engine) else: result = await engine.execute_step(step, context, previous_result=result) return result
[docs] class ParallelGroup(StepGroup): """Execute steps in parallel. This implements the Group pattern where multiple steps execute concurrently. Optionally supports a callback step (Chord pattern) that receives all results. Example: >>> # Simple parallel execution >>> group = ParallelGroup(step1, step2, step3) >>> results = await group.execute(context, engine) # [result1, result2, result3] >>> # Chord pattern with callback >>> group = ParallelGroup(step1, step2, step3, callback=aggregate_step) >>> result = await group.execute(context, engine) # aggregate_step([r1, r2, r3]) """
[docs] def __init__( self, *steps: Step[Any] | StepGroup, callback: Step[Any] | None = None, ) -> None: """Initialize a parallel group. Args: *steps: Steps or groups to execute in parallel. callback: Optional callback step to process results (Chord pattern). """ self.steps = steps self.callback = callback
[docs] async def execute(self, context: WorkflowContext, engine: ExecutionEngine) -> list[Any] | Any: """Execute steps in parallel using asyncio.gather. Args: context: The workflow execution context. engine: The execution engine to delegate step execution. Returns: List of results if no callback, otherwise callback result. Raises: Exception: Any exception from step execution. """ tasks = [] for step in self.steps: if isinstance(step, StepGroup): tasks.append(step.execute(context, engine)) else: tasks.append(engine.execute_step(step, context)) results = await asyncio.gather(*tasks) if self.callback: return await engine.execute_step(self.callback, context, previous_result=results) return results
[docs] class ConditionalGroup(StepGroup): """Execute one of multiple branches based on condition. This implements the Gateway pattern where a condition function determines which branch to execute. Similar to if/else or switch statements. Example: >>> def check_status(ctx: WorkflowContext) -> str: ... return "approved" if ctx.get("approved") else "rejected" >>> group = ConditionalGroup( ... condition=check_status, ... branches={ ... "approved": approve_step, ... "rejected": reject_step, ... }, ... ) >>> result = await group.execute(context, engine) """
[docs] def __init__( self, condition: Callable[[WorkflowContext], str], branches: dict[str, Step[Any] | StepGroup], ) -> None: """Initialize a conditional group. Args: condition: Function that evaluates context and returns branch key. branches: Map of branch keys to steps or groups. """ self.condition = condition self.branches = branches
[docs] async def execute(self, context: WorkflowContext, engine: ExecutionEngine) -> Any: """Execute the branch selected by the condition. Args: context: The workflow execution context. engine: The execution engine to delegate step execution. Returns: The result of the selected branch, or None if no match. Raises: Exception: Any exception from step execution. """ branch_key = self.condition(context) if branch_key not in self.branches: return None branch = self.branches[branch_key] if isinstance(branch, StepGroup): return await branch.execute(context, engine) return await engine.execute_step(branch, context)