Source code for loom.testing.repository_harness
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any
import msgspec
from loom.core.repository.sqlalchemy import get_repository_registration
from loom.core.repository.sqlalchemy.repository import RepositorySQLAlchemy
from loom.core.repository.sqlalchemy.session_manager import SessionManager
ScenarioDict = dict[str, list[msgspec.Struct]]
@dataclass(slots=True)
class EntityHarness:
repository: Any
service: Any | None = None
[docs]
class RepositoryIntegrationHarness:
"""Test harness for integration tests over repository implementations."""
def __init__(
self,
*,
session_manager: SessionManager,
entities: Mapping[str, EntityHarness],
load_order: tuple[str, ...] = (),
) -> None:
self.session_manager = session_manager
self._entities: dict[str, EntityHarness] = dict(entities)
self._load_order = load_order
for entity_name, context in self._entities.items():
setattr(self, entity_name, context)
def __getattr__(self, name: str) -> EntityHarness:
entity = self._entities.get(name)
if entity is None:
raise AttributeError(name)
return entity
async def load(self, scenario: ScenarioDict) -> None:
execution_order: list[str] = [name for name in self._load_order if name in scenario]
execution_order.extend(name for name in scenario if name not in execution_order)
for entity_name in execution_order:
entity_context = self._entities.get(entity_name)
if entity_context is None:
raise ValueError(
f"Entity '{entity_name}' is not registered in RepositoryIntegrationHarness"
)
for contract in scenario.get(entity_name, []):
await entity_context.repository.create(contract)
[docs]
def build_repository_harness(
*,
session_manager: SessionManager,
models: Mapping[str, type[Any]],
repositories: Mapping[str, Any] | None = None,
load_order: tuple[str, ...] = (),
) -> RepositoryIntegrationHarness:
"""Build a repository integration harness with generic/default repositories.
Args:
session_manager: Shared SQLAlchemy session manager.
models: Mapping ``entity_key -> model class``.
repositories: Optional mapping ``entity_key -> repository instance``.
load_order: Optional seed load order.
Returns:
Configured repository integration harness.
"""
entities: dict[str, EntityHarness] = {}
for entity_name, model in models.items():
repository = None
if repositories is not None:
repository = repositories.get(entity_name)
if repository is None:
registration = get_repository_registration(model)
repository_type = (
registration.repository_type if registration is not None else RepositorySQLAlchemy
)
repository = repository_type(session_manager=session_manager, model=model)
entities[entity_name] = EntityHarness(repository=repository)
return RepositoryIntegrationHarness(
session_manager=session_manager,
entities=entities,
load_order=load_order,
)