Source code for loom.core.repository.sqlalchemy.session_manager

from __future__ import annotations

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any

from sqlalchemy import event
from sqlalchemy.ext.asyncio import (
    AsyncEngine,
    AsyncSession,
    async_sessionmaker,
    create_async_engine,
)

from loom.core.logger import get_logger
from loom.core.tracing import get_trace_id


[docs] class SessionManager: """Async SQLAlchemy session manager with pooling support.""" def __init__( self, url: str, *, echo: bool = False, pool_pre_ping: bool = True, pool_size: int | None = 10, max_overflow: int | None = 20, pool_timeout: int | None = 30, pool_recycle: int | None = 1800, connect_args: dict[str, object] | None = None, inject_trace_id: bool = True, **engine_kwargs: object, ) -> None: """Create a session manager backed by an async SQLAlchemy engine. Args: url: Database connection URL (e.g. ``"postgresql+asyncpg://..."``). echo: If ``True``, log all generated SQL statements. pool_pre_ping: Test connections before checkout to detect stale ones. pool_size: Number of permanent connections in the pool. max_overflow: Maximum additional connections beyond ``pool_size``. pool_timeout: Seconds to wait before raising on pool exhaustion. pool_recycle: Seconds after which a connection is recycled. connect_args: Extra keyword arguments passed to the DBAPI ``connect()`` call. inject_trace_id: When ``True``, prefixes every SQL statement with a ``/* trace_id=<id> */`` comment when a trace identifier is active in the current async context. Visible in database slow-query logs and ``pg_stat_activity``. Defaults to ``True``. **engine_kwargs: Additional keyword arguments forwarded to ``create_async_engine``. """ engine_config: dict[str, object] = { "echo": echo, "pool_pre_ping": pool_pre_ping, **engine_kwargs, } if pool_size is not None: engine_config["pool_size"] = pool_size if max_overflow is not None: engine_config["max_overflow"] = max_overflow if pool_timeout is not None: engine_config["pool_timeout"] = pool_timeout if pool_recycle is not None: engine_config["pool_recycle"] = pool_recycle if connect_args is not None: engine_config["connect_args"] = connect_args self._log = get_logger(__name__).bind(module="session_manager") self._engine = create_async_engine(url, **engine_config) self._session_factory = async_sessionmaker( bind=self._engine, class_=AsyncSession, expire_on_commit=False, ) if inject_trace_id: _register_trace_id_listener(self._engine) self._log.info( "SessionManagerInitialized", backend=self._engine.url.get_backend_name(), driver=self._engine.url.get_driver_name(), inject_trace_id=inject_trace_id, )
[docs] @asynccontextmanager async def session(self) -> AsyncIterator[AsyncSession]: """Yield a scoped async session that is automatically closed on exit. Yields: An ``AsyncSession`` bound to the managed engine. """ self._log.debug("SessionScopeOpened") session = self._session_factory() try: yield session finally: await session.close() self._log.debug("SessionScopeClosed")
[docs] async def dispose(self) -> None: """Dispose of the engine and release all pooled connections.""" await self._engine.dispose() self._log.info("SessionManagerDisposed")
@property def engine(self) -> AsyncEngine: """The underlying async SQLAlchemy engine.""" return self._engine @property def session_factory(self) -> async_sessionmaker[AsyncSession]: """The configured async session factory bound to the engine.""" return self._session_factory
def _register_trace_id_listener(async_engine: AsyncEngine) -> None: """Register a ``before_cursor_execute`` listener that injects trace comments. The listener is attached to the underlying sync engine so it fires for every SQL statement executed through the async engine. When a trace-id is active in the current async context, the statement is prefixed with ``/* trace_id=<id> */``. Args: async_engine: The :class:`~sqlalchemy.ext.asyncio.AsyncEngine` whose sync engine will receive the listener. """ @event.listens_for(async_engine.sync_engine, "before_cursor_execute", retval=True) def _inject( conn: Any, cursor: Any, statement: str, parameters: Any, context: Any, executemany: bool, ) -> tuple[str, Any]: tid = get_trace_id() if tid: statement = f"/* trace_id={tid} */ " + statement return statement, parameters