Source code for litestar_workflows.db.repositories

"""Repository implementations for workflow persistence.

This module provides async repositories for CRUD operations on workflow
models using advanced-alchemy's repository pattern.
"""

from __future__ import annotations

from typing import TYPE_CHECKING
from uuid import UUID

from advanced_alchemy.filters import LimitOffset, OrderBy
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
from sqlalchemy import and_, or_, select

from litestar_workflows.core.types import StepStatus, WorkflowStatus
from litestar_workflows.db.models import (
    HumanTaskModel,
    StepExecutionModel,
    WorkflowDefinitionModel,
    WorkflowInstanceModel,
)

if TYPE_CHECKING:
    from collections.abc import Sequence

__all__ = [
    "HumanTaskRepository",
    "StepExecutionRepository",
    "WorkflowDefinitionRepository",
    "WorkflowInstanceRepository",
]


[docs] class WorkflowDefinitionRepository(SQLAlchemyAsyncRepository[WorkflowDefinitionModel]): """Repository for workflow definition CRUD operations. Provides methods for managing workflow definitions including version management and activation status. """ model_type = WorkflowDefinitionModel
[docs] async def get_by_name( self, name: str, version: str | None = None, *, active_only: bool = True, ) -> WorkflowDefinitionModel | None: """Get a workflow definition by name and optional version. Args: name: The workflow name. version: Optional specific version. If None, returns the latest active version. active_only: If True, only return active definitions. Returns: The workflow definition or None if not found. """ conditions = [WorkflowDefinitionModel.name == name] if version: conditions.append(WorkflowDefinitionModel.version == version) if active_only: conditions.append(WorkflowDefinitionModel.is_active == True) # noqa: E712 stmt = ( select(WorkflowDefinitionModel) .where(and_(*conditions)) .order_by(WorkflowDefinitionModel.created_at.desc()) .limit(1) ) result = await self.session.execute(stmt) return result.scalar_one_or_none()
[docs] async def get_latest_version(self, name: str) -> WorkflowDefinitionModel | None: """Get the latest active version of a workflow definition. Args: name: The workflow name. Returns: The latest active workflow definition or None. """ return await self.get_by_name(name, active_only=True)
[docs] async def list_active(self) -> Sequence[WorkflowDefinitionModel]: """List all active workflow definitions. Returns: List of active workflow definitions. """ stmt = ( select(WorkflowDefinitionModel) .where(WorkflowDefinitionModel.is_active == True) # noqa: E712 .order_by(WorkflowDefinitionModel.name, WorkflowDefinitionModel.version.desc()) ) result = await self.session.execute(stmt) return result.scalars().all()
[docs] async def deactivate_version(self, name: str, version: str) -> bool: """Deactivate a specific workflow version. Args: name: The workflow name. version: The version to deactivate. Returns: True if a definition was deactivated. """ definition = await self.get_by_name(name, version, active_only=False) if definition: definition.is_active = False await self.session.flush() return True return False
[docs] class WorkflowInstanceRepository(SQLAlchemyAsyncRepository[WorkflowInstanceModel]): """Repository for workflow instance CRUD operations. Provides methods for querying and managing workflow instances including filtering by status, user, and workflow name. """ model_type = WorkflowInstanceModel
[docs] async def find_by_workflow( self, workflow_name: str, status: WorkflowStatus | None = None, limit: int = 100, offset: int = 0, ) -> tuple[Sequence[WorkflowInstanceModel], int]: """Find instances by workflow name with optional status filter. Args: workflow_name: The workflow name to filter by. status: Optional status filter. limit: Maximum number of results. offset: Number of results to skip. Returns: Tuple of (instances, total_count). """ conditions = [WorkflowInstanceModel.workflow_name == workflow_name] if status: conditions.append(WorkflowInstanceModel.status == status) return await self.list_and_count( *conditions, LimitOffset(limit=limit, offset=offset), OrderBy(field_name="created_at", sort_order="desc"), )
[docs] async def find_by_user( self, user_id: str, status: WorkflowStatus | None = None, ) -> Sequence[WorkflowInstanceModel]: """Find instances created by a specific user. Args: user_id: The user ID to filter by. status: Optional status filter. Returns: List of workflow instances. """ conditions = [WorkflowInstanceModel.created_by == user_id] if status: conditions.append(WorkflowInstanceModel.status == status) stmt = select(WorkflowInstanceModel).where(and_(*conditions)).order_by(WorkflowInstanceModel.created_at.desc()) result = await self.session.execute(stmt) return result.scalars().all()
[docs] async def find_by_tenant( self, tenant_id: str, status: WorkflowStatus | None = None, limit: int = 100, offset: int = 0, ) -> tuple[Sequence[WorkflowInstanceModel], int]: """Find instances by tenant ID. Args: tenant_id: The tenant ID to filter by. status: Optional status filter. limit: Maximum number of results. offset: Number of results to skip. Returns: Tuple of (instances, total_count). """ conditions = [WorkflowInstanceModel.tenant_id == tenant_id] if status: conditions.append(WorkflowInstanceModel.status == status) return await self.list_and_count( *conditions, LimitOffset(limit=limit, offset=offset), OrderBy(field_name="created_at", sort_order="desc"), )
[docs] async def find_running(self) -> Sequence[WorkflowInstanceModel]: """Find all running or waiting workflow instances. Returns: List of active workflow instances. """ stmt = ( select(WorkflowInstanceModel) .where( WorkflowInstanceModel.status.in_( [ WorkflowStatus.RUNNING, WorkflowStatus.WAITING, ] ) ) .order_by(WorkflowInstanceModel.started_at) ) result = await self.session.execute(stmt) return result.scalars().all()
[docs] async def update_status( self, instance_id: UUID, status: WorkflowStatus, *, current_step: str | None = None, error: str | None = None, ) -> WorkflowInstanceModel | None: """Update the status of a workflow instance. Args: instance_id: The instance ID. status: The new status. current_step: Optional current step name. error: Optional error message. Returns: The updated instance or None if not found. """ instance = await self.get(instance_id) if instance: instance.status = status if current_step is not None: instance.current_step = current_step if error is not None: instance.error = error await self.session.flush() return instance
[docs] class StepExecutionRepository(SQLAlchemyAsyncRepository[StepExecutionModel]): """Repository for step execution record CRUD operations.""" model_type = StepExecutionModel
[docs] async def find_by_instance( self, instance_id: UUID, ) -> Sequence[StepExecutionModel]: """Find all step executions for an instance. Args: instance_id: The workflow instance ID. Returns: List of step executions ordered by start time. """ stmt = ( select(StepExecutionModel) .where(StepExecutionModel.instance_id == instance_id) .order_by(StepExecutionModel.started_at) ) result = await self.session.execute(stmt) return result.scalars().all()
[docs] async def find_by_step_name( self, instance_id: UUID, step_name: str, ) -> StepExecutionModel | None: """Find the execution record for a specific step. Args: instance_id: The workflow instance ID. step_name: The step name. Returns: The step execution or None. """ stmt = ( select(StepExecutionModel) .where( and_( StepExecutionModel.instance_id == instance_id, StepExecutionModel.step_name == step_name, ) ) .order_by(StepExecutionModel.started_at.desc()) .limit(1) ) result = await self.session.execute(stmt) return result.scalar_one_or_none()
[docs] async def find_failed( self, instance_id: UUID | None = None, ) -> Sequence[StepExecutionModel]: """Find failed step executions. Args: instance_id: Optional instance ID filter. Returns: List of failed step executions. """ conditions = [StepExecutionModel.status == StepStatus.FAILED] if instance_id: conditions.append(StepExecutionModel.instance_id == instance_id) stmt = select(StepExecutionModel).where(and_(*conditions)).order_by(StepExecutionModel.completed_at.desc()) result = await self.session.execute(stmt) return result.scalars().all()
[docs] class HumanTaskRepository(SQLAlchemyAsyncRepository[HumanTaskModel]): """Repository for human task CRUD operations. Provides methods for querying pending human tasks by assignee, group, and due date. """ model_type = HumanTaskModel
[docs] async def find_pending( self, assignee_id: str | None = None, assignee_group: str | None = None, ) -> Sequence[HumanTaskModel]: """Find pending human tasks. Args: assignee_id: Optional assignee ID filter. assignee_group: Optional group filter. Returns: List of pending human tasks. """ conditions = [HumanTaskModel.status == "pending"] if assignee_id: # Include tasks assigned to user or unassigned conditions.append( or_( HumanTaskModel.assignee_id == assignee_id, HumanTaskModel.assignee_id.is_(None), ) ) if assignee_group: conditions.append( or_( HumanTaskModel.assignee_group == assignee_group, HumanTaskModel.assignee_group.is_(None), ) ) stmt = ( select(HumanTaskModel) .where(and_(*conditions)) .order_by(HumanTaskModel.due_at.asc().nullslast(), HumanTaskModel.created_at) ) result = await self.session.execute(stmt) return result.scalars().all()
[docs] async def find_by_instance( self, instance_id: UUID, ) -> Sequence[HumanTaskModel]: """Find all human tasks for an instance. Args: instance_id: The workflow instance ID. Returns: List of human tasks. """ stmt = ( select(HumanTaskModel).where(HumanTaskModel.instance_id == instance_id).order_by(HumanTaskModel.created_at) ) result = await self.session.execute(stmt) return result.scalars().all()
[docs] async def find_overdue(self) -> Sequence[HumanTaskModel]: """Find overdue pending human tasks. Returns: List of overdue human tasks. """ from datetime import datetime, timezone now = datetime.now(timezone.utc) stmt = ( select(HumanTaskModel) .where( and_( HumanTaskModel.status == "pending", HumanTaskModel.due_at.isnot(None), HumanTaskModel.due_at < now, ) ) .order_by(HumanTaskModel.due_at) ) result = await self.session.execute(stmt) return result.scalars().all()
[docs] async def complete_task( self, task_id: UUID, completed_by: str, ) -> HumanTaskModel | None: """Mark a human task as completed. Args: task_id: The task ID. completed_by: User ID who completed the task. Returns: The updated task or None if not found. """ from datetime import datetime, timezone task = await self.get(task_id) if task and task.status == "pending": task.status = "completed" task.completed_at = datetime.now(timezone.utc) task.completed_by = completed_by await self.session.flush() return task
[docs] async def cancel_task(self, task_id: UUID) -> HumanTaskModel | None: """Cancel a pending human task. Args: task_id: The task ID. Returns: The updated task or None if not found. """ task = await self.get(task_id) if task and task.status == "pending": task.status = "canceled" await self.session.flush() return task