Source code for loom.core.engine.executor

from __future__ import annotations

import inspect
import time
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload

from loom.core.engine.compilable import Compilable
from loom.core.engine.compiler import UseCaseCompiler
from loom.core.engine.events import EventKind, RuntimeEvent
from loom.core.engine.metrics import MetricsAdapter
from loom.core.engine.plan import ExecutionPlan, ExistsStep, LoadStep
from loom.core.errors import NotFound
from loom.core.job.context import clear_pending_dispatches, flush_pending_dispatches
from loom.core.logger import LoggerPort, get_logger
from loom.core.tracing import get_trace_id
from loom.core.uow.abc import UnitOfWorkFactory
from loom.core.uow.context import _active_uow
from loom.core.use_case.markers import LookupKind, OnMissing, SourceKind
from loom.core.use_case.rule import RuleViolation, RuleViolations

if TYPE_CHECKING:
    from loom.core.job.job import Job
    from loom.core.use_case.use_case import UseCase

ResultT = TypeVar("ResultT")

_LOG_EXEC = "[EXEC]"
_LOG_DONE = "[DONE]"
_LOG_FAIL = "[FAIL]"
_STATUS_SUCCESS = "success"
_STATUS_FAILURE = "failure"
_STATUS_RULE_FAILURE = "rule_failure"


class ParameterBindingError(ValueError):
    """Raised when a primitive execute parameter cannot be coerced to its declared type."""


[docs] class RuntimeExecutor: """Executes UseCases from their compiled ExecutionPlan without reflection. Receives a fully constructed UseCase instance and drives execution through the fixed pipeline: bind params → build command → load entities → apply computes → check rules → call execute. Optionally manages a :class:`~loom.core.uow.abc.UnitOfWork` lifecycle around each execution when a ``uow_factory`` is provided. Nested calls detected via a ``contextvars.ContextVar`` share the outer transaction and never open an additional UoW. No signature inspection occurs at runtime. All structural information comes from the cached ExecutionPlan produced by UseCaseCompiler. Emits ``RuntimeEvent`` objects to the optional ``MetricsAdapter`` at key lifecycle points (start, done, error). Enriches log calls with structured fields (``usecase``, ``duration_ms``, ``status``) for structured log consumers. Args: compiler: Compiler used to retrieve cached plans. uow_factory: Optional UoW factory. When provided, each top-level :meth:`execute` call is wrapped in a single atomic transaction (begin → commit on success, rollback on exception). Nested executions within the same async context reuse the outer UoW. debug_execution: When ``True``, emits ``[STEP]`` logs for every pipeline stage. Defaults to ``False`` (summary logs only). logger: Optional logger. Defaults to the framework logger. metrics: Optional metrics adapter. When provided, receives ``EXEC_START``, ``EXEC_DONE``, and ``EXEC_ERROR`` events. repo_resolver: Optional callable that resolves a repository instance from an entity model type. Used by ``Load``/``Exists`` when ``dependencies`` override is not passed to :meth:`execute`. Example:: executor = RuntimeExecutor( compiler, uow_factory=SQLAlchemyUnitOfWorkFactory(session_manager), metrics=prometheus_adapter, ) result = await executor.execute( use_case, params={"user_id": 1}, payload={"email": "new@corp.com"}, ) """ def __init__( self, compiler: UseCaseCompiler, *, uow_factory: UnitOfWorkFactory | None = None, debug_execution: bool = False, logger: LoggerPort | None = None, metrics: MetricsAdapter | None = None, repo_resolver: Callable[[type[Any]], Any] | None = None, ) -> None: self._compiler = compiler self._uow_factory = uow_factory self._debug = debug_execution self._logger = logger or get_logger(__name__) self._metrics = metrics self._repo_resolver = repo_resolver @overload async def execute( self, compilable: UseCase[Any, ResultT], *, params: dict[str, Any] | None = ..., payload: dict[str, Any] | None = ..., dependencies: dict[type[Any], Any] | None = ..., load_overrides: dict[type[Any], Any] | None = ..., read_only: bool = ..., ) -> ResultT: ... @overload async def execute( self, compilable: Job[ResultT], *, params: dict[str, Any] | None = ..., payload: dict[str, Any] | None = ..., dependencies: dict[type[Any], Any] | None = ..., load_overrides: dict[type[Any], Any] | None = ..., read_only: bool = ..., ) -> ResultT: ... @overload async def execute( self, compilable: Compilable, *, params: dict[str, Any] | None = ..., payload: dict[str, Any] | None = ..., dependencies: dict[type[Any], Any] | None = ..., load_overrides: dict[type[Any], Any] | None = ..., read_only: bool = ..., ) -> Any: ...
[docs] async def execute( self, compilable: Compilable, *, params: dict[str, Any] | None = None, payload: dict[str, Any] | None = None, dependencies: dict[type[Any], Any] | None = None, load_overrides: dict[type[Any], Any] | None = None, read_only: bool = False, ) -> Any: """Execute a compiled instance via its ExecutionPlan. Accepts any object satisfying the :class:`~loom.core.engine.compilable.Compilable` protocol — both :class:`~loom.core.use_case.use_case.UseCase` and :class:`~loom.core.job.job.Job` instances are valid inputs. When a ``uow_factory`` was provided at construction and no UoW is already active in the current async context, opens a fresh UoW (begin), runs the pipeline, and commits on success or rolls back on any exception. Nested calls reuse the existing UoW transparently. Args: compilable: Constructed instance to execute. params: Primitive parameter values keyed by name. payload: Raw dict for command construction via ``Input()``. dependencies: Mapping of entity type to repository, used for ``LoadById()`` / ``Load()`` / ``Exists()`` steps. load_overrides: Pre-loaded entities by type, bypassing repo calls. Used by test harnesses. read_only: When ``True``, skips opening a ``UnitOfWork`` transaction even if a ``uow_factory`` was provided. Automatically set to ``True`` by the HTTP layer for GET routes. Also honoured when ``plan.read_only`` is ``True``. Returns: The result produced by ``execute()``. Raises: loom.core.errors.RuleViolations: If one or more rule steps fail. NotFound: If a Load step finds no entity in the repository. """ uc_type = type(compilable) plan = uc_type.__execution_plan__ if plan is None: plan = self._compiler.compile(uc_type) uc_name = uc_type.__qualname__ start = time.perf_counter() _is_read_only = read_only or plan.read_only _owns_uow = ( self._uow_factory is not None and _active_uow.get() is None and not _is_read_only ) if not _owns_uow: return await self._run_pipeline( plan, compilable, uc_name, start, params, payload, dependencies, load_overrides ) uow = self._uow_factory.create() # type: ignore[union-attr] token = _active_uow.set(uow) try: await uow.begin() result: Any = await self._run_pipeline( plan, compilable, uc_name, start, params, payload, dependencies, load_overrides ) await uow.commit() await flush_pending_dispatches() return result except Exception: await uow.rollback() clear_pending_dispatches() raise finally: _active_uow.reset(token)
# ------------------------------------------------------------------ # Pipeline # ------------------------------------------------------------------ async def _run_pipeline( self, plan: ExecutionPlan, compilable: Compilable, uc_name: str, start: float, params: dict[str, Any] | None, payload: dict[str, Any] | None, dependencies: dict[type[Any], Any] | None, load_overrides: dict[type[Any], Any] | None, ) -> Any: trace_id, logger = self._begin_execution(uc_name) try: result = await self._run_core_pipeline( plan=plan, compilable=compilable, params=params, payload=payload, dependencies=dependencies, load_overrides=load_overrides, ) except RuleViolations as exc: self._handle_rule_failure( logger=logger, use_case_name=uc_name, start=start, error=exc, trace_id=trace_id, ) raise except Exception as exc: self._handle_failure( logger=logger, use_case_name=uc_name, start=start, error=exc, trace_id=trace_id, ) raise self._handle_success( logger=logger, use_case_name=uc_name, start=start, trace_id=trace_id, ) return result def _begin_execution(self, use_case_name: str) -> tuple[str | None, LoggerPort]: trace_id = get_trace_id() logger = self._logger.bind(trace_id=trace_id) if trace_id else self._logger logger.info(f"{_LOG_EXEC} {use_case_name}", usecase=use_case_name) self._emit( RuntimeEvent( kind=EventKind.EXEC_START, use_case_name=use_case_name, trace_id=trace_id, ) ) return trace_id, logger async def _run_core_pipeline( self, *, plan: ExecutionPlan, compilable: Compilable, params: dict[str, Any] | None, payload: dict[str, Any] | None, dependencies: dict[type[Any], Any] | None, load_overrides: dict[type[Any], Any] | None, ) -> Any: bound: dict[str, Any] = {} self._bind_params(plan, params or {}, bound) fields_set = self._build_command(plan, payload, bound) await self._execute_loads(plan, compilable, bound, dependencies, load_overrides) await self._execute_exists(plan, compilable, bound, dependencies) self._apply_computes(plan, bound, fields_set) self._check_rules(plan, bound, fields_set) return await self._invoke_execute(compilable, bound) async def _invoke_execute( self, compilable: Compilable, bound: dict[str, Any], ) -> Any: self._log_step("Execute core logic") execute_fn = compilable.execute if inspect.iscoroutinefunction(execute_fn): return await execute_fn(**bound) return execute_fn(**bound) def _handle_rule_failure( self, *, logger: LoggerPort, use_case_name: str, start: float, error: RuleViolations, trace_id: str | None, ) -> None: elapsed_ms = (time.perf_counter() - start) * 1000 logger.warning( f"{_LOG_FAIL} {use_case_name}", usecase=use_case_name, duration_ms=elapsed_ms, status=_STATUS_RULE_FAILURE, ) self._emit( RuntimeEvent( kind=EventKind.EXEC_ERROR, use_case_name=use_case_name, duration_ms=elapsed_ms, status=_STATUS_RULE_FAILURE, error=error, trace_id=trace_id, ) ) def _handle_failure( self, *, logger: LoggerPort, use_case_name: str, start: float, error: Exception, trace_id: str | None, ) -> None: elapsed_ms = (time.perf_counter() - start) * 1000 logger.error( f"{_LOG_FAIL} {use_case_name}", usecase=use_case_name, duration_ms=elapsed_ms, status=_STATUS_FAILURE, error=str(error), ) self._emit( RuntimeEvent( kind=EventKind.EXEC_ERROR, use_case_name=use_case_name, duration_ms=elapsed_ms, status=_STATUS_FAILURE, error=error, trace_id=trace_id, ) ) def _handle_success( self, *, logger: LoggerPort, use_case_name: str, start: float, trace_id: str | None, ) -> None: elapsed_ms = (time.perf_counter() - start) * 1000 logger.info( f"{_LOG_DONE} {elapsed_ms:.1f}ms", usecase=use_case_name, duration_ms=elapsed_ms, status=_STATUS_SUCCESS, ) self._emit( RuntimeEvent( kind=EventKind.EXEC_DONE, use_case_name=use_case_name, duration_ms=elapsed_ms, status=_STATUS_SUCCESS, trace_id=trace_id, ) ) # ------------------------------------------------------------------ # Pipeline stages # ------------------------------------------------------------------ def _bind_params( self, plan: ExecutionPlan, params: dict[str, Any], bound: dict[str, Any], ) -> None: for pb in plan.param_bindings: if pb.name not in params: raise ValueError( f"{plan.use_case_type.__qualname__}: missing required parameter '{pb.name}'" ) raw = params[pb.name] bound[pb.name] = self._coerce_param( param_name=pb.name, annotation=pb.annotation, raw=raw, use_case_name=plan.use_case_type.__qualname__, ) @staticmethod def _coerce_param( *, param_name: str, annotation: Any, raw: Any, use_case_name: str, ) -> Any: if annotation is Any: return raw if not isinstance(annotation, type): return raw if isinstance(raw, annotation): return raw try: return annotation(raw) except (TypeError, ValueError) as exc: raise ParameterBindingError( f"{use_case_name}: invalid value for parameter '{param_name}': " f"expected {annotation.__name__}, got {type(raw).__name__} ({raw!r})" ) from exc def _build_command( self, plan: ExecutionPlan, payload: dict[str, Any] | None, bound: dict[str, Any], ) -> frozenset[str]: if plan.input_binding is None: return frozenset() if payload is None: raise ValueError( f"{plan.use_case_type.__qualname__}: payload is required for a UseCase with Input()" ) cmd_type = plan.input_binding.command_type if hasattr(cmd_type, "from_payload"): command, fields_set = cmd_type.from_payload(payload) else: raise TypeError( f"{plan.use_case_type.__qualname__}: " f"command type {cmd_type!r} must implement from_payload()" ) bound[plan.input_binding.name] = command self._log_step("Bind Input") return cast(frozenset[str], fields_set) async def _execute_loads( self, plan: ExecutionPlan, compilable: Compilable, bound: dict[str, Any], dependencies: dict[type[Any], Any] | None, load_overrides: dict[type[Any], Any] | None, ) -> None: for ls in plan.load_steps: bound[ls.name] = await self._resolve_load( ls, plan, compilable, bound, dependencies, load_overrides ) self._log_step(f"Load {ls.entity_type.__name__}") async def _execute_exists( self, plan: ExecutionPlan, compilable: Compilable, bound: dict[str, Any], dependencies: dict[type[Any], Any] | None, ) -> None: for es in plan.exists_steps: bound[es.name] = await self._resolve_exists(es, plan, compilable, bound, dependencies) self._log_step(f"Exists {es.entity_type.__name__}") def _apply_computes( self, plan: ExecutionPlan, bound: dict[str, Any], fields_set: frozenset[str], ) -> None: if plan.input_binding is None or not plan.compute_steps: return command = bound[plan.input_binding.name] for cs in plan.compute_steps: compute_fn = cast(Any, cs.fn) label = getattr(cs.fn, "__name__", type(cs.fn).__name__) if cs.accepts_context: command = compute_fn(command, fields_set, bound) else: command = compute_fn(command, fields_set) self._log_step(f"Compute {label}") bound[plan.input_binding.name] = command def _check_rules( self, plan: ExecutionPlan, bound: dict[str, Any], fields_set: frozenset[str], ) -> None: if plan.input_binding is None or not plan.rule_steps: return command = bound[plan.input_binding.name] violations: list[RuleViolation] = [] for rs in plan.rule_steps: label = getattr(rs.fn, "__name__", type(rs.fn).__name__) try: if rs.accepts_context: rs.fn(command, fields_set, bound) else: rs.fn(command, fields_set) self._log_step(f"Rule {label}") except RuleViolation as exc: violations.append(exc) self._logger.warning( f"[RULE] {label} failed: {exc.field}: {exc.message}", usecase=plan.use_case_type.__qualname__, field=exc.field, ) if violations: raise RuleViolations(violations) # ------------------------------------------------------------------ # Load resolution # ------------------------------------------------------------------ async def _resolve_load( self, step: LoadStep, plan: ExecutionPlan, compilable: Compilable, bound: dict[str, Any], dependencies: dict[type[Any], Any] | None, load_overrides: dict[type[Any], Any] | None, ) -> Any: del compilable if load_overrides and step.entity_type in load_overrides: return load_overrides[step.entity_type] if dependencies is None: repo = self._resolve_repo(step.entity_type) if repo is None and self._repo_resolver is None: raise RuntimeError( f"No dependencies provided for LoadById({step.entity_type.__name__})" ) else: repo = dependencies.get(step.entity_type) if repo is None: repo = self._resolve_repo(step.entity_type) if repo is None: raise RuntimeError( f"No repository registered for entity type '{step.entity_type.__name__}'" ) value = self._resolve_lookup_value(step.source_kind, step.source_name, plan, bound) entity = await self._get_entity(repo, step, value) if entity is None: if step.on_missing is OnMissing.RETURN_NONE: return None if step.on_missing is OnMissing.RETURN_FALSE: return False raise NotFound(step.entity_type.__name__, id=value) return entity async def _resolve_exists( self, step: ExistsStep, plan: ExecutionPlan, compilable: Compilable, bound: dict[str, Any], dependencies: dict[type[Any], Any] | None, ) -> bool: del compilable if dependencies is None: repo = self._resolve_repo(step.entity_type) if repo is None and self._repo_resolver is None: raise RuntimeError( f"No dependencies provided for Exists({step.entity_type.__name__})" ) else: repo = dependencies.get(step.entity_type) if repo is None: repo = self._resolve_repo(step.entity_type) if repo is None: raise RuntimeError( f"No repository registered for entity type '{step.entity_type.__name__}'" ) value = self._resolve_lookup_value(step.source_kind, step.source_name, plan, bound) exists_by = getattr(repo, "exists_by", None) if not callable(exists_by): raise RuntimeError( "Repository for " f"'{step.entity_type.__name__}' must implement exists_by(field, value)" ) exists_by_async = cast(Callable[[str, Any], Awaitable[bool]], exists_by) found = bool(await exists_by_async(step.against, value)) if found: return True if step.on_missing is OnMissing.RAISE: raise NotFound(step.entity_type.__name__, id=value) return False async def _get_entity(self, repo: Any, step: LoadStep, value: Any) -> Any | None: if step.lookup_kind is LookupKind.BY_ID: return await repo.get_by_id(value, profile=step.profile) get_by = getattr(repo, "get_by", None) if not callable(get_by): raise RuntimeError( f"Repository for '{step.entity_type.__name__}' must implement get_by(field, value)" ) get_by_async = cast(Callable[..., Awaitable[Any | None]], get_by) return await get_by_async(step.against, value, profile=step.profile) @staticmethod def _resolve_lookup_value( source_kind: SourceKind, source_name: str, plan: ExecutionPlan, bound: dict[str, Any], ) -> Any: if source_kind is SourceKind.PARAM: return bound[source_name] if plan.input_binding is None: raise RuntimeError("from_command source requires Input() binding") command = bound[plan.input_binding.name] if not hasattr(command, source_name): raise RuntimeError(f"Command '{type(command).__name__}' has no field '{source_name}'") return getattr(command, source_name) def _resolve_repo(self, entity_type: type[Any]) -> Any | None: if self._repo_resolver is None: return None return self._repo_resolver(entity_type) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _emit(self, event: RuntimeEvent) -> None: if self._metrics is not None: self._metrics.on_event(event) def _log_step(self, label: str) -> None: if self._debug: self._logger.info(f"[STEP] {label}")