Source code for loom.core.repository.sqlalchemy.repository
from __future__ import annotations
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import asynccontextmanager
from functools import wraps
from typing import Any, Generic, TypeVar, cast
from sqlalchemy.ext.asyncio import AsyncSession
from loom.core.logger import get_logger
from loom.core.repository.abc import IdT, OutputT
from loom.core.repository.mutation import MutationEvent
from loom.core.repository.sqlalchemy.mixins import (
SQLAlchemyCreateMixin,
SQLAlchemyDeleteMixin,
SQLAlchemyReadMixin,
SQLAlchemyUpdateMixin,
)
from loom.core.repository.sqlalchemy.session_manager import SessionManager
from loom.core.repository.sqlalchemy.transactional import get_active_session
R = TypeVar("R")
[docs]
def with_session_scope(
method: Callable[..., Awaitable[R]],
) -> Callable[..., Awaitable[R]]:
"""Inject repository-managed session into custom repository methods."""
@wraps(method)
async def wrapper(
self: RepositorySQLAlchemy[Any, Any],
*args: Any,
session: AsyncSession | None = None,
**kwargs: Any,
) -> R:
async with self._session_scope(session) as scoped_session:
return await method(self, scoped_session, *args, **kwargs)
return cast(Callable[..., Awaitable[R]], wrapper)
[docs]
class RepositorySQLAlchemy(
SQLAlchemyCreateMixin[OutputT, IdT],
SQLAlchemyReadMixin[OutputT, IdT],
SQLAlchemyUpdateMixin[OutputT, IdT],
SQLAlchemyDeleteMixin[OutputT, IdT],
Generic[OutputT, IdT],
):
"""Base SQLAlchemy repository with context-aware session management.
Pass ``model`` (a Struct-based ``BaseModel``) to ``__init__``; the
repository uses the compiled SA class for queries and returns the
Struct directly.
"""
def __init__(
self,
session_manager: SessionManager,
model: type,
) -> None:
self.session_manager = session_manager
self.model = model
self._init_struct_model()
self.log = get_logger(__name__).bind(repository=self.__class__.__name__)
[docs]
async def on_transaction_committed(self, events: tuple[MutationEvent, ...]) -> None:
"""Handle post-commit mutation events (cache invalidation hook)."""
self.log.debug("RepositoryTransactionCommitted", mutation_count=len(events))
@asynccontextmanager
async def _session_scope(
self, session: AsyncSession | None = None
) -> AsyncIterator[AsyncSession]:
"""Reuse active transaction session or create a scoped one."""
if session is not None:
yield session
return
context_session = get_active_session()
if context_session is not None:
yield context_session
return
async with self.session_manager.session() as new_session:
try:
yield new_session
await new_session.commit()
except Exception:
await new_session.rollback()
self.log.exception("RepositorySessionRollback")
raise