"""Generic in-memory repository for use in tests.
Provides :class:`InMemoryRepository` — a fully functional, zero-dependency
fake repository that stores :class:`msgspec.Struct` entities in a plain dict.
Designed to be used with :class:`~loom.testing.http_harness.HttpTestHarness`
or standalone in unit tests, eliminating the need to write a ``FakeXxxRepo``
class for every domain model.
Usage::
repo = InMemoryRepository(Product, id_field="id")
repo.seed(Product(id=1, name="Widget"))
product = await repo.get_by_id(1) # → Product(id=1, name="Widget")
missing = await repo.get_by_id(99) # → None
created = await repo.create(CreateProductCmd(name="Gadget"))
# → Product(id=2, name="Gadget")
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import replace
from typing import Any, Generic, TypeVar
import msgspec
from loom.core.model.introspection import get_projections
from loom.core.projection.loaders import (
CountLoader,
ExistsLoader,
JoinFieldsLoader,
find_relation_name_for_loader,
make_memory_loader,
)
from loom.core.projection.runtime import (
ProjectionPlan,
build_projection_plan,
execute_projection_plan,
)
T = TypeVar("T", bound=msgspec.Struct)
[docs]
class InMemoryRepository(Generic[T]):
"""Generic in-memory repository for testing any ``msgspec.Struct`` model.
Stores entities in a plain dict keyed by their id field value. Provides
the standard repository surface (``get_by_id``, ``create``, ``update``,
``delete``, ``list_paginated``) without any database dependency.
The ``create`` method derives entity fields from the command automatically
when no ``creator`` callable is provided: fields present on the command
are copied to the entity, and the id field is assigned from an internal
auto-increment counter.
Args:
entity_type: The ``msgspec.Struct`` subclass this repository stores.
id_field: Name of the identity field on the entity. Defaults to
``"id"``.
creator: Optional ``(cmd, next_id) -> T`` callable used by
:meth:`create`. When provided, the automatic field-mapping is
bypassed entirely.
Example::
repo = InMemoryRepository(Product, id_field="id")
repo.seed(Product(id=1, name="Widget"), Product(id=2, name="Gadget"))
harness = HttpTestHarness()
harness.inject_repo(Product, repo)
client = harness.build_app(interfaces=[ProductRestInterface])
"""
def __init__(
self,
entity_type: type[T],
*,
id_field: str = "id",
creator: Callable[[Any, int], T] | None = None,
) -> None:
self._entity_type = entity_type
self._id_field = id_field
self._creator = creator
self._store: dict[Any, T] = {}
self._next_id: int = 1
self._projection_plans: dict[str, ProjectionPlan | None] = {}
[docs]
def seed(self, *entities: T) -> None:
"""Pre-load entities into the store.
The internal id counter is advanced past the highest integer id seen
so that subsequent :meth:`create` calls do not collide.
Args:
*entities: Entity instances to load.
Example::
repo.seed(Product(id=1, name="A"), Product(id=2, name="B"))
"""
for entity in entities:
id_val = getattr(entity, self._id_field)
self._store[id_val] = entity
if isinstance(id_val, int) and id_val >= self._next_id:
self._next_id = id_val + 1
[docs]
async def get_by_id(self, obj_id: Any, profile: str = "default") -> T | None:
"""Return the entity with ``obj_id``, or ``None`` if not found.
Args:
obj_id: The identity value to look up.
profile: Ignored; present for repository interface compatibility.
Returns:
Entity instance, or ``None`` if no entity has that id.
"""
entity = self._store.get(obj_id)
if entity is None:
return None
return await self._with_projections(entity, profile=profile)
[docs]
async def create(self, cmd: Any) -> T:
"""Create and store a new entity from ``cmd``.
If a ``creator`` callable was provided at construction it is called as
``creator(cmd, next_id)``. Otherwise, command attributes whose names
match entity fields are copied automatically, and the id field is set
from the internal auto-increment counter.
Args:
cmd: Command or payload object carrying the new entity's data.
Returns:
The created and stored entity.
"""
if self._creator is not None:
entity = self._creator(cmd, self._next_id)
self._next_id += 1
else:
entity = self._auto_create(cmd)
id_val = getattr(entity, self._id_field)
self._store[id_val] = entity
return entity
[docs]
async def update(self, obj_id: Any, data: Any) -> T | None:
"""Update the entity at ``obj_id`` with fields from ``data``.
Only non-``None`` fields present on both ``data`` and the entity are
overwritten; the id field is never changed.
Args:
obj_id: Identity value of the entity to update.
data: Object or dict with updated field values.
Returns:
The updated entity, or ``None`` if no entity has that id.
"""
entity = self._store.get(obj_id)
if entity is None:
return None
current = msgspec.structs.asdict(entity)
updates: dict[str, Any] = (
data
if isinstance(data, dict)
else {f.name: getattr(data, f.name) for f in msgspec.structs.fields(type(data))}
)
for k, v in updates.items():
if v is not None and k in current and k != self._id_field:
current[k] = v
updated = msgspec.convert(current, self._entity_type)
self._store[obj_id] = updated
return updated
[docs]
async def delete(self, obj_id: Any) -> bool:
"""Delete the entity at ``obj_id``.
Args:
obj_id: Identity value to delete.
Returns:
``True`` if the entity existed and was removed, ``False`` if not
found.
"""
if obj_id in self._store:
del self._store[obj_id]
return True
return False
[docs]
async def list_paginated(self, *args: Any, **kwargs: Any) -> list[T]:
"""Return all stored entities.
Args:
*args: Ignored; present for repository interface compatibility.
**kwargs: Ignored; present for repository interface compatibility.
Returns:
List of all entities in insertion order.
"""
profile = kwargs.get("profile", "default")
entities = list(self._store.values())
if not entities:
return []
plan = self._projection_plan_for_profile(profile)
if plan is None:
return entities
projection_values = await execute_projection_plan(
plan,
objs=entities,
id_attr=self._id_field,
backend_context=None,
)
return [
self._apply_projection_values(entity, projection_values.get(i))
for i, entity in enumerate(entities)
]
def _auto_create(self, cmd: Any) -> T:
"""Derive entity from command attributes and auto-increment id."""
entity_fields = {f.name for f in msgspec.structs.fields(self._entity_type)}
data: dict[str, Any] = {}
for f_name in entity_fields:
if f_name == self._id_field:
data[f_name] = self._next_id
elif hasattr(cmd, f_name):
data[f_name] = getattr(cmd, f_name)
self._next_id += 1
return msgspec.convert(data, self._entity_type)
def _projection_plan_for_profile(self, profile: str) -> ProjectionPlan | None:
if profile in self._projection_plans:
return self._projection_plans[profile]
model_projections = get_projections(self._entity_type)
active = {
name: projection
for name, projection in model_projections.items()
if profile in projection.profiles
}
if not active:
self._projection_plans[profile] = None
return None
resolved = self._resolve_memory_loaders(active)
compiled = build_projection_plan(resolved)
self._projection_plans[profile] = compiled
return compiled
def _resolve_memory_loaders(self, projections: dict[str, Any]) -> dict[str, Any]:
"""Replace public loader descriptors with memory-path loaders."""
result: dict[str, Any] = {}
for name, proj in projections.items():
loader = proj.loader
if isinstance(loader, (CountLoader, ExistsLoader, JoinFieldsLoader)):
rel_name = find_relation_name_for_loader(loader, self._entity_type)
if rel_name is None:
raise ValueError(
f"InMemoryRepository: cannot resolve '{name}' loader "
f"({type(loader).__name__}(model={loader.model.__name__})). "
"Provide 'via' on the loader or annotate the relation with the target type."
)
result[name] = replace(proj, loader=make_memory_loader(loader, rel_name))
else:
result[name] = proj
return result
async def _with_projections(self, entity: T, *, profile: str) -> T:
plan = self._projection_plan_for_profile(profile)
if plan is None:
return entity
result = await execute_projection_plan(
plan,
objs=[entity],
id_attr=self._id_field,
backend_context=None,
)
return self._apply_projection_values(entity, result.get(0))
def _apply_projection_values(self, entity: T, values: dict[str, Any] | None) -> T:
if not values:
return entity
data = msgspec.structs.asdict(entity)
data.update(values)
return self._entity_type(**data)