Source code for loom.core.repository.sqlalchemy.transactional

from __future__ import annotations

import contextvars
from collections.abc import Awaitable, Callable
from functools import wraps
from typing import Any, Concatenate, ParamSpec, Protocol, TypeVar, cast, runtime_checkable

from sqlalchemy.ext.asyncio import AsyncSession

from loom.core.logger import get_logger
from loom.core.repository.mutation import MutationEvent


[docs] @runtime_checkable class SupportsPostCommit(Protocol): """Protocol for objects that react to committed transactions.""" async def on_transaction_committed(self, events: tuple[MutationEvent, ...]) -> None: ...
T = TypeVar("T") P = ParamSpec("P") _active_session: contextvars.ContextVar[AsyncSession | None] = contextvars.ContextVar( "_active_session", default=None, ) _mutations: contextvars.ContextVar[list[MutationEvent] | None] = contextvars.ContextVar( "_mutations", default=None, ) _log = get_logger(__name__).bind(component="transactional") def get_active_session() -> AsyncSession | None: """Return the transactional session bound to the current context, or ``None``. Returns: The active ``AsyncSession`` if inside a ``@transactional`` scope or inside a :class:`~loom.core.repository.sqlalchemy.uow.SQLAlchemyUnitOfWork` managed by :class:`~loom.core.engine.executor.RuntimeExecutor`. """ return _active_session.get() def set_active_session( session: AsyncSession, ) -> contextvars.Token[AsyncSession | None]: """Bind ``session`` to the current async context. Used by :class:`~loom.core.repository.sqlalchemy.uow.SQLAlchemyUnitOfWork` so that :func:`get_active_session` returns the UoW session, making repository ``_session_scope`` and :func:`transactional` seamlessly participate in the same transaction. Args: session: The ``AsyncSession`` to bind. Returns: A reset token that must be passed to :func:`reset_active_session`. """ return _active_session.set(session) def reset_active_session(token: contextvars.Token[AsyncSession | None]) -> None: """Restore the session ContextVar to its previous state. Args: token: The token returned by :func:`set_active_session`. """ _active_session.reset(token) MutationsToken = contextvars.Token[list[MutationEvent] | None] def set_active_mutations() -> tuple[list[MutationEvent], MutationsToken]: """Initialise a fresh mutations list for the current context. Returns: A tuple of ``(mutations_list, reset_token)`` where the list collects :class:`~loom.core.repository.mutation.MutationEvent` objects and the token is passed to :func:`reset_active_mutations` on exit. """ mutations: list[MutationEvent] = [] token = _mutations.set(mutations) return mutations, token def reset_active_mutations( token: MutationsToken, ) -> None: """Restore the mutations ContextVar to its previous state. Args: token: The token returned by :func:`set_active_mutations`. """ _mutations.reset(token) def record_mutation(event: MutationEvent) -> None: """Append a mutation event to the current transaction's pending list. If called outside a ``@transactional`` scope the event is silently discarded. Args: event: The mutation event to record. """ events = _mutations.get() if events is None: return events.append(event) def get_pending_mutations() -> tuple[MutationEvent, ...]: """Return all mutation events recorded in the current transaction scope. Returns: A tuple of ``MutationEvent`` instances, empty if none were recorded. """ events = _mutations.get() if not events: return () return tuple(events)
[docs] def transactional( method: Callable[Concatenate[Any, P], Awaitable[T]], ) -> Callable[Concatenate[Any, P], Awaitable[T]]: """Create a single transaction boundary for service/orchestrator use cases.""" @wraps(method) async def wrapper(self: Any, *args: Any, **kwargs: Any) -> T: from loom.core.repository.sqlalchemy.repository import RepositorySQLAlchemy if isinstance(self, RepositorySQLAlchemy): raise TypeError( "@transactional is intended for service/orchestrator boundaries, " "not repository methods.", ) existing_session = get_active_session() if existing_session is not None: _log.debug( "TransactionalSessionReused", owner=self.__class__.__name__, method=method.__name__, ) return await method(self, *args, **kwargs) session_manager = getattr(self, "session_manager", None) if session_manager is None or not callable(getattr(session_manager, "session", None)): raise TypeError( f"{self.__class__.__name__} must have a 'session_manager' attribute " f"with a .session() context manager to use @transactional.", ) async with session_manager.session() as session: session_token = _active_session.set(session) mutations_token = _mutations.set([]) try: result = await method(self, *args, **kwargs) await session.commit() _log.info( "TransactionCommitted", owner=self.__class__.__name__, method=method.__name__, mutation_count=len(get_pending_mutations()), ) pending = get_pending_mutations() if isinstance(self, SupportsPostCommit): await self.on_transaction_committed(pending) for dependency in _iter_post_commit_dependencies(self): await dependency.on_transaction_committed(pending) return result except Exception: await session.rollback() _log.exception( "TransactionRolledBack", owner=self.__class__.__name__, method=method.__name__, ) raise finally: _active_session.reset(session_token) _mutations.reset(mutations_token) return cast(Callable[Concatenate[Any, P], Awaitable[T]], wrapper)
def _iter_post_commit_dependencies(owner: Any) -> list[SupportsPostCommit]: dependencies: list[SupportsPostCommit] = [] for value in vars(owner).values(): if value is owner: continue if isinstance(value, SupportsPostCommit): dependencies.append(value) return dependencies