Source code for litestar_workflows.db.engine

"""Persistent execution engine with database storage.

This module provides a persistence-aware execution engine that stores
workflow state in a database using SQLAlchemy.
"""

from __future__ import annotations

import asyncio
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from uuid import UUID, uuid4

from litestar_workflows.core.context import StepExecution, WorkflowContext
from litestar_workflows.core.models import WorkflowInstanceData
from litestar_workflows.core.types import StepStatus, StepType, WorkflowStatus
from litestar_workflows.db.models import (
    HumanTaskModel,
    StepExecutionModel,
    WorkflowDefinitionModel,
    WorkflowInstanceModel,
)
from litestar_workflows.db.repositories import (
    HumanTaskRepository,
    StepExecutionRepository,
    WorkflowDefinitionRepository,
    WorkflowInstanceRepository,
)
from litestar_workflows.engine.graph import WorkflowGraph

if TYPE_CHECKING:
    from sqlalchemy.ext.asyncio import AsyncSession

    from litestar_workflows.core.definition import WorkflowDefinition
    from litestar_workflows.core.protocols import Step, Workflow
    from litestar_workflows.engine.registry import WorkflowRegistry

__all__ = ["PersistentExecutionEngine"]


[docs] class PersistentExecutionEngine: """Execution engine with database persistence. This engine persists all workflow state to a database, enabling durability, recovery, and querying of workflow history. Attributes: registry: The workflow registry for looking up definitions. session: SQLAlchemy async session for database operations. event_bus: Optional event bus for emitting workflow events. """
[docs] def __init__( self, registry: WorkflowRegistry, session: AsyncSession, event_bus: Any | None = None, ) -> None: """Initialize the persistent execution engine. Args: registry: The workflow registry. session: SQLAlchemy async session. event_bus: Optional event bus for events. """ self.registry = registry self.session = session self.event_bus = event_bus # Initialize repositories self._definition_repo = WorkflowDefinitionRepository(session=session) self._instance_repo = WorkflowInstanceRepository(session=session) self._step_repo = StepExecutionRepository(session=session) self._task_repo = HumanTaskRepository(session=session) # Track running tasks in memory self._running: dict[UUID, asyncio.Task[None]] = {}
async def _get_or_create_definition( self, workflow: type[Workflow], ) -> WorkflowDefinitionModel: """Get or create a workflow definition in the database. Args: workflow: The workflow class. Returns: The database definition model. """ definition = workflow.get_definition() # Check if definition exists existing = await self._definition_repo.get_by_name( definition.name, definition.version, ) if existing: return existing # Create new definition model = WorkflowDefinitionModel( name=definition.name, version=definition.version, description=definition.description, definition_json=self._serialize_definition(definition), is_active=True, ) return await self._definition_repo.add(model, auto_commit=True) def _serialize_definition(self, definition: WorkflowDefinition) -> dict[str, Any]: """Serialize a workflow definition to JSON-compatible dict. Args: definition: The workflow definition. Returns: JSON-serializable dictionary. """ return { "name": definition.name, "version": definition.version, "description": definition.description, "steps": list(definition.steps.keys()), "edges": [ { "source": e.get_source_name(), "target": e.get_target_name(), "condition": e.condition, } for e in definition.edges ], "initial_step": definition.initial_step, "terminal_steps": list(definition.terminal_steps), }
[docs] async def start_workflow( self, workflow: type[Workflow], initial_data: dict[str, Any] | None = None, *, tenant_id: str | None = None, created_by: str | None = None, ) -> WorkflowInstanceData: """Start a new workflow instance with persistence. Args: workflow: The workflow class to execute. initial_data: Optional initial data for the workflow. tenant_id: Optional tenant ID for multi-tenancy. created_by: Optional user ID who started the workflow. Returns: The created WorkflowInstanceData. """ # Get or create definition in database definition_model = await self._get_or_create_definition(workflow) definition = workflow.get_definition() # Create instance ID instance_id = uuid4() workflow_id = uuid4() now = datetime.now(timezone.utc) # Create workflow context context = WorkflowContext( workflow_id=workflow_id, instance_id=instance_id, data=initial_data or {}, metadata={ "workflow_name": definition.name, "workflow_version": definition.version, }, current_step=definition.initial_step, step_history=[], started_at=now, tenant_id=tenant_id, ) # Create instance model instance_model = WorkflowInstanceModel( id=instance_id, definition_id=definition_model.id, workflow_name=definition.name, workflow_version=definition.version, status=WorkflowStatus.RUNNING, current_step=definition.initial_step, context_data=initial_data or {}, metadata_={ "workflow_id": str(workflow_id), }, started_at=now, tenant_id=tenant_id, created_by=created_by, ) # Persist instance await self._instance_repo.add(instance_model, auto_commit=True) # Emit event if self.event_bus: await self.event_bus.emit("workflow.started", instance_id=instance_id) # Create in-memory instance data for return instance_data = WorkflowInstanceData( id=instance_id, workflow_name=definition.name, workflow_version=definition.version, status=WorkflowStatus.RUNNING, context=context, current_step=definition.initial_step, error=None, started_at=now, completed_at=None, ) # Start execution in background self._running[instance_id] = asyncio.create_task(self._run_workflow(instance_id, definition)) return instance_data
async def _run_workflow( self, instance_id: UUID, definition: WorkflowDefinition, ) -> None: """Main workflow execution loop with persistence. Args: instance_id: The workflow instance ID. definition: The workflow definition. """ graph = WorkflowGraph.from_definition(definition) # Load instance from database instance = await self._instance_repo.get(instance_id) if not instance: return while instance.status == WorkflowStatus.RUNNING: current_step_name = instance.current_step if not current_step_name: break # Get current step if current_step_name not in definition.steps: instance.status = WorkflowStatus.FAILED instance.error = f"Step '{current_step_name}' not found" instance.completed_at = datetime.now(timezone.utc) await self.session.commit() break step = definition.steps[current_step_name] # Check if it's a human task if step.step_type == StepType.HUMAN: instance.status = WorkflowStatus.WAITING await self.session.commit() # Create human task record await self._create_human_task(instance_id, step) if self.event_bus: await self.event_bus.emit( "workflow.waiting", instance_id=instance_id, step_name=current_step_name, ) return # Execute machine step now = datetime.now(timezone.utc) step_execution = StepExecutionModel( instance_id=instance_id, step_name=current_step_name, step_type=step.step_type, status=StepStatus.RUNNING, started_at=now, ) await self._step_repo.add(step_execution, auto_commit=False) try: # Build context from database state context = self._build_context(instance) # Check if step can execute if not await step.can_execute(context): step_execution.status = StepStatus.SKIPPED step_execution.completed_at = datetime.now(timezone.utc) await self.session.commit() else: # Execute the step result = await step.execute(context) await step.on_success(context, result) # Update step execution step_execution.status = StepStatus.SUCCEEDED step_execution.output_data = result if isinstance(result, dict) else {"result": result} step_execution.completed_at = datetime.now(timezone.utc) # Update instance context instance.context_data = context.data await self.session.commit() except Exception as e: await step.on_failure(context, e) step_execution.status = StepStatus.FAILED step_execution.error = str(e) step_execution.completed_at = datetime.now(timezone.utc) instance.status = WorkflowStatus.FAILED instance.error = str(e) instance.completed_at = datetime.now(timezone.utc) await self.session.commit() break # Check if terminal step (after execution) if graph.is_terminal(current_step_name): instance.status = WorkflowStatus.COMPLETED instance.completed_at = datetime.now(timezone.utc) instance.current_step = None await self.session.commit() break # Find next steps next_steps = graph.get_next_steps(current_step_name, context) if not next_steps: instance.status = WorkflowStatus.COMPLETED instance.completed_at = datetime.now(timezone.utc) instance.current_step = None await self.session.commit() break # Update to next step instance.current_step = next_steps[0] await self.session.commit() # Reload instance for next iteration await self.session.refresh(instance) # Emit completion event if self.event_bus: event_type = "workflow.completed" if instance.status == WorkflowStatus.COMPLETED else "workflow.failed" await self.event_bus.emit(event_type, instance_id=instance_id) # Cleanup if instance_id in self._running: del self._running[instance_id] def _build_context(self, instance: WorkflowInstanceModel) -> WorkflowContext: """Build a WorkflowContext from a database instance. Args: instance: The database instance model. Returns: A WorkflowContext for step execution. """ return WorkflowContext( workflow_id=UUID(instance.metadata_.get("workflow_id", str(instance.id))), instance_id=instance.id, data=instance.context_data or {}, metadata=instance.metadata_ or {}, current_step=instance.current_step or "", step_history=[], # Could load from step_executions if needed started_at=instance.started_at, tenant_id=instance.tenant_id, ) async def _create_human_task( self, instance_id: UUID, step: Step[Any], ) -> HumanTaskModel: """Create a human task record for a human step. Args: instance_id: The workflow instance ID. step: The human step. Returns: The created human task model. """ # Get the step execution record step_exec = await self._step_repo.find_by_step_name(instance_id, step.name) # Create if doesn't exist if not step_exec: step_exec = StepExecutionModel( instance_id=instance_id, step_name=step.name, step_type=StepType.HUMAN, status=StepStatus.WAITING, started_at=datetime.now(timezone.utc), ) await self._step_repo.add(step_exec, auto_commit=True) # Get form schema if available form_schema = None if hasattr(step, "form_schema"): form_schema = step.form_schema # Get title title = getattr(step, "title", step.name) task = HumanTaskModel( instance_id=instance_id, step_execution_id=step_exec.id, step_name=step.name, title=title, description=step.description, form_schema=form_schema, status="pending", ) return await self._task_repo.add(task, auto_commit=True)
[docs] async def complete_human_task( self, instance_id: UUID, step_name: str, user_id: str, data: dict[str, Any], ) -> None: """Complete a human task with user-provided data. Args: instance_id: The workflow instance ID. step_name: Name of the human task step. user_id: ID of the user completing the task. data: User-provided data to merge into context. """ # Get instance instance = await self._instance_repo.get(instance_id) if not instance: msg = f"Instance {instance_id} not found" raise ValueError(msg) if instance.status != WorkflowStatus.WAITING: msg = f"Instance {instance_id} is not waiting" raise ValueError(msg) if instance.current_step != step_name: msg = f"Instance is at step '{instance.current_step}', not '{step_name}'" raise ValueError(msg) # Update step execution step_exec = await self._step_repo.find_by_step_name(instance_id, step_name) if step_exec: step_exec.status = StepStatus.SUCCEEDED step_exec.output_data = data step_exec.completed_at = datetime.now(timezone.utc) step_exec.completed_by = user_id # Complete the human task record tasks = await self._task_repo.find_by_instance(instance_id) for task in tasks: if task.step_name == step_name and task.status == "pending": await self._task_repo.complete_task(task.id, user_id) break # Merge data into context context_data = instance.context_data or {} context_data.update(data) instance.context_data = context_data # Resume workflow instance.status = WorkflowStatus.RUNNING await self.session.commit() # Get definition and find next step definition = self.registry.get_definition(instance.workflow_name) graph = WorkflowGraph.from_definition(definition) context = self._build_context(instance) next_steps = graph.get_next_steps(step_name, context) if next_steps: instance.current_step = next_steps[0] await self.session.commit() # Resume execution if instance_id not in self._running: self._running[instance_id] = asyncio.create_task(self._run_workflow(instance_id, definition))
[docs] async def cancel_workflow(self, instance_id: UUID, reason: str) -> None: """Cancel a running workflow. Args: instance_id: The workflow instance ID. reason: Reason for cancellation. """ instance = await self._instance_repo.get(instance_id) if not instance: msg = f"Instance {instance_id} not found" raise ValueError(msg) instance.status = WorkflowStatus.CANCELED instance.error = f"Canceled: {reason}" instance.completed_at = datetime.now(timezone.utc) await self.session.commit() # Cancel any pending human tasks tasks = await self._task_repo.find_by_instance(instance_id) for task in tasks: if task.status == "pending": await self._task_repo.cancel_task(task.id) # Cancel running task if instance_id in self._running: self._running[instance_id].cancel() del self._running[instance_id] if self.event_bus: await self.event_bus.emit( "workflow.canceled", instance_id=instance_id, reason=reason, )
[docs] async def get_instance(self, instance_id: UUID) -> WorkflowInstanceData: """Get a workflow instance by ID. Args: instance_id: The workflow instance ID. Returns: The WorkflowInstanceData. Raises: KeyError: If instance not found. """ instance = await self._instance_repo.get(instance_id) if not instance: msg = f"Instance {instance_id} not found" raise KeyError(msg) # Load step history step_executions = await self._step_repo.find_by_instance(instance_id) step_history = [ StepExecution( step_name=se.step_name, status=se.status, started_at=se.started_at, completed_at=se.completed_at, result=se.output_data, error=se.error, ) for se in step_executions ] context = WorkflowContext( workflow_id=UUID(instance.metadata_.get("workflow_id", str(instance.id))), instance_id=instance.id, data=instance.context_data or {}, metadata=instance.metadata_ or {}, current_step=instance.current_step or "", step_history=step_history, started_at=instance.started_at, tenant_id=instance.tenant_id, ) return WorkflowInstanceData( id=instance.id, workflow_name=instance.workflow_name, workflow_version=instance.workflow_version, status=instance.status, context=context, current_step=instance.current_step, error=instance.error, started_at=instance.started_at, completed_at=instance.completed_at, )
[docs] def get_running_instances(self) -> list[UUID]: """Get IDs of currently running instances. Returns: List of running instance IDs. """ return list(self._running.keys())