Source code for litestar_workflows.web.controllers

"""REST API controllers for workflow management.

This module provides three controller classes for managing workflows:
- WorkflowDefinitionController: Manage workflow definitions and schemas
- WorkflowInstanceController: Start, monitor, and control workflow executions
- HumanTaskController: Manage human approval tasks
"""

from __future__ import annotations

from typing import Any, ClassVar
from uuid import UUID

from litestar import Controller, get, post
from litestar.exceptions import NotFoundException
from litestar.params import Parameter

from litestar_workflows.core.types import WorkflowStatus
from litestar_workflows.engine.local import LocalExecutionEngine  # noqa: TC001 - needed for DI
from litestar_workflows.engine.registry import WorkflowRegistry  # noqa: TC001 - needed for DI
from litestar_workflows.web.dto import (
    CompleteTaskDTO,
    GraphDTO,
    HumanTaskDTO,
    ReassignTaskDTO,
    StartWorkflowDTO,
    StepExecutionDTO,
    WorkflowDefinitionDTO,
    WorkflowInstanceDetailDTO,
    WorkflowInstanceDTO,
)
from litestar_workflows.web.graph import (
    generate_mermaid_graph,
    generate_mermaid_graph_with_state,
    parse_graph_to_dict,
)

# Try to import DB repositories - will be None if [db] extra not installed
try:
    from litestar_workflows.db.repositories import (
        HumanTaskRepository,
        WorkflowInstanceRepository,
    )
except ImportError:  # pragma: no cover
    HumanTaskRepository = Any  # type: ignore[misc, assignment]
    WorkflowInstanceRepository = Any  # type: ignore[misc, assignment]

__all__ = [
    "HumanTaskController",
    "WorkflowDefinitionController",
    "WorkflowInstanceController",
]


[docs] class WorkflowDefinitionController(Controller): """API controller for workflow definitions. Provides endpoints for listing and retrieving workflow definitions, including their schemas and graph visualizations. Tags: Workflow Definitions """ path = "/definitions" tags: ClassVar[list[str]] = ["Workflow Definitions"] @get("/") async def list_definitions( self, workflow_registry: WorkflowRegistry, active_only: bool = Parameter( default=True, description="Filter to only active workflow definitions", ), ) -> list[WorkflowDefinitionDTO]: """List all registered workflow definitions. Returns a list of workflow definitions available for instantiation. By default, only active definitions are returned. Args: workflow_registry: Injected workflow registry. active_only: Whether to filter to only active definitions. Returns: List of workflow definition DTOs. """ definitions = workflow_registry.list_definitions() result = [] for definition in definitions: dto = WorkflowDefinitionDTO( name=definition.name, version=definition.version, description=definition.description, steps=list(definition.steps.keys()), edges=[ { "source": edge.get_source_name(), "target": edge.get_target_name(), "condition": str(edge.condition) if edge.condition else None, } for edge in definition.edges ], initial_step=definition.initial_step, terminal_steps=list(definition.terminal_steps), ) result.append(dto) return result @get("/{name:str}") async def get_definition( self, name: str, workflow_registry: WorkflowRegistry, version: str | None = Parameter( default=None, description="Specific version to retrieve. If omitted, returns latest.", ), ) -> WorkflowDefinitionDTO: """Get a specific workflow definition by name. Args: name: The workflow name. workflow_registry: Injected workflow registry. version: Optional specific version to retrieve. Returns: Workflow definition DTO. Raises: NotFoundException: If workflow definition not found. """ try: definition = workflow_registry.get_definition(name, version=version) except KeyError as e: raise NotFoundException(detail=f"Workflow definition '{name}' not found") from e return WorkflowDefinitionDTO( name=definition.name, version=definition.version, description=definition.description, steps=list(definition.steps.keys()), edges=[ { "source": edge.get_source_name(), "target": edge.get_target_name(), "condition": str(edge.condition) if edge.condition else None, } for edge in definition.edges ], initial_step=definition.initial_step, terminal_steps=list(definition.terminal_steps), ) @get("/{name:str}/graph") async def get_definition_graph( self, name: str, workflow_registry: WorkflowRegistry, graph_format: str = Parameter( default="mermaid", description="Graph format: 'mermaid' or 'json'", ), ) -> GraphDTO: """Get workflow graph visualization. Returns a visual representation of the workflow graph, either as MermaidJS source or as a structured JSON object. Args: name: The workflow name. workflow_registry: Injected workflow registry. graph_format: Graph format ('mermaid' or 'json'). Returns: Graph DTO with visualization data. Raises: NotFoundException: If workflow definition not found. """ try: definition = workflow_registry.get_definition(name) except KeyError as e: raise NotFoundException(detail=f"Workflow definition '{name}' not found") from e if graph_format == "mermaid": mermaid_source = generate_mermaid_graph(definition) graph_dict = parse_graph_to_dict(definition) return GraphDTO( mermaid_source=mermaid_source, nodes=graph_dict["nodes"], edges=graph_dict["edges"], ) if graph_format == "json": graph_dict = parse_graph_to_dict(definition) return GraphDTO( mermaid_source="", nodes=graph_dict["nodes"], edges=graph_dict["edges"], ) raise NotFoundException(detail=f"Unknown format: {graph_format}")
[docs] class WorkflowInstanceController(Controller): """API controller for workflow instances. Provides endpoints for starting, listing, and managing workflow instance executions. Tags: Workflow Instances """ path = "/instances" tags: ClassVar[list[str]] = ["Workflow Instances"] @post("/", dto=None, return_dto=None) async def start_workflow( self, data: StartWorkflowDTO, workflow_engine: LocalExecutionEngine, workflow_registry: WorkflowRegistry, ) -> WorkflowInstanceDTO: """Start a new workflow instance. Creates and starts a new instance of the specified workflow definition. Args: data: Workflow start parameters. workflow_engine: Injected execution engine. workflow_registry: Injected workflow registry. Returns: Workflow instance DTO. Raises: NotFoundException: If workflow definition not found. """ try: workflow_class = workflow_registry.get_workflow_class(data.definition_name) except KeyError as e: raise NotFoundException(detail=f"Workflow '{data.definition_name}' not found") from e # Start the workflow instance = await workflow_engine.start_workflow( workflow_class, initial_data=data.input_data or {}, ) return WorkflowInstanceDTO( id=instance.id, definition_name=instance.workflow_name, status=instance.status.value, current_step=instance.context.current_step, started_at=instance.context.started_at, completed_at=None, created_by=data.user_id, ) @get("/") async def list_instances( self, workflow_instance_repo: WorkflowInstanceRepository, workflow_name: str | None = Parameter( default=None, description="Filter by workflow name", ), status: str | None = Parameter( default=None, description="Filter by status", ), limit: int = Parameter( default=50, le=100, description="Maximum number of results", ), offset: int = Parameter( default=0, ge=0, description="Number of results to skip", ), ) -> list[WorkflowInstanceDTO]: """List workflow instances with optional filtering. Args: workflow_instance_repo: Injected workflow instance repository. workflow_name: Optional workflow name filter. status: Optional status filter. limit: Maximum number of results. offset: Pagination offset. Returns: List of workflow instance DTOs. """ workflow_status = WorkflowStatus(status) if status else None if workflow_name: instances, _ = await workflow_instance_repo.find_by_workflow( workflow_name=workflow_name, status=workflow_status, limit=limit, offset=offset, ) else: # List all instances instances = await workflow_instance_repo.list() result = [] for instance in instances[:limit]: # Apply limit manually if not filtered dto = WorkflowInstanceDTO( id=instance.id, definition_name=instance.workflow_name, status=instance.status.value, current_step=instance.current_step, started_at=instance.started_at, completed_at=instance.completed_at, created_by=instance.created_by, ) result.append(dto) return result @get("/{instance_id:uuid}") async def get_instance( self, instance_id: UUID, workflow_instance_repo: WorkflowInstanceRepository, ) -> WorkflowInstanceDetailDTO: """Get detailed workflow instance information. Returns comprehensive information about a workflow instance including execution context, step history, and current state. Args: instance_id: The workflow instance ID. workflow_instance_repo: Injected workflow instance repository. Returns: Detailed workflow instance DTO. Raises: NotFoundException: If instance not found. """ instance = await workflow_instance_repo.get(instance_id) if not instance: raise NotFoundException(detail=f"Workflow instance {instance_id} not found") # Convert step executions to DTOs step_history = [] if hasattr(instance, "step_executions") and instance.step_executions: for step_exec in instance.step_executions: step_dto = StepExecutionDTO( id=step_exec.id, step_name=step_exec.step_name, status=step_exec.status.value, started_at=step_exec.started_at, completed_at=step_exec.completed_at, error=step_exec.error, ) step_history.append(step_dto) return WorkflowInstanceDetailDTO( id=instance.id, definition_name=instance.workflow_name, status=instance.status.value, current_step=instance.current_step, started_at=instance.started_at, completed_at=instance.completed_at, created_by=instance.created_by, context_data=instance.context_data, metadata=instance.metadata_, step_history=step_history, error=instance.error, ) @get("/{instance_id:uuid}/graph") async def get_instance_graph( self, instance_id: UUID, workflow_instance_repo: WorkflowInstanceRepository, workflow_registry: WorkflowRegistry, ) -> GraphDTO: """Get workflow instance graph with execution state highlighting. Returns a visual representation of the workflow with the current execution state highlighted, showing completed and failed steps. Args: instance_id: The workflow instance ID. workflow_instance_repo: Injected workflow instance repository. workflow_registry: Injected workflow registry. Returns: Graph DTO with state highlighting. Raises: NotFoundException: If instance not found. """ instance = await workflow_instance_repo.get(instance_id) if not instance: raise NotFoundException(detail=f"Workflow instance {instance_id} not found") try: definition = workflow_registry.get_definition(instance.workflow_name) except KeyError as e: raise NotFoundException(detail=f"Workflow definition '{instance.workflow_name}' not found") from e # Extract completed and failed steps from execution history completed_steps = [] failed_steps = [] if hasattr(instance, "step_executions") and instance.step_executions: for step_exec in instance.step_executions: from litestar_workflows.core.types import StepStatus if step_exec.status == StepStatus.SUCCEEDED: completed_steps.append(step_exec.step_name) elif step_exec.status == StepStatus.FAILED: failed_steps.append(step_exec.step_name) mermaid_source = generate_mermaid_graph_with_state( definition, current_step=instance.current_step, completed_steps=completed_steps, failed_steps=failed_steps, ) graph_dict = parse_graph_to_dict(definition) return GraphDTO( mermaid_source=mermaid_source, nodes=graph_dict["nodes"], edges=graph_dict["edges"], ) @post("/{instance_id:uuid}/cancel") async def cancel_instance( self, instance_id: UUID, workflow_engine: LocalExecutionEngine, workflow_instance_repo: WorkflowInstanceRepository, reason: str = Parameter( default="User canceled", description="Reason for cancellation", ), ) -> WorkflowInstanceDTO: """Cancel a running workflow instance. Args: instance_id: The workflow instance ID. workflow_engine: Injected execution engine. workflow_instance_repo: Injected workflow instance repository. reason: Cancellation reason. Returns: Updated workflow instance DTO. Raises: NotFoundException: If instance not found. """ instance = await workflow_instance_repo.get(instance_id) if not instance: raise NotFoundException(detail=f"Workflow instance {instance_id} not found") await workflow_engine.cancel_workflow(instance_id, reason=reason) # Reload instance instance = await workflow_instance_repo.get(instance_id) return WorkflowInstanceDTO( id=instance.id, definition_name=instance.workflow_name, status=instance.status.value, current_step=instance.current_step, started_at=instance.started_at, completed_at=instance.completed_at, created_by=instance.created_by, ) @post("/{instance_id:uuid}/retry") async def retry_instance( self, instance_id: UUID, workflow_engine: LocalExecutionEngine, workflow_instance_repo: WorkflowInstanceRepository, from_step: str | None = Parameter( default=None, description="Step to retry from (defaults to failed step)", ), ) -> WorkflowInstanceDTO: """Retry a failed workflow instance. Args: instance_id: The workflow instance ID. workflow_engine: Injected execution engine. workflow_instance_repo: Injected workflow instance repository. from_step: Optional step to retry from. Returns: Updated workflow instance DTO. Raises: NotFoundException: If instance not found. """ instance = await workflow_instance_repo.get(instance_id) if not instance: raise NotFoundException(detail=f"Workflow instance {instance_id} not found") # Retry logic would go here - this is a placeholder # In a real implementation, this would call engine.retry_workflow() # For now, we'll just return the instance return WorkflowInstanceDTO( id=instance.id, definition_name=instance.workflow_name, status=instance.status.value, current_step=instance.current_step, started_at=instance.started_at, completed_at=instance.completed_at, created_by=instance.created_by, )
[docs] class HumanTaskController(Controller): """API controller for human tasks. Provides endpoints for managing human approval tasks including listing, completing, and reassigning tasks. Tags: Human Tasks """ path = "/tasks" tags: ClassVar[list[str]] = ["Human Tasks"] @get("/") async def list_tasks( self, human_task_repo: HumanTaskRepository, assignee_id: str | None = Parameter( default=None, description="Filter by assignee ID", ), assignee_group: str | None = Parameter( default=None, description="Filter by assignee group", ), status: str = Parameter( default="pending", description="Filter by task status", ), ) -> list[HumanTaskDTO]: """List human tasks with optional filtering. Returns tasks that match the specified filters. By default, returns all pending tasks. Args: human_task_repo: Injected human task repository. assignee_id: Optional assignee ID filter. assignee_group: Optional group filter. status: Task status filter. Returns: List of human task DTOs. """ if status == "pending": tasks = await human_task_repo.find_pending( assignee_id=assignee_id, assignee_group=assignee_group, ) else: # For other statuses, we'd need to implement additional filters tasks = [] result = [] for task in tasks: dto = HumanTaskDTO( id=task.id, instance_id=task.instance_id, step_name=task.step_name, title=task.title, description=task.description, assignee=task.assignee_id, status=task.status, due_date=task.due_at, created_at=task.created_at, form_schema=task.form_schema, ) result.append(dto) return result @get("/{task_id:uuid}") async def get_task( self, task_id: UUID, human_task_repo: HumanTaskRepository, ) -> HumanTaskDTO: """Get detailed information about a human task. Args: task_id: The task ID. human_task_repo: Injected human task repository. Returns: Human task DTO. Raises: NotFoundException: If task not found. """ task = await human_task_repo.get(task_id) if not task: raise NotFoundException(detail=f"Task {task_id} not found") return HumanTaskDTO( id=task.id, instance_id=task.instance_id, step_name=task.step_name, title=task.title, description=task.description, assignee=task.assignee_id, status=task.status, due_date=task.due_at, created_at=task.created_at, form_schema=task.form_schema, ) @post( "/{task_id:uuid}/complete", dto=None, return_dto=None, ) async def complete_task( self, task_id: UUID, data: CompleteTaskDTO, workflow_engine: LocalExecutionEngine, human_task_repo: HumanTaskRepository, ) -> WorkflowInstanceDTO: """Complete a human task with form data. Submits the task completion data and resumes the workflow execution. Args: task_id: The task ID. data: Task completion data. workflow_engine: Injected execution engine. human_task_repo: Injected human task repository. Returns: Updated workflow instance DTO. Raises: NotFoundException: If task not found. """ task = await human_task_repo.get(task_id) if not task: raise NotFoundException(detail=f"Task {task_id} not found") # Complete the task await human_task_repo.complete_task(task_id, completed_by=data.completed_by) # Resume the workflow with the task output await workflow_engine.complete_human_task( instance_id=task.instance_id, step_name=task.step_name, user_id=data.completed_by, data=data.output_data, ) # Return updated instance (placeholder) return WorkflowInstanceDTO( id=task.instance_id, definition_name="", # Would need to fetch from instance status="RUNNING", current_step=None, started_at=task.created_at, created_by=data.completed_by, ) @post( "/{task_id:uuid}/reassign", dto=None, return_dto=None, ) async def reassign_task( self, task_id: UUID, data: ReassignTaskDTO, human_task_repo: HumanTaskRepository, ) -> HumanTaskDTO: """Reassign a task to a different user. Args: task_id: The task ID. data: Reassignment data. human_task_repo: Injected human task repository. Returns: Updated human task DTO. Raises: NotFoundException: If task not found. """ task = await human_task_repo.get(task_id) if not task: raise NotFoundException(detail=f"Task {task_id} not found") # Update assignee task.assignee_id = data.new_assignee await human_task_repo.session.flush() return HumanTaskDTO( id=task.id, instance_id=task.instance_id, step_name=task.step_name, title=task.title, description=task.description, assignee=task.assignee_id, status=task.status, due_date=task.due_at, created_at=task.created_at, form_schema=task.form_schema, )